8#ifndef RABIT_INTERNAL_RABIT_INL_H_
9#define RABIT_INTERNAL_RABIT_INL_H_
21template<
typename DType>
28inline DataType GetType<unsigned char>() {
36inline DataType GetType<unsigned int>() {
44inline DataType GetType<unsigned long>() {
56inline DataType GetType<long long>() {
60inline DataType GetType<unsigned long long>() {
69 template<
typename DType>
70 inline static void Reduce(DType &dst,
const DType &src) {
71 if (dst < src) dst = src;
76 template<
typename DType>
77 inline static void Reduce(DType &dst,
const DType &src) {
78 if (dst > src) dst = src;
83 template<
typename DType>
84 inline static void Reduce(DType &dst,
const DType &src) {
90 template<
typename DType>
91 inline static void Reduce(DType &dst,
const DType &src) {
97 template<
typename DType>
98 inline static void Reduce(DType &dst,
const DType &src) {
104 template<
typename DType>
105 inline static void Reduce(DType &dst,
const DType &src) {
109template <
typename OP,
typename DType>
110inline void Reducer(
const void *src_,
void *dst_,
int len,
const MPI::Datatype &) {
111 const DType *src =
static_cast<const DType *
>(src_);
112 DType *dst = (DType *)dst_;
113 for (
int i = 0; i < len; i++) {
114 OP::Reduce(dst[i], src[i]);
120inline bool Init(
int argc,
char *argv[]) {
128inline int GetRingPrevRank() {
148inline void Broadcast(
void *sendrecv_data,
size_t size,
int root) {
151template<
typename DType>
152inline void Broadcast(std::vector<DType> *sendrecv_data,
int root) {
153 size_t size = sendrecv_data->size();
155 if (sendrecv_data->size() != size) {
156 sendrecv_data->resize(size);
159 Broadcast(&(*sendrecv_data)[0], size *
sizeof(DType), root);
162inline void Broadcast(std::string *sendrecv_data,
int root) {
163 size_t size = sendrecv_data->length();
165 if (sendrecv_data->length() != size) {
166 sendrecv_data->resize(size);
169 Broadcast(&(*sendrecv_data)[0], size *
sizeof(
char), root);
174template<
typename OP,
typename DType>
175inline void Allreduce(DType *sendrecvbuf,
size_t count,
176 void (*prepare_fun)(
void *arg),
179 engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
184inline void InvokeLambda(
void *fun) {
185 (*
static_cast<std::function<
void()
>*>(fun))();
187template<
typename OP,
typename DType>
188inline void Allreduce(DType *sendrecvbuf,
size_t count,
189 std::function<
void()> prepare_fun) {
191 engine::mpi::GetType<DType>(), OP::kType, InvokeLambda, &prepare_fun);
195template<
typename DType>
199 size_t sizeNodeSlice,
200 size_t sizePrevSlice) {
202 (beginIndex + sizeNodeSlice) *
sizeof(DType),
203 sizePrevSlice *
sizeof(DType));
211#ifndef RABIT_STRICT_CXX98_
213 const int kPrintBuffer = 1 << 10;
214 std::string msg(kPrintBuffer,
'\0');
217 vsnprintf(&msg[0], kPrintBuffer, fmt, args);
219 msg.resize(strlen(msg.c_str()));
Definition allreduce_base.h:32
virtual bool IsDistributed() const =0
whether we run in distribted mode
virtual void TrackerPrint(const std::string &msg)=0
prints the msg in the tracker, this function can be used to communicate progress information to the u...
virtual int GetRank() const =0
gets rank of current node
virtual int LoadCheckPoint()=0
virtual int VersionNumber() const =0
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root)=0
broadcasts data from root to every other node
virtual int GetWorldSize() const =0
gets total number of nodes
virtual std::string GetHost() const =0
gets the host name of the current node
virtual int GetRingPrevRank() const =0
gets rank of previous node in ring topology
virtual void CheckPoint()=0
Increase internal version number. Deprecated.
virtual void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice)=0
Allgather function, each node have a segment of data in the ring of sendrecvbuf, the data provided by...
DataType
enum of supported data types
Definition engine.h:141
OpType
enum of all operators
Definition engine.h:132
bool Init(int argc, char *argv[])
initializes the engine module
Definition engine.cc:43
void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, IEngine::ReduceFunction red, mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun=nullptr, void *prepare_arg=nullptr)
perform in-place Allreduce, on sendrecvbuf this is an internal function used by rabit to be able to c...
Definition engine.cc:95
bool Finalize()
finalizes the engine module
Definition engine.cc:55
IEngine * GetEngine()
singleton method to get engine
Definition engine.cc:71
namespace of rabit
Definition engine.h:18
void CheckPoint()
deprecated, planned for removal after checkpoing from JVM package is removed.
Definition rabit-inl.h:228
std::string GetProcessorName()
gets processor's name
Definition rabit-inl.h:144
bool Finalize()
finalizes the rabit engine, call this function after you finished with all the jobs
Definition rabit-inl.h:124
int LoadCheckPoint()
deprecated, planned for removal after checkpoing from JVM package is removed.
Definition rabit-inl.h:226
void TrackerPrint(const std::string &msg)
prints the msg to the tracker, this function can be used to communicate progress information to the u...
Definition rabit-inl.h:208
void Allgather(DType *sendrecvbuf_, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice)
Allgather function, each node have a segment of data in the ring of sendrecvbuf, the data provided by...
void Broadcast(void *sendrecv_data, size_t size, int root)
broadcasts a memory region to every node from the root
Definition rabit-inl.h:148
void TrackerPrintf(const char *fmt,...)
prints the msg to the tracker, this function may not be available in very strict c++98 compilers,...
Definition rabit-inl.h:212
bool IsDistributed()
whether rabit env is in distributed mode
Definition rabit-inl.h:140
int VersionNumber()
Definition rabit-inl.h:230
bool Init(int argc, char *argv[])
initializes rabit, call this once at the beginning of your program
Definition rabit-inl.h:120
int GetRank()
gets rank of the current process
Definition rabit-inl.h:132
int GetWorldSize()
gets total number of processes
Definition rabit-inl.h:136
Copyright 2014-2023, XGBoost Contributors.
This file defines rabit's Allreduce/Broadcast interface The rabit engine contains the actual implemen...
bitwise AND reduction operator
Definition rabit-inl.h:88
bitwise OR reduction operator
Definition rabit-inl.h:95
bitwise XOR reduction operator
Definition rabit-inl.h:102
maximum reduction operator
Definition rabit-inl.h:67
minimum reduction operator
Definition rabit-inl.h:74
sum reduction operator
Definition rabit-inl.h:81
simple utils to support the code