Medial Code Documentation
Loading...
Searching...
No Matches
linkers.h
1#ifndef LIGHTGBM_NETWORK_LINKERS_H_
2#define LIGHTGBM_NETWORK_LINKERS_H_
3
4#include <LightGBM/network.h>
5#include <LightGBM/meta.h>
6#include <LightGBM/config.h>
7
8
9#include <algorithm>
10#include <chrono>
11#include <ctime>
12#include <thread>
13#include <vector>
14#include <string>
15#include <memory>
16
17#ifdef USE_SOCKET
18#include "socket_wrapper.hpp"
19#include <LightGBM/utils/common.h>
20#endif
21
22#ifdef USE_MPI
23#include <mpi.h>
24#define MPI_SAFE_CALL(mpi_return) CHECK((mpi_return) == MPI_SUCCESS)
25#endif
26
27namespace LightGBM {
28
34class Linkers {
35public:
36 Linkers() {
37 is_init_ = false;
38 }
43 explicit Linkers(Config config);
54 inline void Recv(int rank, char* data, int len) const;
55
56 inline void Recv(int rank, char* data, int64_t len) const;
57
64 inline void Send(int rank, char* data, int len) const;
65
66 inline void Send(int rank, char* data, int64_t len) const;
76 inline void SendRecv(int send_rank, char* send_data, int send_len,
77 int recv_rank, char* recv_data, int recv_len);
78
79 inline void SendRecv(int send_rank, char* send_data, int64_t send_len,
80 int recv_rank, char* recv_data, int64_t recv_len);
84 inline int rank();
88 inline int num_machines();
92 inline const BruckMap& bruck_map();
97
98 #ifdef USE_SOCKET
103 void TryBind(int port);
109 void SetLinker(int rank, const TcpSocket& socket);
114 void ListenThread(int incoming_cnt);
118 void Construct();
124 void ParseMachineList(const std::string& machines, const std::string& filename);
130 bool CheckLinker(int rank);
134 void PrintLinkers();
135
136 #endif // USE_SOCKET
137
138
139private:
141 int rank_;
143 int num_machines_;
145 BruckMap bruck_map_;
147 RecursiveHalvingMap recursive_halving_map_;
148
149 std::chrono::duration<double, std::milli> network_time_;
150
151 bool is_init_;
152
153 #ifdef USE_SOCKET
155 std::vector<std::string> client_ips_;
157 std::vector<int> client_ports_;
159 int socket_timeout_;
161 int local_listen_port_;
163 std::vector<std::unique_ptr<TcpSocket>> linkers_;
165 std::unique_ptr<TcpSocket> listener_;
166 #endif // USE_SOCKET
167};
168
169
170inline int Linkers::rank() {
171 return rank_;
172}
173
175 return num_machines_;
176}
177
179 return bruck_map_;
180}
181
183 return recursive_halving_map_;
184}
185
186inline void Linkers::Recv(int rank, char* data, int64_t len) const {
187 int64_t used = 0;
188 do {
189 int cur_size = static_cast<int>(std::min<int64_t>(len - used, INT32_MAX));
190 Recv(rank, data + used, cur_size);
191 used += cur_size;
192 } while (used < len);
193}
194
195inline void Linkers::Send(int rank, char* data, int64_t len) const {
196 int64_t used = 0;
197 do {
198 int cur_size = static_cast<int>(std::min<int64_t>(len - used, INT32_MAX));
199 Send(rank, data + used, cur_size);
200 used += cur_size;
201 } while (used < len);
202}
203
204inline void Linkers::SendRecv(int send_rank, char* send_data, int64_t send_len,
205 int recv_rank, char* recv_data, int64_t recv_len) {
206 auto start_time = std::chrono::high_resolution_clock::now();
207 std::thread send_worker(
208 [this, send_rank, send_data, send_len]() {
209 Send(send_rank, send_data, send_len);
210 });
211 Recv(recv_rank, recv_data, recv_len);
212 send_worker.join();
213 // wait for send complete
214 auto end_time = std::chrono::high_resolution_clock::now();
215 // output used time on each iteration
216 network_time_ += std::chrono::duration<double, std::milli>(end_time - start_time);
217}
218
219#ifdef USE_SOCKET
220
221inline void Linkers::Recv(int rank, char* data, int len) const {
222 int recv_cnt = 0;
223 while (recv_cnt < len) {
224 recv_cnt += linkers_[rank]->Recv(data + recv_cnt,
225 // len - recv_cnt
226 std::min(len - recv_cnt, SocketConfig::kMaxReceiveSize));
227 }
228}
229
230inline void Linkers::Send(int rank, char* data, int len) const {
231 if (len <= 0) {
232 return;
233 }
234 int send_cnt = 0;
235 while (send_cnt < len) {
236 send_cnt += linkers_[rank]->Send(data + send_cnt, len - send_cnt);
237 }
238}
239
240inline void Linkers::SendRecv(int send_rank, char* send_data, int send_len,
241 int recv_rank, char* recv_data, int recv_len) {
242 auto start_time = std::chrono::high_resolution_clock::now();
243 if (send_len < SocketConfig::kSocketBufferSize) {
244 // if buffer is enough, send will non-blocking
245 Send(send_rank, send_data, send_len);
246 Recv(recv_rank, recv_data, recv_len);
247 } else {
248 // if buffer is not enough, use another thread to send, since send will be blocking
249 std::thread send_worker(
250 [this, send_rank, send_data, send_len]() {
251 Send(send_rank, send_data, send_len);
252 });
253 Recv(recv_rank, recv_data, recv_len);
254 send_worker.join();
255 }
256 // wait for send complete
257 auto end_time = std::chrono::high_resolution_clock::now();
258 // output used time on each iteration
259 network_time_ += std::chrono::duration<double, std::milli>(end_time - start_time);
260}
261
262#endif // USE_SOCKET
263
264#ifdef USE_MPI
265
266inline void Linkers::Recv(int rank, char* data, int len) const {
267 MPI_Status status;
268 int read_cnt = 0;
269 while (read_cnt < len) {
270 MPI_SAFE_CALL(MPI_Recv(data + read_cnt, len - read_cnt, MPI_BYTE, rank, MPI_ANY_TAG, MPI_COMM_WORLD, &status));
271 int cur_cnt;
272 MPI_SAFE_CALL(MPI_Get_count(&status, MPI_BYTE, &cur_cnt));
273 read_cnt += cur_cnt;
274 }
275}
276
277inline void Linkers::Send(int rank, char* data, int len) const {
278 if (len <= 0) {
279 return;
280 }
281 MPI_Status status;
282 MPI_Request send_request;
283 MPI_SAFE_CALL(MPI_Isend(data, len, MPI_BYTE, rank, 0, MPI_COMM_WORLD, &send_request));
284 MPI_SAFE_CALL(MPI_Wait(&send_request, &status));
285}
286
287inline void Linkers::SendRecv(int send_rank, char* send_data, int send_len,
288 int recv_rank, char* recv_data, int recv_len) {
289 MPI_Request send_request;
290 // send first, non-blocking
291 MPI_SAFE_CALL(MPI_Isend(send_data, send_len, MPI_BYTE, send_rank, 0, MPI_COMM_WORLD, &send_request));
292 // then receive, blocking
293 MPI_Status status;
294 int read_cnt = 0;
295 while (read_cnt < recv_len) {
296 MPI_SAFE_CALL(MPI_Recv(recv_data + read_cnt, recv_len - read_cnt, MPI_BYTE, recv_rank, 0, MPI_COMM_WORLD, &status));
297 int cur_cnt;
298 MPI_SAFE_CALL(MPI_Get_count(&status, MPI_BYTE, &cur_cnt));
299 read_cnt += cur_cnt;
300 }
301 // wait for send complete
302 MPI_SAFE_CALL(MPI_Wait(&send_request, &status));
303}
304
305#endif // USE_MPI
306} // namespace LightGBM
307#endif // LightGBM_NETWORK_LINKERS_H_
The network structure for all_gather.
Definition network.h:19
An network basic communication warpper. Will warp low level communication methods,...
Definition linkers.h:34
void SendRecv(int send_rank, char *send_data, int send_len, int recv_rank, char *recv_data, int recv_len)
Send and Recv at same time, blocking.
Linkers(Config config)
Constructor.
~Linkers()
Destructor.
void Send(int rank, char *data, int len) const
Send data, blocking.
const RecursiveHalvingMap & recursive_halving_map()
Get Recursive Halving map of this network.
Definition linkers.h:182
void Recv(int rank, char *data, int len) const
Recv data, blocking.
const BruckMap & bruck_map()
Get Bruck map of this network.
Definition linkers.h:178
int num_machines()
Get total number of machines.
Definition linkers.h:174
int rank()
Get rank of local machine.
Definition linkers.h:170
Network structure for recursive halving algorithm.
Definition network.h:53
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
Definition config.h:27