Medial Code Documentation
Loading...
Searching...
No Matches
federated_communicator.h
1
4#pragma once
5#include <xgboost/json.h>
6
7#include "../../src/c_api/c_api_utils.h"
8#include "../../src/collective/communicator.h"
9#include "../../src/common/io.h"
10#include "federated_client.h"
11
12namespace xgboost {
13namespace collective {
14
19 public:
25 static Communicator *Create(Json const &config) {
26 std::string server_address{};
27 int world_size{0};
28 int rank{-1};
29 std::string server_cert{};
30 std::string client_key{};
31 std::string client_cert{};
32
33 // Parse environment variables first.
34 auto *value = getenv("FEDERATED_SERVER_ADDRESS");
35 if (value != nullptr) {
36 server_address = value;
37 }
38 value = getenv("FEDERATED_WORLD_SIZE");
39 if (value != nullptr) {
40 world_size = std::stoi(value);
41 }
42 value = getenv("FEDERATED_RANK");
43 if (value != nullptr) {
44 rank = std::stoi(value);
45 }
46 value = getenv("FEDERATED_SERVER_CERT");
47 if (value != nullptr) {
48 server_cert = value;
49 }
50 value = getenv("FEDERATED_CLIENT_KEY");
51 if (value != nullptr) {
52 client_key = value;
53 }
54 value = getenv("FEDERATED_CLIENT_CERT");
55 if (value != nullptr) {
56 client_cert = value;
57 }
58
59 // Runtime configuration overrides, optional as users can specify them as env vars.
60 server_address = OptionalArg<String>(config, "federated_server_address", server_address);
61 world_size =
62 OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
63 rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
64 server_cert = OptionalArg<String>(config, "federated_server_cert", server_cert);
65 client_key = OptionalArg<String>(config, "federated_client_key", client_key);
66 client_cert = OptionalArg<String>(config, "federated_client_cert", client_cert);
67
68 if (server_address.empty()) {
69 LOG(FATAL) << "Federated server address must be set.";
70 }
71 if (world_size == 0) {
72 LOG(FATAL) << "Federated world size must be set.";
73 }
74 if (rank == -1) {
75 LOG(FATAL) << "Federated rank must be set.";
76 }
77 return new FederatedCommunicator(world_size, rank, server_address, server_cert, client_key,
78 client_cert);
79 }
80
91 FederatedCommunicator(int world_size, int rank, std::string const &server_address,
92 std::string const &server_cert_path, std::string const &client_key_path,
93 std::string const &client_cert_path)
94 : Communicator{world_size, rank} {
95 if (server_cert_path.empty() || client_key_path.empty() || client_cert_path.empty()) {
96 client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
97 } else {
99 server_address, rank, xgboost::common::ReadAll(server_cert_path),
100 xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path)));
101 }
102 }
103
110 FederatedCommunicator(int world_size, int rank, std::string const &server_address)
111 : Communicator{world_size, rank} {
112 client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
113 }
114
115 ~FederatedCommunicator() override { client_.reset(); }
116
121 bool IsDistributed() const override { return true; }
122
127 bool IsFederated() const override { return true; }
128
134 void AllGather(void *send_receive_buffer, std::size_t size) override {
135 std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
136 auto const received = client_->Allgather(send_buffer);
137 received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
138 }
139
147 void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
148 Operation op) override {
149 std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer),
150 count * GetTypeSize(data_type));
151 auto const received =
152 client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(data_type),
153 static_cast<xgboost::federated::ReduceOperation>(op));
154 received.copy(reinterpret_cast<char *>(send_receive_buffer), count * GetTypeSize(data_type));
155 }
156
163 void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
164 if (GetWorldSize() == 1) return;
165 if (GetRank() == root) {
166 std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
167 client_->Broadcast(send_buffer, root);
168 } else {
169 auto const received = client_->Broadcast("", root);
170 received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
171 }
172 }
173
178 std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); }
179
184 void Print(const std::string &message) override { LOG(CONSOLE) << message; }
185
186 protected:
187 void Shutdown() override {}
188
189 private:
190 std::unique_ptr<xgboost::federated::FederatedClient> client_{};
191};
192} // namespace collective
193} // namespace xgboost
Data structure representing JSON format.
Definition json.h:357
A communicator class that handles collective communication.
Definition communicator.h:86
int GetRank() const
Get the rank of the current processes.
Definition communicator.h:117
int GetWorldSize() const
Get the total number of processes.
Definition communicator.h:114
A Federated Learning communicator class that handles collective communication.
Definition federated_communicator.h:18
void AllGather(void *send_receive_buffer, std::size_t size) override
Perform in-place allgather.
Definition federated_communicator.h:134
bool IsFederated() const override
Get if the communicator is federated.
Definition federated_communicator.h:127
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override
Perform in-place allreduce.
Definition federated_communicator.h:147
void Print(const std::string &message) override
Print the message to the communicator.
Definition federated_communicator.h:184
FederatedCommunicator(int world_size, int rank, std::string const &server_address, std::string const &server_cert_path, std::string const &client_key_path, std::string const &client_cert_path)
Construct a new federated communicator.
Definition federated_communicator.h:91
static Communicator * Create(Json const &config)
Create a new communicator based on JSON configuration.
Definition federated_communicator.h:25
FederatedCommunicator(int world_size, int rank, std::string const &server_address)
Construct an insecure federated communicator without using SSL.
Definition federated_communicator.h:110
bool IsDistributed() const override
Get if the communicator is distributed.
Definition federated_communicator.h:121
void Shutdown() override
Shuts down the communicator.
Definition federated_communicator.h:187
std::string GetProcessorName() override
Get the name of the processor.
Definition federated_communicator.h:178
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override
Broadcast a memory region to all others from root.
Definition federated_communicator.h:163
A wrapper around the gRPC client.
Definition federated_client.h:20
Operation
Defines the reduction operation.
Definition communicator.h:61
DataType
Defines the integral and floating data types.
Definition communicator.h:15
std::size_t GetTypeSize(DataType data_type)
Get the size of the data type.
Definition communicator.h:27
std::string ReadAll(dmlc::Stream *fi, PeekableInStream *fp)
Read the whole buffer from dmlc stream.
Definition io.h:110
namespace of xgboost
Definition base.h:90