Medial Code Documentation
Loading...
Searching...
No Matches
rabit-inl.h
Go to the documentation of this file.
1
8#ifndef RABIT_INTERNAL_RABIT_INL_H_
9#define RABIT_INTERNAL_RABIT_INL_H_
10// use engine for implementation
11#include <vector>
12#include <string>
13#include "rabit/internal/io.h"
15#include "rabit/rabit.h"
16
17namespace rabit {
18namespace engine {
19namespace mpi {
20// template function to translate type to enum indicator
21template<typename DType>
22inline DataType GetType();
23template<>
24inline DataType GetType<char>() {
25 return kChar;
26}
27template<>
28inline DataType GetType<unsigned char>() {
29 return kUChar;
30}
31template<>
32inline DataType GetType<int>() {
33 return kInt;
34}
35template<>
36inline DataType GetType<unsigned int>() { // NOLINT(*)
37 return kUInt;
38}
39template<>
40inline DataType GetType<long>() { // NOLINT(*)
41 return kLong;
42}
43template<>
44inline DataType GetType<unsigned long>() { // NOLINT(*)
45 return kULong;
46}
47template<>
48inline DataType GetType<float>() {
49 return kFloat;
50}
51template<>
52inline DataType GetType<double>() {
53 return kDouble;
54}
55template<>
56inline DataType GetType<long long>() { // NOLINT(*)
57 return kLongLong;
58}
59template<>
60inline DataType GetType<unsigned long long>() { // NOLINT(*)
61 return kULongLong;
62}
63} // namespace mpi
64} // namespace engine
65
66namespace op {
67struct Max {
68 static const engine::mpi::OpType kType = engine::mpi::kMax;
69 template<typename DType>
70 inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
71 if (dst < src) dst = src;
72 }
73};
74struct Min {
75 static const engine::mpi::OpType kType = engine::mpi::kMin;
76 template<typename DType>
77 inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
78 if (dst > src) dst = src;
79 }
80};
81struct Sum {
82 static const engine::mpi::OpType kType = engine::mpi::kSum;
83 template<typename DType>
84 inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
85 dst += src;
86 }
87};
88struct BitAND {
89 static const engine::mpi::OpType kType = engine::mpi::kBitwiseAND;
90 template<typename DType>
91 inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
92 dst &= src;
93 }
94};
95struct BitOR {
96 static const engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
97 template<typename DType>
98 inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
99 dst |= src;
100 }
101};
102struct BitXOR {
103 static const engine::mpi::OpType kType = engine::mpi::kBitwiseXOR;
104 template<typename DType>
105 inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
106 dst ^= src;
107 }
108};
109template <typename OP, typename DType>
110inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &) {
111 const DType *src = static_cast<const DType *>(src_);
112 DType *dst = (DType *)dst_; // NOLINT(*)
113 for (int i = 0; i < len; i++) {
114 OP::Reduce(dst[i], src[i]);
115 }
116}
117} // namespace op
118
119// initialize the rabit engine
120inline bool Init(int argc, char *argv[]) {
121 return engine::Init(argc, argv);
122}
123// finalize the rabit engine
124inline bool Finalize() {
125 return engine::Finalize();
126}
127// get the rank of the previous worker in ring topology
128inline int GetRingPrevRank() {
130}
131// get the rank of current process
132inline int GetRank() {
133 return engine::GetEngine()->GetRank();
134}
135// the the size of the world
136inline int GetWorldSize() {
138}
139// whether rabit is distributed
140inline bool IsDistributed() {
142}
143// get the name of current processor
144inline std::string GetProcessorName() {
145 return engine::GetEngine()->GetHost();
146}
147// broadcast data to all other nodes from root
148inline void Broadcast(void *sendrecv_data, size_t size, int root) {
149 engine::GetEngine()->Broadcast(sendrecv_data, size, root);
150}
151template<typename DType>
152inline void Broadcast(std::vector<DType> *sendrecv_data, int root) {
153 size_t size = sendrecv_data->size();
154 Broadcast(&size, sizeof(size), root);
155 if (sendrecv_data->size() != size) {
156 sendrecv_data->resize(size);
157 }
158 if (size != 0) {
159 Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root);
160 }
161}
162inline void Broadcast(std::string *sendrecv_data, int root) {
163 size_t size = sendrecv_data->length();
164 Broadcast(&size, sizeof(size), root);
165 if (sendrecv_data->length() != size) {
166 sendrecv_data->resize(size);
167 }
168 if (size != 0) {
169 Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
170 }
171}
172
173// perform inplace Allreduce
174template<typename OP, typename DType>
175inline void Allreduce(DType *sendrecvbuf, size_t count,
176 void (*prepare_fun)(void *arg),
177 void *prepare_arg) {
178 engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
179 engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
180}
181
182// C++11 support for lambda prepare function
183#if DMLC_USE_CXX11
184inline void InvokeLambda(void *fun) {
185 (*static_cast<std::function<void()>*>(fun))();
186}
187template<typename OP, typename DType>
188inline void Allreduce(DType *sendrecvbuf, size_t count,
189 std::function<void()> prepare_fun) {
190 engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
191 engine::mpi::GetType<DType>(), OP::kType, InvokeLambda, &prepare_fun);
192}
193
194// Performs inplace Allgather
195template<typename DType>
196inline void Allgather(DType *sendrecvbuf,
197 size_t totalSize,
198 size_t beginIndex,
199 size_t sizeNodeSlice,
200 size_t sizePrevSlice) {
201 engine::GetEngine()->Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
202 (beginIndex + sizeNodeSlice) * sizeof(DType),
203 sizePrevSlice * sizeof(DType));
204}
205#endif // C++11
206
207// print message to the tracker
208inline void TrackerPrint(const std::string &msg) {
210}
211#ifndef RABIT_STRICT_CXX98_
212inline void TrackerPrintf(const char *fmt, ...) {
213 const int kPrintBuffer = 1 << 10;
214 std::string msg(kPrintBuffer, '\0');
215 va_list args;
216 va_start(args, fmt);
217 vsnprintf(&msg[0], kPrintBuffer, fmt, args);
218 va_end(args);
219 msg.resize(strlen(msg.c_str()));
220 TrackerPrint(msg);
221}
222
223#endif // RABIT_STRICT_CXX98_
224
225// deprecated, planned for removal after checkpoing from JVM package is removed.
227// deprecated, increase internal version number
229// return the version number of currently stored model
230inline int VersionNumber() {
232}
233} // namespace rabit
234#endif // RABIT_INTERNAL_RABIT_INL_H_
Definition allreduce_base.h:32
virtual bool IsDistributed() const =0
whether we run in distribted mode
virtual void TrackerPrint(const std::string &msg)=0
prints the msg in the tracker, this function can be used to communicate progress information to the u...
virtual int GetRank() const =0
gets rank of current node
virtual int LoadCheckPoint()=0
virtual int VersionNumber() const =0
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root)=0
broadcasts data from root to every other node
virtual int GetWorldSize() const =0
gets total number of nodes
virtual std::string GetHost() const =0
gets the host name of the current node
virtual int GetRingPrevRank() const =0
gets rank of previous node in ring topology
virtual void CheckPoint()=0
Increase internal version number. Deprecated.
virtual void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice)=0
Allgather function, each node have a segment of data in the ring of sendrecvbuf, the data provided by...
DataType
enum of supported data types
Definition engine.h:141
OpType
enum of all operators
Definition engine.h:132
bool Init(int argc, char *argv[])
initializes the engine module
Definition engine.cc:43
void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red, mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun=nullptr, void *prepare_arg=nullptr)
perform in-place Allreduce, on sendrecvbuf this is an internal function used by rabit to be able to c...
Definition engine.cc:95
bool Finalize()
finalizes the engine module
Definition engine.cc:55
IEngine * GetEngine()
singleton method to get engine
Definition engine.cc:71
namespace of rabit
Definition engine.h:18
void CheckPoint()
deprecated, planned for removal after checkpoing from JVM package is removed.
Definition rabit-inl.h:228
std::string GetProcessorName()
gets processor's name
Definition rabit-inl.h:144
bool Finalize()
finalizes the rabit engine, call this function after you finished with all the jobs
Definition rabit-inl.h:124
int LoadCheckPoint()
deprecated, planned for removal after checkpoing from JVM package is removed.
Definition rabit-inl.h:226
void TrackerPrint(const std::string &msg)
prints the msg to the tracker, this function can be used to communicate progress information to the u...
Definition rabit-inl.h:208
void Allgather(DType *sendrecvbuf_, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice)
Allgather function, each node have a segment of data in the ring of sendrecvbuf, the data provided by...
void Broadcast(void *sendrecv_data, size_t size, int root)
broadcasts a memory region to every node from the root
Definition rabit-inl.h:148
void TrackerPrintf(const char *fmt,...)
prints the msg to the tracker, this function may not be available in very strict c++98 compilers,...
Definition rabit-inl.h:212
bool IsDistributed()
whether rabit env is in distributed mode
Definition rabit-inl.h:140
int VersionNumber()
Definition rabit-inl.h:230
bool Init(int argc, char *argv[])
initializes rabit, call this once at the beginning of your program
Definition rabit-inl.h:120
int GetRank()
gets rank of the current process
Definition rabit-inl.h:132
int GetWorldSize()
gets total number of processes
Definition rabit-inl.h:136
Copyright 2014-2023, XGBoost Contributors.
This file defines rabit's Allreduce/Broadcast interface The rabit engine contains the actual implemen...
bitwise AND reduction operator
Definition rabit-inl.h:88
bitwise OR reduction operator
Definition rabit-inl.h:95
bitwise XOR reduction operator
Definition rabit-inl.h:102
maximum reduction operator
Definition rabit-inl.h:67
minimum reduction operator
Definition rabit-inl.h:74
sum reduction operator
Definition rabit-inl.h:81
simple utils to support the code