5#include <xgboost/json.h>
9#include "../c_api/c_api_utils.h"
10#include "in_memory_handler.h"
30 auto* value = getenv(
"IN_MEMORY_WORLD_SIZE");
31 if (value !=
nullptr) {
32 world_size = std::stoi(value);
34 value = getenv(
"IN_MEMORY_RANK");
35 if (value !=
nullptr) {
36 rank = std::stoi(value);
40 world_size =
static_cast<int>(OptionalArg<Integer>(config,
"in_memory_world_size",
41 static_cast<Integer::Int
>(world_size)));
42 rank =
static_cast<int>(
43 OptionalArg<Integer>(config,
"in_memory_rank",
static_cast<Integer::Int
>(rank)));
45 if (world_size == 0) {
46 LOG(FATAL) <<
"Federated world size must be set.";
49 LOG(FATAL) <<
"Federated rank must be set.";
55 handler_.
Init(world_size, rank);
58 ~InMemoryCommunicator()
override { handler_.
Shutdown(sequence_number_++,
GetRank()); }
63 void AllGather(
void* in_out, std::size_t size)
override {
65 handler_.
Allgather(
static_cast<const char*
>(in_out), size, &output, sequence_number_++,
67 output.copy(
static_cast<char*
>(in_out), size);
73 handler_.
Allreduce(
static_cast<const char*
>(in_out), bytes, &output, sequence_number_++,
74 GetRank(), data_type, operation);
75 output.copy(
static_cast<char*
>(in_out), bytes);
78 void Broadcast(
void* in_out, std::size_t size,
int root)
override {
80 handler_.
Broadcast(
static_cast<const char*
>(in_out), size, &output, sequence_number_++,
82 output.copy(
static_cast<char*
>(in_out), size);
87 void Print(
const std::string& message)
override { LOG(CONSOLE) << message; }
94 uint64_t sequence_number_{};
Data structure representing JSON format.
Definition json.h:357
A communicator class that handles collective communication.
Definition communicator.h:86
int GetRank() const
Get the rank of the current processes.
Definition communicator.h:117
An in-memory communicator, useful for testing.
Definition in_memory_communicator.h:18
bool IsDistributed() const override
Whether the communicator is running in distributed mode.
Definition in_memory_communicator.h:60
static Communicator * Create(Json const &config)
Create a new communicator based on JSON configuration.
Definition in_memory_communicator.h:25
void Broadcast(void *in_out, std::size_t size, int root) override
Broadcasts a message from the process with rank root to all other processes of the group.
Definition in_memory_communicator.h:78
bool IsFederated() const override
Whether the communicator is running in federated mode.
Definition in_memory_communicator.h:61
void Shutdown() override
Shuts down the communicator.
Definition in_memory_communicator.h:90
void Print(const std::string &message) override
Prints the message.
Definition in_memory_communicator.h:87
void AllReduce(void *in_out, std::size_t size, DataType data_type, Operation operation) override
Combines values from all processes and distributes the result back to all processes.
Definition in_memory_communicator.h:70
void AllGather(void *in_out, std::size_t size) override
Gathers data from all processes and distributes it to all processes.
Definition in_memory_communicator.h:63
std::string GetProcessorName() override
Gets the name of the processor.
Definition in_memory_communicator.h:85
Handles collective communication primitives in memory.
Definition in_memory_handler.h:18
void Broadcast(char const *input, std::size_t bytes, std::string *output, std::size_t sequence_number, int rank, int root)
Perform broadcast.
Definition in_memory_handler.cc:207
void Allgather(char const *input, std::size_t bytes, std::string *output, std::size_t sequence_number, int rank)
Perform allgather.
Definition in_memory_handler.cc:196
void Init(int world_size, int rank)
Initialize the handler with the world size and rank.
Definition in_memory_handler.cc:171
void Shutdown(uint64_t sequence_number, int rank)
Shut down the handler.
Definition in_memory_handler.cc:181
void Allreduce(char const *input, std::size_t bytes, std::string *output, std::size_t sequence_number, int rank, DataType data_type, Operation op)
Perform allreduce.
Definition in_memory_handler.cc:201
Operation
Defines the reduction operation.
Definition communicator.h:61
DataType
Defines the integral and floating data types.
Definition communicator.h:15
std::size_t GetTypeSize(DataType data_type)
Get the size of the data type.
Definition communicator.h:27
namespace of xgboost
Definition base.h:90