Medial Code Documentation
Loading...
Searching...
No Matches
federated_client.h
1
4#pragma once
5#include <federated.grpc.pb.h>
6#include <federated.pb.h>
7#include <grpcpp/grpcpp.h>
8
9#include <cstdio>
10#include <cstdlib>
11#include <limits>
12#include <string>
13
14namespace xgboost {
15namespace federated {
16
21 public:
22 FederatedClient(std::string const &server_address, int rank, std::string const &server_cert,
23 std::string const &client_key, std::string const &client_cert)
24 : stub_{[&] {
25 grpc::SslCredentialsOptions options;
26 options.pem_root_certs = server_cert;
27 options.pem_private_key = client_key;
28 options.pem_cert_chain = client_cert;
29 grpc::ChannelArguments args;
30 args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
31 auto channel =
32 grpc::CreateCustomChannel(server_address, grpc::SslCredentials(options), args);
33 channel->WaitForConnected(
34 gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
35 return Federated::NewStub(channel);
36 }()},
37 rank_{rank} {}
38
40 FederatedClient(std::string const &server_address, int rank)
41 : stub_{[&] {
42 grpc::ChannelArguments args;
43 args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
44 return Federated::NewStub(
45 grpc::CreateCustomChannel(server_address, grpc::InsecureChannelCredentials(), args));
46 }()},
47 rank_{rank} {}
48
49 std::string Allgather(std::string const &send_buffer) {
50 AllgatherRequest request;
51 request.set_sequence_number(sequence_number_++);
52 request.set_rank(rank_);
53 request.set_send_buffer(send_buffer);
54
55 AllgatherReply reply;
56 grpc::ClientContext context;
57 context.set_wait_for_ready(true);
58 grpc::Status status = stub_->Allgather(&context, request, &reply);
59
60 if (status.ok()) {
61 return reply.receive_buffer();
62 } else {
63 std::cout << status.error_code() << ": " << status.error_message() << '\n';
64 throw std::runtime_error("Allgather RPC failed");
65 }
66 }
67
68 std::string Allreduce(std::string const &send_buffer, DataType data_type,
69 ReduceOperation reduce_operation) {
70 AllreduceRequest request;
71 request.set_sequence_number(sequence_number_++);
72 request.set_rank(rank_);
73 request.set_send_buffer(send_buffer);
74 request.set_data_type(data_type);
75 request.set_reduce_operation(reduce_operation);
76
77 AllreduceReply reply;
78 grpc::ClientContext context;
79 context.set_wait_for_ready(true);
80 grpc::Status status = stub_->Allreduce(&context, request, &reply);
81
82 if (status.ok()) {
83 return reply.receive_buffer();
84 } else {
85 std::cout << status.error_code() << ": " << status.error_message() << '\n';
86 throw std::runtime_error("Allreduce RPC failed");
87 }
88 }
89
90 std::string Broadcast(std::string const &send_buffer, int root) {
91 BroadcastRequest request;
92 request.set_sequence_number(sequence_number_++);
93 request.set_rank(rank_);
94 request.set_send_buffer(send_buffer);
95 request.set_root(root);
96
97 BroadcastReply reply;
98 grpc::ClientContext context;
99 context.set_wait_for_ready(true);
100 grpc::Status status = stub_->Broadcast(&context, request, &reply);
101
102 if (status.ok()) {
103 return reply.receive_buffer();
104 } else {
105 std::cout << status.error_code() << ": " << status.error_message() << '\n';
106 throw std::runtime_error("Broadcast RPC failed");
107 }
108 }
109
110 private:
111 std::unique_ptr<Federated::Stub> const stub_;
112 int const rank_;
113 uint64_t sequence_number_{};
114};
115
116} // namespace federated
117} // namespace xgboost
A wrapper around the gRPC client.
Definition federated_client.h:20
FederatedClient(std::string const &server_address, int rank)
Insecure client for connecting to localhost only.
Definition federated_client.h:40
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice)
Allgather function, each node have a segment of data in the ring of sendrecvbuf, the data provided by...
Definition engine.cc:85
void Broadcast(void *sendrecv_data, size_t size, int root)
broadcasts a memory region to every node from the root
Definition rabit-inl.h:148
namespace of xgboost
Definition base.h:90
DataType
data type accepted by xgboost interface
Definition data.h:33