5#include <xgboost/json.h>
31 size =
sizeof(std::int8_t);
33 case DataType::kUInt8:
34 size =
sizeof(std::uint8_t);
36 case DataType::kInt32:
37 size =
sizeof(std::int32_t);
39 case DataType::kUInt32:
40 size =
sizeof(std::uint32_t);
42 case DataType::kInt64:
43 size =
sizeof(std::int64_t);
45 case DataType::kUInt64:
46 size =
sizeof(std::uint64_t);
48 case DataType::kFloat:
51 case DataType::kDouble:
52 size =
sizeof(double);
55 LOG(FATAL) <<
"Unknown data type.";
70class DeviceCommunicator;
72enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory, kInMemoryNccl };
77 return _stricmp(s1, s2);
79 return strcasecmp(s1, s2);
93 static void Init(
Json const &config);
101#if defined(XGBOOST_USE_CUDA)
108 static DeviceCommunicator *GetDevice(
int device_ordinal);
134 virtual void AllGather(
void *send_receive_buffer, std::size_t size) = 0;
155 virtual void Broadcast(
void *send_receive_buffer, std::size_t size,
int root) = 0;
165 virtual void Print(std::string
const &message) = 0;
169 auto *env = std::getenv(
"XGBOOST_COMMUNICATOR");
170 if (env !=
nullptr) {
171 return StringToType(env);
173 return CommunicatorType::kUnknown;
179 auto const &j_upper = config[
"XGBOOST_COMMUNICATOR"];
180 if (IsA<String const>(j_upper)) {
181 return StringToType(get<String const>(j_upper).c_str());
183 auto const &j_lower = config[
"xgboost_communicator"];
184 if (IsA<String const>(j_lower)) {
185 return StringToType(get<String const>(j_lower).c_str());
187 return CommunicatorType::kUnknown;
197 Communicator(
int world_size,
int rank) : world_size_(world_size), rank_(rank) {
198 if (world_size < 1) {
199 LOG(FATAL) <<
"World size " << world_size <<
" is less than 1.";
202 LOG(FATAL) <<
"Rank " << rank <<
" is less than 0.";
204 if (rank >= world_size) {
205 LOG(FATAL) <<
"Rank " << rank <<
" is greater than world_size - 1: " << world_size - 1 <<
".";
215 static CommunicatorType StringToType(
char const *str) {
216 CommunicatorType result = CommunicatorType::kUnknown;
218 result = CommunicatorType::kRabit;
220 result = CommunicatorType::kFederated;
222 result = CommunicatorType::kInMemory;
224 result = CommunicatorType::kInMemoryNccl;
226 LOG(FATAL) <<
"Unknown communicator type " << str;
231 static thread_local std::unique_ptr<Communicator> communicator_;
232 static thread_local CommunicatorType type_;
233#if defined(XGBOOST_USE_CUDA)
234 static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
237 int const world_size_;
Data structure representing JSON format.
Definition json.h:357
A communicator class that handles collective communication.
Definition communicator.h:86
Communicator(int world_size, int rank)
Construct a new communicator.
Definition communicator.h:197
static CommunicatorType GetTypeFromConfig(Json const &config)
Get the communicator type from runtime configuration.
Definition communicator.h:178
virtual std::string GetProcessorName()=0
Gets the name of the processor.
static void Init(Json const &config)
Initialize the communicator.
Definition communicator.cc:20
static Communicator * Get()
Get the communicator instance.
Definition communicator.h:99
int GetRank() const
Get the rank of the current processes.
Definition communicator.h:117
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op)=0
Combines values from all processes and distributes the result back to all processes.
virtual void AllGather(void *send_receive_buffer, std::size_t size)=0
Gathers data from all processes and distributes it to all processes.
virtual bool IsDistributed() const =0
Whether the communicator is running in distributed mode.
static void Finalize()
Finalize the communicator.
Definition communicator.cc:55
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root)=0
Broadcasts a message from the process with rank root to all other processes of the group.
virtual void Print(std::string const &message)=0
Prints the message.
int GetWorldSize() const
Get the total number of processes.
Definition communicator.h:114
static CommunicatorType GetTypeFromEnv()
Get the communicator type from environment variables.
Definition communicator.h:168
virtual bool IsFederated() const =0
Whether the communicator is running in federated mode.
virtual void Shutdown()=0
Shuts down the communicator.
defines console logging options for xgboost. Use to enforce unified print behavior.
Operation
Defines the reduction operation.
Definition communicator.h:61
DataType
Defines the integral and floating data types.
Definition communicator.h:15
std::size_t GetTypeSize(DataType data_type)
Get the size of the data type.
Definition communicator.h:27
int CompareStringsCaseInsensitive(const char *s1, const char *s2)
Case-insensitive string comparison.
Definition communicator.h:75
namespace of xgboost
Definition base.h:90