22 FederatedClient(std::string
const &server_address,
int rank, std::string
const &server_cert,
23 std::string
const &client_key, std::string
const &client_cert)
25 grpc::SslCredentialsOptions options;
26 options.pem_root_certs = server_cert;
27 options.pem_private_key = client_key;
28 options.pem_cert_chain = client_cert;
29 grpc::ChannelArguments args;
30 args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
32 grpc::CreateCustomChannel(server_address, grpc::SslCredentials(options), args);
33 channel->WaitForConnected(
34 gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
35 return Federated::NewStub(channel);
42 grpc::ChannelArguments args;
43 args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
44 return Federated::NewStub(
45 grpc::CreateCustomChannel(server_address, grpc::InsecureChannelCredentials(), args));
49 std::string
Allgather(std::string
const &send_buffer) {
50 AllgatherRequest request;
51 request.set_sequence_number(sequence_number_++);
52 request.set_rank(rank_);
53 request.set_send_buffer(send_buffer);
56 grpc::ClientContext context;
57 context.set_wait_for_ready(
true);
58 grpc::Status status = stub_->Allgather(&context, request, &reply);
61 return reply.receive_buffer();
63 std::cout << status.error_code() <<
": " << status.error_message() <<
'\n';
64 throw std::runtime_error(
"Allgather RPC failed");
68 std::string Allreduce(std::string
const &send_buffer,
DataType data_type,
69 ReduceOperation reduce_operation) {
70 AllreduceRequest request;
71 request.set_sequence_number(sequence_number_++);
72 request.set_rank(rank_);
73 request.set_send_buffer(send_buffer);
74 request.set_data_type(data_type);
75 request.set_reduce_operation(reduce_operation);
78 grpc::ClientContext context;
79 context.set_wait_for_ready(
true);
80 grpc::Status status = stub_->Allreduce(&context, request, &reply);
83 return reply.receive_buffer();
85 std::cout << status.error_code() <<
": " << status.error_message() <<
'\n';
86 throw std::runtime_error(
"Allreduce RPC failed");
90 std::string
Broadcast(std::string
const &send_buffer,
int root) {
91 BroadcastRequest request;
92 request.set_sequence_number(sequence_number_++);
93 request.set_rank(rank_);
94 request.set_send_buffer(send_buffer);
95 request.set_root(root);
98 grpc::ClientContext context;
99 context.set_wait_for_ready(
true);
100 grpc::Status status = stub_->Broadcast(&context, request, &reply);
103 return reply.receive_buffer();
105 std::cout << status.error_code() <<
": " << status.error_message() <<
'\n';
106 throw std::runtime_error(
"Broadcast RPC failed");
111 std::unique_ptr<Federated::Stub>
const stub_;
113 uint64_t sequence_number_{};