Medial Code Documentation
Loading...
Searching...
No Matches
communicator-inl.h
1
4#pragma once
5#include <string>
6#include <vector>
7
8#include "communicator.h"
9
10namespace xgboost {
11namespace collective {
12
60inline void Init(Json const& config) {
61 Communicator::Init(config);
62}
63
69inline void Finalize() { Communicator::Finalize(); }
70
76inline int GetRank() { return Communicator::Get()->GetRank(); }
77
83inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
84
90inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
91
97inline bool IsFederated() { return Communicator::Get()->IsFederated(); }
98
107inline void Print(char const *message) { Communicator::Get()->Print(message); }
108
109inline void Print(std::string const &message) { Communicator::Get()->Print(message); }
110
116inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); }
117
129inline void Broadcast(void *send_receive_buffer, size_t size, int root) {
130 Communicator::Get()->Broadcast(send_receive_buffer, size, root);
131}
132
133inline void Broadcast(std::string *sendrecv_data, int root) {
134 size_t size = sendrecv_data->length();
135 Broadcast(&size, sizeof(size), root);
136 if (sendrecv_data->length() != size) {
137 sendrecv_data->resize(size);
138 }
139 if (size != 0) {
140 Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
141 }
142}
143
153inline void Allgather(void *send_receive_buffer, std::size_t size) {
154 Communicator::Get()->AllGather(send_receive_buffer, size);
155}
156
170inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) {
171 Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast<DataType>(data_type),
172 static_cast<Operation>(op));
173}
174
175inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) {
176 Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op);
177}
178
179template <Operation op>
180inline void Allreduce(int8_t *send_receive_buffer, size_t count) {
181 Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
182}
183
184template <Operation op>
185inline void Allreduce(uint8_t *send_receive_buffer, size_t count) {
186 Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
187}
188
189template <Operation op>
190inline void Allreduce(int32_t *send_receive_buffer, size_t count) {
191 Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
192}
193
194template <Operation op>
195inline void Allreduce(uint32_t *send_receive_buffer, size_t count) {
196 Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
197}
198
199template <Operation op>
200inline void Allreduce(int64_t *send_receive_buffer, size_t count) {
201 Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
202}
203
204template <Operation op>
205inline void Allreduce(uint64_t *send_receive_buffer, size_t count) {
206 Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
207}
208
209// Specialization for size_t, which is implementation defined, so it might or might not
210// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
211template <Operation op, typename T,
212 typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}> >
213inline void Allreduce(T *send_receive_buffer, size_t count) {
214 static_assert(sizeof(T) == sizeof(uint64_t));
215 Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
216}
217
218template <Operation op>
219inline void Allreduce(float *send_receive_buffer, size_t count) {
220 Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
221}
222
223template <Operation op>
224inline void Allreduce(double *send_receive_buffer, size_t count) {
225 Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
226}
227
228template <typename T>
230 std::vector<std::size_t> offsets;
231 std::vector<std::size_t> sizes;
232 std::vector<T> result;
233};
234
243template <typename T>
244inline AllgatherVResult<T> AllgatherV(std::vector<T> const &inputs,
245 std::vector<std::size_t> const &sizes) {
246 auto num_inputs = sizes.size();
247
248 // Gather the sizes across all workers.
249 std::vector<std::size_t> all_sizes(num_inputs * GetWorldSize());
250 std::copy_n(sizes.cbegin(), sizes.size(), all_sizes.begin() + num_inputs * GetRank());
251 collective::Allgather(all_sizes.data(), all_sizes.size() * sizeof(std::size_t));
252
253 // Calculate input offsets (std::exclusive_scan).
254 std::vector<std::size_t> offsets(all_sizes.size());
255 for (std::size_t i = 1; i < offsets.size(); i++) {
256 offsets[i] = offsets[i - 1] + all_sizes[i - 1];
257 }
258
259 // Gather all the inputs.
260 auto total_input_size = offsets.back() + all_sizes.back();
261 std::vector<T> all_inputs(total_input_size);
262 std::copy_n(inputs.cbegin(), inputs.size(), all_inputs.begin() + offsets[num_inputs * GetRank()]);
263 // We cannot use allgather here, since each worker might have a different size.
264 Allreduce<Operation::kMax>(all_inputs.data(), all_inputs.size());
265
266 return {offsets, all_sizes, all_inputs};
267}
268
269} // namespace collective
270} // namespace xgboost
Data structure representing JSON format.
Definition json.h:357
virtual std::string GetProcessorName()=0
Gets the name of the processor.
static void Init(Json const &config)
Initialize the communicator.
Definition communicator.cc:20
static Communicator * Get()
Get the communicator instance.
Definition communicator.h:99
int GetRank() const
Get the rank of the current processes.
Definition communicator.h:117
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op)=0
Combines values from all processes and distributes the result back to all processes.
virtual void AllGather(void *send_receive_buffer, std::size_t size)=0
Gathers data from all processes and distributes it to all processes.
virtual bool IsDistributed() const =0
Whether the communicator is running in distributed mode.
static void Finalize()
Finalize the communicator.
Definition communicator.cc:55
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root)=0
Broadcasts a message from the process with rank root to all other processes of the group.
virtual void Print(std::string const &message)=0
Prints the message.
int GetWorldSize() const
Get the total number of processes.
Definition communicator.h:114
virtual bool IsFederated() const =0
Whether the communicator is running in federated mode.
void Broadcast(void *send_receive_buffer, size_t size, int root)
Broadcast a memory region to all others from root. This function is NOT thread-safe.
Definition communicator-inl.h:129
Operation
Defines the reduction operation.
Definition communicator.h:61
void Allgather(void *send_receive_buffer, std::size_t size)
Gathers data from all processes and distributes it to all processes.
Definition communicator-inl.h:153
int GetWorldSize()
Get total number of processes.
Definition communicator-inl.h:83
void Init(Json const &config)
Initialize the collective communicator.
Definition communicator-inl.h:60
DataType
Defines the integral and floating data types.
Definition communicator.h:15
bool IsFederated()
Get if the communicator is federated.
Definition communicator-inl.h:97
void Print(char const *message)
Print the message to the communicator.
Definition communicator-inl.h:107
bool IsDistributed()
Get if the communicator is distributed.
Definition communicator-inl.h:90
std::string GetProcessorName()
Get the name of the processor.
Definition communicator-inl.h:116
void Finalize()
Finalize the collective communicator.
Definition communicator-inl.h:69
int GetRank()
Get rank of current process.
Definition communicator-inl.h:76
void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op)
Perform in-place allreduce. This function is NOT thread-safe.
Definition communicator-inl.h:170
AllgatherVResult< T > AllgatherV(std::vector< T > const &inputs, std::vector< std::size_t > const &sizes)
Gathers variable-length data from all processes and distributes it to all processes.
Definition communicator-inl.h:244
namespace of xgboost
Definition base.h:90
DataType
data type accepted by xgboost interface
Definition data.h:33
Definition communicator-inl.h:229