Medial Code Documentation
Loading...
Searching...
No Matches
helpers.h
1
4#pragma once
5
6#include <dmlc/omp.h>
7#include <grpcpp/server_builder.h>
8#include <gtest/gtest.h>
9#include <xgboost/json.h>
10
11#include <random>
12#include <thread> // for thread, sleep_for
13
14#include "../../../plugin/federated/federated_server.h"
15#include "../../../src/collective/communicator-inl.h"
16#include "../../../src/common/threading_utils.h"
17
18namespace xgboost {
19
21 std::string server_address_;
22 std::unique_ptr<std::thread> server_thread_;
23 std::unique_ptr<grpc::Server> server_;
24
25 public:
26 explicit ServerForTest(std::int32_t world_size) {
27 server_thread_.reset(new std::thread([this, world_size] {
28 grpc::ServerBuilder builder;
29 xgboost::federated::FederatedService service{world_size};
30 int selected_port;
31 builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port);
32 builder.RegisterService(&service);
33 server_ = builder.BuildAndStart();
34 server_address_ = std::string("localhost:") + std::to_string(selected_port);
35 server_->Wait();
36 }));
37 }
38
40 using namespace std::chrono_literals;
41 while (!server_) {
42 std::this_thread::sleep_for(100ms);
43 }
44 server_->Shutdown();
45 while (!server_thread_) {
46 std::this_thread::sleep_for(100ms);
47 }
48 server_thread_->join();
49 }
50
51 auto Address() const {
52 using namespace std::chrono_literals;
53 while (server_address_.empty()) {
54 std::this_thread::sleep_for(100ms);
55 }
56 return server_address_;
57 }
58};
59
60class BaseFederatedTest : public ::testing::Test {
61 protected:
62 void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
63
64 void TearDown() override { server_.reset(nullptr); }
65
66 static int constexpr kWorldSize{2};
67 std::unique_ptr<ServerForTest> server_;
68};
69
70template <typename Function, typename... Args>
71void RunWithFederatedCommunicator(int32_t world_size, std::string const& server_address,
72 Function&& function, Args&&... args) {
73 auto run = [&](auto rank) {
74 Json config{JsonObject()};
75 config["xgboost_communicator"] = String("federated");
76 config["federated_server_address"] = String(server_address);
77 config["federated_world_size"] = world_size;
78 config["federated_rank"] = rank;
80
81 std::forward<Function>(function)(std::forward<Args>(args)...);
82
84 };
85#if defined(_OPENMP)
86 common::ParallelFor(world_size, world_size, run);
87#else
88 std::vector<std::thread> threads;
89 for (auto rank = 0; rank < world_size; rank++) {
90 threads.emplace_back(run, rank);
91 }
92 for (auto& thread : threads) {
93 thread.join();
94 }
95#endif
96}
97
98} // namespace xgboost
Definition helpers.h:60
Definition json.h:190
Data structure representing JSON format.
Definition json.h:357
Definition helpers.h:20
Definition federated_server.h:13
void Init(Json const &config)
Initialize the collective communicator.
Definition communicator-inl.h:60
void Finalize()
Finalize the collective communicator.
Definition communicator-inl.h:69
namespace of xgboost
Definition base.h:90
header to handle OpenMP compatibility issues