Medial Code Documentation
Loading...
Searching...
No Matches
rabit_communicator.h
1
4#pragma once
5#include <rabit/rabit.h>
6
7#include <string>
8#include <vector>
9
10#include "communicator.h"
11#include "xgboost/json.h"
12
13namespace xgboost {
14namespace collective {
15
17 public:
18 static Communicator *Create(Json const &config) {
19 std::vector<std::string> args_str;
20 for (auto &items : get<Object const>(config)) {
21 switch (items.second.GetValue().Type()) {
22 case xgboost::Value::ValueKind::kString: {
23 args_str.push_back(items.first + "=" + get<String const>(items.second));
24 break;
25 }
26 case xgboost::Value::ValueKind::kInteger: {
27 args_str.push_back(items.first + "=" + std::to_string(get<Integer const>(items.second)));
28 break;
29 }
30 case xgboost::Value::ValueKind::kBoolean: {
31 if (get<Boolean const>(items.second)) {
32 args_str.push_back(items.first + "=1");
33 } else {
34 args_str.push_back(items.first + "=0");
35 }
36 break;
37 }
38 default:
39 break;
40 }
41 }
42 std::vector<char *> args;
43 for (auto &key_value : args_str) {
44 args.push_back(&key_value[0]);
45 }
46 if (!rabit::Init(static_cast<int>(args.size()), &args[0])) {
47 LOG(FATAL) << "Failed to initialize Rabit";
48 }
50 }
51
52 RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {}
53
54 bool IsDistributed() const override { return rabit::IsDistributed(); }
55
56 bool IsFederated() const override { return false; }
57
58 void AllGather(void *send_receive_buffer, std::size_t size) override {
59 auto const per_rank = size / GetWorldSize();
60 auto const index = per_rank * GetRank();
61 rabit::Allgather(static_cast<char *>(send_receive_buffer), size, index, per_rank, per_rank);
62 }
63
64 void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
65 Operation op) override {
66 switch (data_type) {
67 case DataType::kInt8:
68 DoAllReduce<char>(send_receive_buffer, count, op);
69 break;
70 case DataType::kUInt8:
71 DoAllReduce<unsigned char>(send_receive_buffer, count, op);
72 break;
73 case DataType::kInt32:
74 DoAllReduce<std::int32_t>(send_receive_buffer, count, op);
75 break;
76 case DataType::kUInt32:
77 DoAllReduce<std::uint32_t>(send_receive_buffer, count, op);
78 break;
79 case DataType::kInt64:
80 DoAllReduce<std::int64_t>(send_receive_buffer, count, op);
81 break;
82 case DataType::kUInt64:
83 DoAllReduce<std::uint64_t>(send_receive_buffer, count, op);
84 break;
85 case DataType::kFloat:
86 DoAllReduce<float>(send_receive_buffer, count, op);
87 break;
88 case DataType::kDouble:
89 DoAllReduce<double>(send_receive_buffer, count, op);
90 break;
91 default:
92 LOG(FATAL) << "Unknown data type";
93 }
94 }
95
96 void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
97 rabit::Broadcast(send_receive_buffer, size, root);
98 }
99
100 std::string GetProcessorName() override { return rabit::GetProcessorName(); }
101
102 void Print(const std::string &message) override { rabit::TrackerPrint(message); }
103
104 protected:
105 void Shutdown() override { rabit::Finalize(); }
106
107 private:
108 template <typename DType, std::enable_if_t<std::is_integral<DType>::value> * = nullptr>
109 void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
110 switch (op) {
111 case Operation::kBitwiseAND:
112 rabit::Allreduce<rabit::op::BitAND, DType>(static_cast<DType *>(send_receive_buffer),
113 count);
114 break;
115 case Operation::kBitwiseOR:
116 rabit::Allreduce<rabit::op::BitOR, DType>(static_cast<DType *>(send_receive_buffer), count);
117 break;
118 case Operation::kBitwiseXOR:
119 rabit::Allreduce<rabit::op::BitXOR, DType>(static_cast<DType *>(send_receive_buffer),
120 count);
121 break;
122 default:
123 LOG(FATAL) << "Unknown allreduce operation";
124 }
125 }
126
127 template <typename DType, std::enable_if_t<std::is_floating_point<DType>::value> * = nullptr>
128 void DoBitwiseAllReduce(void *, std::size_t, Operation) {
129 LOG(FATAL) << "Floating point types do not support bitwise operations.";
130 }
131
132 template <typename DType>
133 void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
134 switch (op) {
135 case Operation::kMax:
136 rabit::Allreduce<rabit::op::Max, DType>(static_cast<DType *>(send_receive_buffer), count);
137 break;
138 case Operation::kMin:
139 rabit::Allreduce<rabit::op::Min, DType>(static_cast<DType *>(send_receive_buffer), count);
140 break;
141 case Operation::kSum:
142 rabit::Allreduce<rabit::op::Sum, DType>(static_cast<DType *>(send_receive_buffer), count);
143 break;
144 case Operation::kBitwiseAND:
145 case Operation::kBitwiseOR:
146 case Operation::kBitwiseXOR:
147 DoBitwiseAllReduce<DType>(send_receive_buffer, count, op);
148 break;
149 default:
150 LOG(FATAL) << "Unknown allreduce operation";
151 }
152 }
153};
154} // namespace collective
155} // namespace xgboost
Data structure representing JSON format.
Definition json.h:357
A communicator class that handles collective communication.
Definition communicator.h:86
int GetRank() const
Get the rank of the current processes.
Definition communicator.h:117
int GetWorldSize() const
Get the total number of processes.
Definition communicator.h:114
Definition rabit_communicator.h:16
std::string GetProcessorName() override
Gets the name of the processor.
Definition rabit_communicator.h:100
bool IsFederated() const override
Whether the communicator is running in federated mode.
Definition rabit_communicator.h:56
void Print(const std::string &message) override
Prints the message.
Definition rabit_communicator.h:102
bool IsDistributed() const override
Whether the communicator is running in distributed mode.
Definition rabit_communicator.h:54
void Shutdown() override
Shuts down the communicator.
Definition rabit_communicator.h:105
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override
Broadcasts a message from the process with rank root to all other processes of the group.
Definition rabit_communicator.h:96
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override
Combines values from all processes and distributes the result back to all processes.
Definition rabit_communicator.h:64
void AllGather(void *send_receive_buffer, std::size_t size) override
Gathers data from all processes and distributes it to all processes.
Definition rabit_communicator.h:58
std::string GetProcessorName()
gets processor's name
Definition rabit-inl.h:144
bool Finalize()
finalizes the rabit engine, call this function after you finished with all the jobs
Definition rabit-inl.h:124
void TrackerPrint(const std::string &msg)
prints the msg to the tracker, this function can be used to communicate progress information to the u...
Definition rabit-inl.h:208
void Allgather(DType *sendrecvbuf_, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice)
Allgather function, each node have a segment of data in the ring of sendrecvbuf, the data provided by...
void Broadcast(void *sendrecv_data, size_t size, int root)
broadcasts a memory region to every node from the root
Definition rabit-inl.h:148
bool IsDistributed()
whether rabit env is in distributed mode
Definition rabit-inl.h:140
bool Init(int argc, char *argv[])
initializes rabit, call this once at the beginning of your program
Definition rabit-inl.h:120
int GetRank()
gets rank of the current process
Definition rabit-inl.h:132
int GetWorldSize()
gets total number of processes
Definition rabit-inl.h:136
Operation
Defines the reduction operation.
Definition communicator.h:61
DataType
Defines the integral and floating data types.
Definition communicator.h:15
namespace of xgboost
Definition base.h:90
This file defines rabit's Allreduce/Broadcast interface The rabit engine contains the actual implemen...