5#include <xgboost/json.h>
7#include "../../src/c_api/c_api_utils.h"
8#include "../../src/collective/communicator.h"
9#include "../../src/common/io.h"
10#include "federated_client.h"
26 std::string server_address{};
29 std::string server_cert{};
30 std::string client_key{};
31 std::string client_cert{};
34 auto *value = getenv(
"FEDERATED_SERVER_ADDRESS");
35 if (value !=
nullptr) {
36 server_address = value;
38 value = getenv(
"FEDERATED_WORLD_SIZE");
39 if (value !=
nullptr) {
40 world_size = std::stoi(value);
42 value = getenv(
"FEDERATED_RANK");
43 if (value !=
nullptr) {
44 rank = std::stoi(value);
46 value = getenv(
"FEDERATED_SERVER_CERT");
47 if (value !=
nullptr) {
50 value = getenv(
"FEDERATED_CLIENT_KEY");
51 if (value !=
nullptr) {
54 value = getenv(
"FEDERATED_CLIENT_CERT");
55 if (value !=
nullptr) {
60 server_address = OptionalArg<String>(config,
"federated_server_address", server_address);
62 OptionalArg<Integer>(config,
"federated_world_size",
static_cast<Integer::Int
>(world_size));
63 rank = OptionalArg<Integer>(config,
"federated_rank",
static_cast<Integer::Int
>(rank));
64 server_cert = OptionalArg<String>(config,
"federated_server_cert", server_cert);
65 client_key = OptionalArg<String>(config,
"federated_client_key", client_key);
66 client_cert = OptionalArg<String>(config,
"federated_client_cert", client_cert);
68 if (server_address.empty()) {
69 LOG(FATAL) <<
"Federated server address must be set.";
71 if (world_size == 0) {
72 LOG(FATAL) <<
"Federated world size must be set.";
75 LOG(FATAL) <<
"Federated rank must be set.";
92 std::string
const &server_cert_path, std::string
const &client_key_path,
93 std::string
const &client_cert_path)
95 if (server_cert_path.empty() || client_key_path.empty() || client_cert_path.empty()) {
134 void AllGather(
void *send_receive_buffer, std::size_t size)
override {
135 std::string
const send_buffer(
reinterpret_cast<char const *
>(send_receive_buffer), size);
136 auto const received = client_->Allgather(send_buffer);
137 received.copy(
reinterpret_cast<char *
>(send_receive_buffer), size);
149 std::string
const send_buffer(
reinterpret_cast<char const *
>(send_receive_buffer),
151 auto const received =
152 client_->Allreduce(send_buffer,
static_cast<xgboost::federated::DataType
>(data_type),
153 static_cast<xgboost::federated::ReduceOperation
>(op));
154 received.copy(
reinterpret_cast<char *
>(send_receive_buffer), count *
GetTypeSize(data_type));
163 void Broadcast(
void *send_receive_buffer, std::size_t size,
int root)
override {
166 std::string
const send_buffer(
reinterpret_cast<char const *
>(send_receive_buffer), size);
167 client_->Broadcast(send_buffer, root);
169 auto const received = client_->Broadcast(
"", root);
170 received.copy(
reinterpret_cast<char *
>(send_receive_buffer), size);
184 void Print(
const std::string &message)
override { LOG(CONSOLE) << message; }
190 std::unique_ptr<xgboost::federated::FederatedClient> client_{};
Data structure representing JSON format.
Definition json.h:357
A communicator class that handles collective communication.
Definition communicator.h:86
int GetRank() const
Get the rank of the current processes.
Definition communicator.h:117
int GetWorldSize() const
Get the total number of processes.
Definition communicator.h:114
A Federated Learning communicator class that handles collective communication.
Definition federated_communicator.h:18
void AllGather(void *send_receive_buffer, std::size_t size) override
Perform in-place allgather.
Definition federated_communicator.h:134
bool IsFederated() const override
Get if the communicator is federated.
Definition federated_communicator.h:127
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override
Perform in-place allreduce.
Definition federated_communicator.h:147
void Print(const std::string &message) override
Print the message to the communicator.
Definition federated_communicator.h:184
FederatedCommunicator(int world_size, int rank, std::string const &server_address, std::string const &server_cert_path, std::string const &client_key_path, std::string const &client_cert_path)
Construct a new federated communicator.
Definition federated_communicator.h:91
static Communicator * Create(Json const &config)
Create a new communicator based on JSON configuration.
Definition federated_communicator.h:25
FederatedCommunicator(int world_size, int rank, std::string const &server_address)
Construct an insecure federated communicator without using SSL.
Definition federated_communicator.h:110
bool IsDistributed() const override
Get if the communicator is distributed.
Definition federated_communicator.h:121
void Shutdown() override
Shuts down the communicator.
Definition federated_communicator.h:187
std::string GetProcessorName() override
Get the name of the processor.
Definition federated_communicator.h:178
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override
Broadcast a memory region to all others from root.
Definition federated_communicator.h:163
A wrapper around the gRPC client.
Definition federated_client.h:20
Operation
Defines the reduction operation.
Definition communicator.h:61
DataType
Defines the integral and floating data types.
Definition communicator.h:15
std::size_t GetTypeSize(DataType data_type)
Get the size of the data type.
Definition communicator.h:27
std::string ReadAll(dmlc::Stream *fi, PeekableInStream *fp)
Read the whole buffer from dmlc stream.
Definition io.h:110
namespace of xgboost
Definition base.h:90