Medial Code Documentation
Loading...
Searching...
No Matches
in_memory_communicator.h
1
4#pragma once
5#include <xgboost/json.h>
6
7#include <string>
8
9#include "../c_api/c_api_utils.h"
10#include "in_memory_handler.h"
11
12namespace xgboost {
13namespace collective {
14
19 public:
25 static Communicator* Create(Json const& config) {
26 int world_size{0};
27 int rank{-1};
28
29 // Parse environment variables first.
30 auto* value = getenv("IN_MEMORY_WORLD_SIZE");
31 if (value != nullptr) {
32 world_size = std::stoi(value);
33 }
34 value = getenv("IN_MEMORY_RANK");
35 if (value != nullptr) {
36 rank = std::stoi(value);
37 }
38
39 // Runtime configuration overrides, optional as users can specify them as env vars.
40 world_size = static_cast<int>(OptionalArg<Integer>(config, "in_memory_world_size",
41 static_cast<Integer::Int>(world_size)));
42 rank = static_cast<int>(
43 OptionalArg<Integer>(config, "in_memory_rank", static_cast<Integer::Int>(rank)));
44
45 if (world_size == 0) {
46 LOG(FATAL) << "Federated world size must be set.";
47 }
48 if (rank == -1) {
49 LOG(FATAL) << "Federated rank must be set.";
50 }
51 return new InMemoryCommunicator(world_size, rank);
52 }
53
54 InMemoryCommunicator(int world_size, int rank) : Communicator(world_size, rank) {
55 handler_.Init(world_size, rank);
56 }
57
58 ~InMemoryCommunicator() override { handler_.Shutdown(sequence_number_++, GetRank()); }
59
60 bool IsDistributed() const override { return true; }
61 bool IsFederated() const override { return false; }
62
63 void AllGather(void* in_out, std::size_t size) override {
64 std::string output;
65 handler_.Allgather(static_cast<const char*>(in_out), size, &output, sequence_number_++,
66 GetRank());
67 output.copy(static_cast<char*>(in_out), size);
68 }
69
70 void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override {
71 auto const bytes = size * GetTypeSize(data_type);
72 std::string output;
73 handler_.Allreduce(static_cast<const char*>(in_out), bytes, &output, sequence_number_++,
74 GetRank(), data_type, operation);
75 output.copy(static_cast<char*>(in_out), bytes);
76 }
77
78 void Broadcast(void* in_out, std::size_t size, int root) override {
79 std::string output;
80 handler_.Broadcast(static_cast<const char*>(in_out), size, &output, sequence_number_++,
81 GetRank(), root);
82 output.copy(static_cast<char*>(in_out), size);
83 }
84
85 std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); }
86
87 void Print(const std::string& message) override { LOG(CONSOLE) << message; }
88
89 protected:
90 void Shutdown() override {}
91
92 private:
93 static InMemoryHandler handler_;
94 uint64_t sequence_number_{};
95};
96
97} // namespace collective
98} // namespace xgboost
Data structure representing JSON format.
Definition json.h:357
A communicator class that handles collective communication.
Definition communicator.h:86
int GetRank() const
Get the rank of the current processes.
Definition communicator.h:117
An in-memory communicator, useful for testing.
Definition in_memory_communicator.h:18
bool IsDistributed() const override
Whether the communicator is running in distributed mode.
Definition in_memory_communicator.h:60
static Communicator * Create(Json const &config)
Create a new communicator based on JSON configuration.
Definition in_memory_communicator.h:25
void Broadcast(void *in_out, std::size_t size, int root) override
Broadcasts a message from the process with rank root to all other processes of the group.
Definition in_memory_communicator.h:78
bool IsFederated() const override
Whether the communicator is running in federated mode.
Definition in_memory_communicator.h:61
void Shutdown() override
Shuts down the communicator.
Definition in_memory_communicator.h:90
void Print(const std::string &message) override
Prints the message.
Definition in_memory_communicator.h:87
void AllReduce(void *in_out, std::size_t size, DataType data_type, Operation operation) override
Combines values from all processes and distributes the result back to all processes.
Definition in_memory_communicator.h:70
void AllGather(void *in_out, std::size_t size) override
Gathers data from all processes and distributes it to all processes.
Definition in_memory_communicator.h:63
std::string GetProcessorName() override
Gets the name of the processor.
Definition in_memory_communicator.h:85
Handles collective communication primitives in memory.
Definition in_memory_handler.h:18
void Broadcast(char const *input, std::size_t bytes, std::string *output, std::size_t sequence_number, int rank, int root)
Perform broadcast.
Definition in_memory_handler.cc:207
void Allgather(char const *input, std::size_t bytes, std::string *output, std::size_t sequence_number, int rank)
Perform allgather.
Definition in_memory_handler.cc:196
void Init(int world_size, int rank)
Initialize the handler with the world size and rank.
Definition in_memory_handler.cc:171
void Shutdown(uint64_t sequence_number, int rank)
Shut down the handler.
Definition in_memory_handler.cc:181
void Allreduce(char const *input, std::size_t bytes, std::string *output, std::size_t sequence_number, int rank, DataType data_type, Operation op)
Perform allreduce.
Definition in_memory_handler.cc:201
Operation
Defines the reduction operation.
Definition communicator.h:61
DataType
Defines the integral and floating data types.
Definition communicator.h:15
std::size_t GetTypeSize(DataType data_type)
Get the size of the data type.
Definition communicator.h:27
namespace of xgboost
Definition base.h:90