Medial Code Documentation
Loading...
Searching...
No Matches
allreduce_base.h
Go to the documentation of this file.
1
12#ifndef RABIT_ALLREDUCE_BASE_H_
13#define RABIT_ALLREDUCE_BASE_H_
14
15#include <functional>
16#include <future>
17#include <vector>
18#include <string>
19#include <algorithm>
23
24#ifdef RABIT_CXXTESTDEFS_H
25#define private public
26#define protected public
27#endif // RABIT_CXXTESTDEFS_H
28
29
30namespace MPI { // NOLINT
31// MPI data type to be compatible with existing MPI interface
32class Datatype {
33 public:
34 size_t type_size;
35 explicit Datatype(size_t type_size) : type_size(type_size) {}
36};
37}
38namespace rabit {
39namespace engine {
40
42class AllreduceBase : public IEngine {
43 public:
44 // magic number to verify server
45 static const int kMagic = 0xff99;
46 // constant one byte out of band message to indicate error happening
48 virtual ~AllreduceBase() = default;
49 // initialize the manager
50 virtual bool Init(int argc, char* argv[]);
51 // shutdown the engine
52 virtual bool Shutdown();
58 virtual void SetParam(const char *name, const char *val);
65 void TrackerPrint(const std::string &msg) override;
66
68 int GetRingPrevRank() const override {
69 return ring_prev->rank;
70 }
72 int GetRank() const override {
73 return rank;
74 }
76 int GetWorldSize() const override {
77 if (world_size == -1) return 1;
78 return world_size;
79 }
81 bool IsDistributed() const override {
82 return tracker_uri != "NULL";
83 }
85 std::string GetHost() const override {
86 return host_uri;
87 }
88
102 void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
103 size_t slice_end, size_t size_prev_slice) override {
104 if (world_size == 1 || world_size == -1) {
105 return;
106 }
107 utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size, slice_begin,
108 slice_end, size_prev_slice) == kSuccess,
109 "AllgatherRing failed");
110 }
123 void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
124 ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
125 void *prepare_arg = nullptr) override {
126 if (prepare_fun != nullptr) prepare_fun(prepare_arg);
127 if (world_size == 1 || world_size == -1) return;
128 utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) ==
129 kSuccess,
130 "Allreduce failed");
131 }
141 void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override {
142 if (world_size == 1 || world_size == -1) return;
143 utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
144 "Broadcast failed");
145 }
150 int LoadCheckPoint() override { return 0; }
151
152 // deprecated, increase internal version number
153 void CheckPoint() override { version_number += 1; }
159 int VersionNumber() const override {
160 return version_number;
161 }
166 inline void ReportStatus() const {
167 if (hadoop_mode != 0) {
168 LOG(CONSOLE) << "reporter:status:Rabit Phase[" << version_number << "] Operation " << seq_counter << "\n";
169 }
170 }
171
172 protected:
190 struct ReturnType {
193 // constructor
194 ReturnType() = default;
195 ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*)
196 inline bool operator==(const ReturnTypeEnum &v) const {
197 return value == v;
198 }
199 inline bool operator!=(const ReturnTypeEnum &v) const {
200 return value != v;
201 }
202 };
205 int errsv = xgboost::system::LastError();
206 if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess;
207#ifdef _WIN32
208 if (errsv == WSAEWOULDBLOCK) return kSuccess;
209 if (errsv == WSAECONNRESET) return kConnReset;
210#endif // _WIN32
211 if (errsv == ECONNRESET) return kConnReset;
212 return kSockError;
213 }
214 // link record to a neighbor
215 struct LinkRecord {
216 public:
217 // socket to get data from/to link
219 // rank of the node in this link
220 int rank;
221 // size of data readed from link
222 size_t size_read;
223 // size of data sent to the link
224 size_t size_write;
225 // pointer to buffer head
226 char *buffer_head {nullptr};
227 // buffer size, in bytes
228 size_t buffer_size {0};
229 // constructor
230 LinkRecord() = default;
231 // initialize buffer
232 void InitBuffer(size_t type_nbytes, size_t count,
233 size_t reduce_buffer_size) {
234 size_t n = (type_nbytes * count + 7)/ 8;
235 auto to = Min(reduce_buffer_size, n);
236 buffer_.resize(to);
237 // make sure align to type_nbytes
238 buffer_size =
239 buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
240 utils::Assert(type_nbytes <= buffer_size,
241 "too large type_nbytes=%lu, buffer_size=%lu",
242 type_nbytes, buffer_size);
243 // set buffer head
244 buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
245 }
246 // reset the recv and sent size
247 inline void ResetSize() {
248 size_write = size_read = 0;
249 }
258 inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
259 utils::Assert(buffer_head != nullptr, "ReadToRingBuffer: buffer not allocated");
260 utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
261 size_t ngap = size_read - protect_start;
262 utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
263 size_t offset = size_read % buffer_size;
264 size_t nmax = max_size_read - size_read;
265 nmax = Min(nmax, buffer_size - ngap);
266 nmax = Min(nmax, buffer_size - offset);
267 if (nmax == 0) return kSuccess;
268 ssize_t len = sock.Recv(buffer_head + offset, nmax);
269 // length equals 0, remote disconnected
270 if (len == 0) {
271 sock.Close(); return kRecvZeroLen;
272 }
273 if (len == -1) return Errno2Return();
274 size_read += static_cast<size_t>(len);
275 return kSuccess;
276 }
284 inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) {
285 if (max_size == size_read) return kSuccess;
286 char *p = static_cast<char*>(recvbuf_);
287 ssize_t len = sock.Recv(p + size_read, max_size - size_read);
288 // length equals 0, remote disconnected
289 if (len == 0) {
290 sock.Close(); return kRecvZeroLen;
291 }
292 if (len == -1) return Errno2Return();
293 size_read += static_cast<size_t>(len);
294 return kSuccess;
295 }
302 inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) {
303 const char *p = static_cast<const char*>(sendbuf_);
304 ssize_t len = sock.Send(p + size_write, max_size - size_write);
305 if (len == -1) return Errno2Return();
306 size_write += static_cast<size_t>(len);
307 return kSuccess;
308 }
309
310 private:
311 // recv buffer to get data from child
312 // aligned with 64 bits, will be able to perform 64 bits operations freely
313 std::vector<uint64_t> buffer_;
314 };
320 std::vector<LinkRecord*> plinks;
321 inline LinkRecord &operator[](size_t i) {
322 return *plinks[i];
323 }
324 inline size_t Size() const {
325 return plinks.size();
326 }
327 };
338 bool ReConnectLinks(const char *cmd = "start");
354 ReturnType TryAllreduce(void *sendrecvbuf_,
355 size_t type_nbytes,
356 size_t count,
357 ReduceFunction reducer);
366 ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
378 ReturnType TryAllreduceTree(void *sendrecvbuf_,
379 size_t type_nbytes,
380 size_t count,
381 ReduceFunction reducer);
397 ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
398 size_t slice_begin, size_t slice_end,
399 size_t size_prev_slice);
414 ReturnType TryReduceScatterRing(void *sendrecvbuf_,
415 size_t type_nbytes,
416 size_t count,
417 ReduceFunction reducer);
429 ReturnType TryAllreduceRing(void *sendrecvbuf_,
430 size_t type_nbytes,
431 size_t count,
432 ReduceFunction reducer);
439 err_link = link; return err;
440 }
441 //---- data structure related to model ----
442 // call sequence counter, records how many calls we made so far
443 // from last call to CheckPoint, LoadCheckPoint
444 int seq_counter{0}; // NOLINT
445 // version number of model
446 int version_number {0}; // NOLINT
447 // whether the job is running in Hadoop
448 bool hadoop_mode; // NOLINT
449 //---- local data related to link ----
450 // index of parent link, can be -1, meaning this is root of the tree
451 int parent_index; // NOLINT
452 // rank of parent node, can be -1
453 int parent_rank; // NOLINT
454 // sockets of all links this connects to
455 std::vector<LinkRecord> all_links; // NOLINT
456 // used to record the link where things goes wrong
457 LinkRecord *err_link; // NOLINT
458 // all the links in the reduction tree connection
459 RefLinkVector tree_links; // NOLINT
460 // pointer to links in the ring
461 LinkRecord *ring_prev, *ring_next; // NOLINT
462 //----- meta information-----
463 // list of enviroment variables that are of possible interest
464 std::vector<std::string> env_vars; // NOLINT
465 // unique identifier of the possible job this process is doing
466 // used to assign ranks, optional, default to NULL
467 std::string task_id; // NOLINT
468 // uri of current host, to be set by Init
469 std::string host_uri; // NOLINT
470 // uri of tracker
471 std::string tracker_uri; // NOLINT
472 // role in dmlc jobs
473 std::string dmlc_role; // NOLINT
474 // port of tracker address
475 int tracker_port; // NOLINT
476 // reduce buffer size
477 size_t reduce_buffer_size; // NOLINT
478 // reduction method
479 int reduce_method; // NOLINT
480 // minimum count of cells to use ring based method
481 size_t reduce_ring_mincount; // NOLINT
482 // minimum block size per tree reduce
483 size_t tree_reduce_minsize; // NOLINT
484 // current rank
485 int rank; // NOLINT
486 // world size
487 int world_size; // NOLINT
488 // connect retry time
489 int connect_retry; // NOLINT
490 // by default, if rabit worker not recover in half an hour exit
491 std::chrono::seconds timeout_sec{std::chrono::seconds{1800}}; // NOLINT
492 // flag to enable rabit_timeout
493 bool rabit_timeout = false; // NOLINT
494 // Enable TCP node delay
495 bool rabit_enable_tcp_no_delay = false; // NOLINT
496};
497} // namespace engine
498} // namespace rabit
499#endif // RABIT_ALLREDUCE_BASE_H_
Definition allreduce_base.h:32
implementation of basic Allreduce engine
Definition allreduce_base.h:42
int GetWorldSize() const override
get rank
Definition allreduce_base.h:76
ReturnType TryReduceScatterRing(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer)
perform in-place allreduce, reduce on the sendrecvbuf,
Definition allreduce_base.cc:847
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root)
broadcast data from root to all nodes, this function can fail,and will return the cause of failure
Definition allreduce_base.cc:668
void ReportStatus() const
report current status to the job tracker depending on the job tracker we are in
Definition allreduce_base.h:166
void TrackerPrint(const std::string &msg) override
print the msg in the tracker, this function can be used to communicate the information of the progres...
Definition allreduce_base.cc:145
ReturnType TryAllreduceTree(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer)
perform in-place allreduce, on sendrecvbuf, this function implements tree-shape reduction
Definition allreduce_base.cc:473
void CheckPoint() override
Increase internal version number. Deprecated.
Definition allreduce_base.h:153
ReturnType ReportError(LinkRecord *link, ReturnType err)
function used to report error when a link goes wrong
Definition allreduce_base.h:438
ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice)
internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,...
Definition allreduce_base.cc:763
ReturnType TryAllreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer)
perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of fail...
Definition allreduce_base.cc:451
int VersionNumber() const override
Definition allreduce_base.h:159
std::string GetHost() const override
get rank
Definition allreduce_base.h:85
int GetRingPrevRank() const override
get rank of previous node in ring topology
Definition allreduce_base.h:68
virtual void SetParam(const char *name, const char *val)
set parameters to the engine
Definition allreduce_base.cc:182
xgboost::collective::TCPSocket ConnectTracker() const
initialize connection to the tracker
Definition allreduce_base.cc:222
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer, PreprocFunction prepare_fun=nullptr, void *prepare_arg=nullptr) override
perform in-place allreduce, on sendrecvbuf this function is NOT thread-safe
Definition allreduce_base.h:123
int GetRank() const override
get rank
Definition allreduce_base.h:72
void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override
broadcast data from root to all nodes
Definition allreduce_base.h:141
ReturnType TryAllreduceRing(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer)
perform in-place allreduce, on sendrecvbuf use a ring based algorithm, reduce-scatter + allgather
Definition allreduce_base.cc:948
static ReturnType Errno2Return()
translate errno to return type
Definition allreduce_base.h:204
int LoadCheckPoint() override
deprecated
Definition allreduce_base.h:150
bool ReConnectLinks(const char *cmd="start")
connect to the tracker to fix the the missing links this function is also used when the engine start ...
Definition allreduce_base.cc:263
bool IsDistributed() const override
whether is distributed or not
Definition allreduce_base.h:81
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice) override
internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,...
Definition allreduce_base.h:102
ReturnTypeEnum
enumeration of possible returning results from Try functions
Definition allreduce_base.h:174
@ kRecvZeroLen
received a zero length message
Definition allreduce_base.h:180
@ kGetExcept
another node which is not my neighbor go down, get Out-of-Band exception notification from my neighbo...
Definition allreduce_base.h:187
@ kSockError
a neighbor node go down, the connection is dropped
Definition allreduce_base.h:182
@ kConnReset
a link was reset by peer
Definition allreduce_base.h:178
@ kSuccess
execution is successful
Definition allreduce_base.h:176
interface of core Allreduce engine
Definition engine.h:22
void() PreprocFunction(void *arg)
Preprocessing function, that is called before AllReduce, used to prepare the data used by AllReduce.
Definition engine.h:29
void() ReduceFunction(const void *src, void *dst, int count, const MPI::Datatype &dtype)
reduce function, the same form of MPI reduce function is used, to be compatible with MPI interface In...
Definition engine.h:41
TCP socket for simple communication.
Definition socket.h:249
void Close()
Close the socket, called automatically in destructor if the socket is not closed.
Definition socket.h:504
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
receive data using the socket
Definition socket.h:489
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
Send data using the socket.
Definition socket.h:478
This file defines the core interface of rabit library.
void Assert(bool exp, const char *fmt,...)
assert a condition is true, use this to handle debug information
Definition utils.h:79
namespace of rabit
Definition engine.h:18
Definition allreduce_base.h:215
ReturnType ReadToArray(void *recvbuf_, size_t max_size)
read data into array, this function can not be used together with ReadToRingBuffer a link can either ...
Definition allreduce_base.h:284
ReturnType WriteFromArray(const void *sendbuf_, size_t max_size)
write data in array to sock
Definition allreduce_base.h:302
ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read)
read data into ring-buffer, with care not to existing useful override data position after protect_sta...
Definition allreduce_base.h:258
simple data structure that works like a vector but takes reference instead of space
Definition allreduce_base.h:319
struct return type to avoid implicit conversion to int/bool
Definition allreduce_base.h:190
ReturnTypeEnum value
internal return type
Definition allreduce_base.h:192
simple utils to support the code