Medial Code Documentation
Loading...
Searching...
No Matches
communicator.h
1
4#pragma once
5#include <xgboost/json.h>
6#include <xgboost/logging.h>
7
8#include <memory>
9#include <string>
10
11namespace xgboost {
12namespace collective {
13
15enum class DataType {
16 kInt8 = 0,
17 kUInt8 = 1,
18 kInt32 = 2,
19 kUInt32 = 3,
20 kInt64 = 4,
21 kUInt64 = 5,
22 kFloat = 6,
23 kDouble = 7
24};
25
27inline std::size_t GetTypeSize(DataType data_type) {
28 std::size_t size{0};
29 switch (data_type) {
30 case DataType::kInt8:
31 size = sizeof(std::int8_t);
32 break;
33 case DataType::kUInt8:
34 size = sizeof(std::uint8_t);
35 break;
36 case DataType::kInt32:
37 size = sizeof(std::int32_t);
38 break;
39 case DataType::kUInt32:
40 size = sizeof(std::uint32_t);
41 break;
42 case DataType::kInt64:
43 size = sizeof(std::int64_t);
44 break;
45 case DataType::kUInt64:
46 size = sizeof(std::uint64_t);
47 break;
48 case DataType::kFloat:
49 size = sizeof(float);
50 break;
51 case DataType::kDouble:
52 size = sizeof(double);
53 break;
54 default:
55 LOG(FATAL) << "Unknown data type.";
56 }
57 return size;
58}
59
61enum class Operation {
62 kMax = 0,
63 kMin = 1,
64 kSum = 2,
65 kBitwiseAND = 3,
66 kBitwiseOR = 4,
67 kBitwiseXOR = 5
68};
69
70class DeviceCommunicator;
71
72enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory, kInMemoryNccl };
73
75inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
76#ifdef _MSC_VER
77 return _stricmp(s1, s2);
78#else // _MSC_VER
79 return strcasecmp(s1, s2);
80#endif // _MSC_VER
81}
82
87 public:
93 static void Init(Json const &config);
94
96 static void Finalize();
97
99 static Communicator *Get() { return communicator_.get(); }
100
101#if defined(XGBOOST_USE_CUDA)
108 static DeviceCommunicator *GetDevice(int device_ordinal);
109#endif
110
111 virtual ~Communicator() = default;
112
114 int GetWorldSize() const { return world_size_; }
115
117 int GetRank() const { return rank_; }
118
120 virtual bool IsDistributed() const = 0;
121
123 virtual bool IsFederated() const = 0;
124
134 virtual void AllGather(void *send_receive_buffer, std::size_t size) = 0;
135
144 virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
145 Operation op) = 0;
146
155 virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0;
156
160 virtual std::string GetProcessorName() = 0;
161
165 virtual void Print(std::string const &message) = 0;
166
168 static CommunicatorType GetTypeFromEnv() {
169 auto *env = std::getenv("XGBOOST_COMMUNICATOR");
170 if (env != nullptr) {
171 return StringToType(env);
172 } else {
173 return CommunicatorType::kUnknown;
174 }
175 }
176
178 static CommunicatorType GetTypeFromConfig(Json const &config) {
179 auto const &j_upper = config["XGBOOST_COMMUNICATOR"];
180 if (IsA<String const>(j_upper)) {
181 return StringToType(get<String const>(j_upper).c_str());
182 }
183 auto const &j_lower = config["xgboost_communicator"];
184 if (IsA<String const>(j_lower)) {
185 return StringToType(get<String const>(j_lower).c_str());
186 }
187 return CommunicatorType::kUnknown;
188 }
189
190 protected:
197 Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) {
198 if (world_size < 1) {
199 LOG(FATAL) << "World size " << world_size << " is less than 1.";
200 }
201 if (rank < 0) {
202 LOG(FATAL) << "Rank " << rank << " is less than 0.";
203 }
204 if (rank >= world_size) {
205 LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << ".";
206 }
207 }
208
212 virtual void Shutdown() = 0;
213
214 private:
215 static CommunicatorType StringToType(char const *str) {
216 CommunicatorType result = CommunicatorType::kUnknown;
217 if (!CompareStringsCaseInsensitive("rabit", str)) {
218 result = CommunicatorType::kRabit;
219 } else if (!CompareStringsCaseInsensitive("federated", str)) {
220 result = CommunicatorType::kFederated;
221 } else if (!CompareStringsCaseInsensitive("in-memory", str)) {
222 result = CommunicatorType::kInMemory;
223 } else if (!CompareStringsCaseInsensitive("in-memory-nccl", str)) {
224 result = CommunicatorType::kInMemoryNccl;
225 } else {
226 LOG(FATAL) << "Unknown communicator type " << str;
227 }
228 return result;
229 }
230
231 static thread_local std::unique_ptr<Communicator> communicator_;
232 static thread_local CommunicatorType type_;
233#if defined(XGBOOST_USE_CUDA)
234 static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
235#endif
236
237 int const world_size_;
238 int const rank_;
239};
240
241} // namespace collective
242} // namespace xgboost
Data structure representing JSON format.
Definition json.h:357
A communicator class that handles collective communication.
Definition communicator.h:86
Communicator(int world_size, int rank)
Construct a new communicator.
Definition communicator.h:197
static CommunicatorType GetTypeFromConfig(Json const &config)
Get the communicator type from runtime configuration.
Definition communicator.h:178
virtual std::string GetProcessorName()=0
Gets the name of the processor.
static void Init(Json const &config)
Initialize the communicator.
Definition communicator.cc:20
static Communicator * Get()
Get the communicator instance.
Definition communicator.h:99
int GetRank() const
Get the rank of the current processes.
Definition communicator.h:117
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op)=0
Combines values from all processes and distributes the result back to all processes.
virtual void AllGather(void *send_receive_buffer, std::size_t size)=0
Gathers data from all processes and distributes it to all processes.
virtual bool IsDistributed() const =0
Whether the communicator is running in distributed mode.
static void Finalize()
Finalize the communicator.
Definition communicator.cc:55
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root)=0
Broadcasts a message from the process with rank root to all other processes of the group.
virtual void Print(std::string const &message)=0
Prints the message.
int GetWorldSize() const
Get the total number of processes.
Definition communicator.h:114
static CommunicatorType GetTypeFromEnv()
Get the communicator type from environment variables.
Definition communicator.h:168
virtual bool IsFederated() const =0
Whether the communicator is running in federated mode.
virtual void Shutdown()=0
Shuts down the communicator.
defines console logging options for xgboost. Use to enforce unified print behavior.
Operation
Defines the reduction operation.
Definition communicator.h:61
DataType
Defines the integral and floating data types.
Definition communicator.h:15
std::size_t GetTypeSize(DataType data_type)
Get the size of the data type.
Definition communicator.h:27
int CompareStringsCaseInsensitive(const char *s1, const char *s2)
Case-insensitive string comparison.
Definition communicator.h:75
namespace of xgboost
Definition base.h:90