8#include "communicator.h"
129inline void Broadcast(
void *send_receive_buffer,
size_t size,
int root) {
133inline void Broadcast(std::string *sendrecv_data,
int root) {
134 size_t size = sendrecv_data->length();
136 if (sendrecv_data->length() != size) {
137 sendrecv_data->resize(size);
140 Broadcast(&(*sendrecv_data)[0], size *
sizeof(
char), root);
153inline void Allgather(
void *send_receive_buffer, std::size_t size) {
170inline void Allreduce(
void *send_receive_buffer,
size_t count,
int data_type,
int op) {
179template <Operation op>
180inline void Allreduce(int8_t *send_receive_buffer,
size_t count) {
184template <Operation op>
185inline void Allreduce(uint8_t *send_receive_buffer,
size_t count) {
189template <Operation op>
190inline void Allreduce(int32_t *send_receive_buffer,
size_t count) {
194template <Operation op>
195inline void Allreduce(uint32_t *send_receive_buffer,
size_t count) {
199template <Operation op>
200inline void Allreduce(int64_t *send_receive_buffer,
size_t count) {
204template <Operation op>
205inline void Allreduce(uint64_t *send_receive_buffer,
size_t count) {
212 typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}> >
213inline void Allreduce(T *send_receive_buffer,
size_t count) {
214 static_assert(
sizeof(T) ==
sizeof(uint64_t));
218template <Operation op>
219inline void Allreduce(
float *send_receive_buffer,
size_t count) {
223template <Operation op>
224inline void Allreduce(
double *send_receive_buffer,
size_t count) {
230 std::vector<std::size_t> offsets;
231 std::vector<std::size_t> sizes;
232 std::vector<T> result;
245 std::vector<std::size_t>
const &sizes) {
246 auto num_inputs = sizes.size();
249 std::vector<std::size_t> all_sizes(num_inputs *
GetWorldSize());
250 std::copy_n(sizes.cbegin(), sizes.size(), all_sizes.begin() + num_inputs *
GetRank());
254 std::vector<std::size_t> offsets(all_sizes.size());
255 for (std::size_t i = 1; i < offsets.size(); i++) {
256 offsets[i] = offsets[i - 1] + all_sizes[i - 1];
260 auto total_input_size = offsets.back() + all_sizes.back();
261 std::vector<T> all_inputs(total_input_size);
262 std::copy_n(inputs.cbegin(), inputs.size(), all_inputs.begin() + offsets[num_inputs *
GetRank()]);
264 Allreduce<Operation::kMax>(all_inputs.data(), all_inputs.size());
266 return {offsets, all_sizes, all_inputs};
Data structure representing JSON format.
Definition json.h:357
virtual std::string GetProcessorName()=0
Gets the name of the processor.
static void Init(Json const &config)
Initialize the communicator.
Definition communicator.cc:20
static Communicator * Get()
Get the communicator instance.
Definition communicator.h:99
int GetRank() const
Get the rank of the current processes.
Definition communicator.h:117
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op)=0
Combines values from all processes and distributes the result back to all processes.
virtual void AllGather(void *send_receive_buffer, std::size_t size)=0
Gathers data from all processes and distributes it to all processes.
virtual bool IsDistributed() const =0
Whether the communicator is running in distributed mode.
static void Finalize()
Finalize the communicator.
Definition communicator.cc:55
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root)=0
Broadcasts a message from the process with rank root to all other processes of the group.
virtual void Print(std::string const &message)=0
Prints the message.
int GetWorldSize() const
Get the total number of processes.
Definition communicator.h:114
virtual bool IsFederated() const =0
Whether the communicator is running in federated mode.
void Broadcast(void *send_receive_buffer, size_t size, int root)
Broadcast a memory region to all others from root. This function is NOT thread-safe.
Definition communicator-inl.h:129
Operation
Defines the reduction operation.
Definition communicator.h:61
void Allgather(void *send_receive_buffer, std::size_t size)
Gathers data from all processes and distributes it to all processes.
Definition communicator-inl.h:153
int GetWorldSize()
Get total number of processes.
Definition communicator-inl.h:83
void Init(Json const &config)
Initialize the collective communicator.
Definition communicator-inl.h:60
DataType
Defines the integral and floating data types.
Definition communicator.h:15
bool IsFederated()
Get if the communicator is federated.
Definition communicator-inl.h:97
void Print(char const *message)
Print the message to the communicator.
Definition communicator-inl.h:107
bool IsDistributed()
Get if the communicator is distributed.
Definition communicator-inl.h:90
std::string GetProcessorName()
Get the name of the processor.
Definition communicator-inl.h:116
void Finalize()
Finalize the collective communicator.
Definition communicator-inl.h:69
int GetRank()
Get rank of current process.
Definition communicator-inl.h:76
void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op)
Perform in-place allreduce. This function is NOT thread-safe.
Definition communicator-inl.h:170
AllgatherVResult< T > AllgatherV(std::vector< T > const &inputs, std::vector< std::size_t > const &sizes)
Gathers variable-length data from all processes and distributes it to all processes.
Definition communicator-inl.h:244
namespace of xgboost
Definition base.h:90
DataType
data type accepted by xgboost interface
Definition data.h:33
Definition communicator-inl.h:229