10#include "communicator.h"
11#include "xgboost/json.h"
19 std::vector<std::string> args_str;
20 for (
auto &items : get<Object const>(config)) {
21 switch (items.second.GetValue().Type()) {
22 case xgboost::Value::ValueKind::kString: {
23 args_str.push_back(items.first +
"=" + get<String const>(items.second));
26 case xgboost::Value::ValueKind::kInteger: {
27 args_str.push_back(items.first +
"=" + std::to_string(get<Integer const>(items.second)));
30 case xgboost::Value::ValueKind::kBoolean: {
31 if (get<Boolean const>(items.second)) {
32 args_str.push_back(items.first +
"=1");
34 args_str.push_back(items.first +
"=0");
42 std::vector<char *> args;
43 for (
auto &key_value : args_str) {
44 args.push_back(&key_value[0]);
46 if (!
rabit::Init(
static_cast<int>(args.size()), &args[0])) {
47 LOG(FATAL) <<
"Failed to initialize Rabit";
58 void AllGather(
void *send_receive_buffer, std::size_t size)
override {
60 auto const index = per_rank *
GetRank();
61 rabit::Allgather(
static_cast<char *
>(send_receive_buffer), size, index, per_rank, per_rank);
68 DoAllReduce<char>(send_receive_buffer, count, op);
70 case DataType::kUInt8:
71 DoAllReduce<unsigned char>(send_receive_buffer, count, op);
73 case DataType::kInt32:
74 DoAllReduce<std::int32_t>(send_receive_buffer, count, op);
76 case DataType::kUInt32:
77 DoAllReduce<std::uint32_t>(send_receive_buffer, count, op);
79 case DataType::kInt64:
80 DoAllReduce<std::int64_t>(send_receive_buffer, count, op);
82 case DataType::kUInt64:
83 DoAllReduce<std::uint64_t>(send_receive_buffer, count, op);
85 case DataType::kFloat:
86 DoAllReduce<float>(send_receive_buffer, count, op);
88 case DataType::kDouble:
89 DoAllReduce<double>(send_receive_buffer, count, op);
92 LOG(FATAL) <<
"Unknown data type";
96 void Broadcast(
void *send_receive_buffer, std::size_t size,
int root)
override {
108 template <typename DType, std::enable_if_t<std::is_integral<DType>::value> * =
nullptr>
109 void DoBitwiseAllReduce(
void *send_receive_buffer, std::size_t count,
Operation op) {
111 case Operation::kBitwiseAND:
112 rabit::Allreduce<rabit::op::BitAND, DType>(
static_cast<DType *
>(send_receive_buffer),
115 case Operation::kBitwiseOR:
116 rabit::Allreduce<rabit::op::BitOR, DType>(
static_cast<DType *
>(send_receive_buffer), count);
118 case Operation::kBitwiseXOR:
119 rabit::Allreduce<rabit::op::BitXOR, DType>(
static_cast<DType *
>(send_receive_buffer),
123 LOG(FATAL) <<
"Unknown allreduce operation";
127 template <typename DType, std::enable_if_t<std::is_floating_point<DType>::value> * =
nullptr>
128 void DoBitwiseAllReduce(
void *, std::size_t,
Operation) {
129 LOG(FATAL) <<
"Floating point types do not support bitwise operations.";
132 template <
typename DType>
133 void DoAllReduce(
void *send_receive_buffer, std::size_t count,
Operation op) {
135 case Operation::kMax:
136 rabit::Allreduce<rabit::op::Max, DType>(
static_cast<DType *
>(send_receive_buffer), count);
138 case Operation::kMin:
139 rabit::Allreduce<rabit::op::Min, DType>(
static_cast<DType *
>(send_receive_buffer), count);
141 case Operation::kSum:
142 rabit::Allreduce<rabit::op::Sum, DType>(
static_cast<DType *
>(send_receive_buffer), count);
144 case Operation::kBitwiseAND:
145 case Operation::kBitwiseOR:
146 case Operation::kBitwiseXOR:
147 DoBitwiseAllReduce<DType>(send_receive_buffer, count, op);
150 LOG(FATAL) <<
"Unknown allreduce operation";
Data structure representing JSON format.
Definition json.h:357
A communicator class that handles collective communication.
Definition communicator.h:86
int GetRank() const
Get the rank of the current processes.
Definition communicator.h:117
int GetWorldSize() const
Get the total number of processes.
Definition communicator.h:114
Definition rabit_communicator.h:16
std::string GetProcessorName() override
Gets the name of the processor.
Definition rabit_communicator.h:100
bool IsFederated() const override
Whether the communicator is running in federated mode.
Definition rabit_communicator.h:56
void Print(const std::string &message) override
Prints the message.
Definition rabit_communicator.h:102
bool IsDistributed() const override
Whether the communicator is running in distributed mode.
Definition rabit_communicator.h:54
void Shutdown() override
Shuts down the communicator.
Definition rabit_communicator.h:105
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override
Broadcasts a message from the process with rank root to all other processes of the group.
Definition rabit_communicator.h:96
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override
Combines values from all processes and distributes the result back to all processes.
Definition rabit_communicator.h:64
void AllGather(void *send_receive_buffer, std::size_t size) override
Gathers data from all processes and distributes it to all processes.
Definition rabit_communicator.h:58
std::string GetProcessorName()
gets processor's name
Definition rabit-inl.h:144
bool Finalize()
finalizes the rabit engine, call this function after you finished with all the jobs
Definition rabit-inl.h:124
void TrackerPrint(const std::string &msg)
prints the msg to the tracker, this function can be used to communicate progress information to the u...
Definition rabit-inl.h:208
void Allgather(DType *sendrecvbuf_, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice)
Allgather function, each node have a segment of data in the ring of sendrecvbuf, the data provided by...
void Broadcast(void *sendrecv_data, size_t size, int root)
broadcasts a memory region to every node from the root
Definition rabit-inl.h:148
bool IsDistributed()
whether rabit env is in distributed mode
Definition rabit-inl.h:140
bool Init(int argc, char *argv[])
initializes rabit, call this once at the beginning of your program
Definition rabit-inl.h:120
int GetRank()
gets rank of the current process
Definition rabit-inl.h:132
int GetWorldSize()
gets total number of processes
Definition rabit-inl.h:136
Operation
Defines the reduction operation.
Definition communicator.h:61
DataType
Defines the integral and floating data types.
Definition communicator.h:15
namespace of xgboost
Definition base.h:90
This file defines rabit's Allreduce/Broadcast interface The rabit engine contains the actual implemen...