1#ifndef LIGHTGBM_NETWORK_LINKERS_H_
2#define LIGHTGBM_NETWORK_LINKERS_H_
4#include <LightGBM/network.h>
5#include <LightGBM/meta.h>
6#include <LightGBM/config.h>
18#include "socket_wrapper.hpp"
19#include <LightGBM/utils/common.h>
24#define MPI_SAFE_CALL(mpi_return) CHECK((mpi_return) == MPI_SUCCESS)
54 inline void Recv(
int rank,
char* data,
int len)
const;
56 inline void Recv(
int rank,
char* data, int64_t len)
const;
64 inline void Send(
int rank,
char* data,
int len)
const;
66 inline void Send(
int rank,
char* data, int64_t len)
const;
76 inline void SendRecv(
int send_rank,
char* send_data,
int send_len,
77 int recv_rank,
char* recv_data,
int recv_len);
79 inline void SendRecv(
int send_rank,
char* send_data, int64_t send_len,
80 int recv_rank,
char* recv_data, int64_t recv_len);
103 void TryBind(
int port);
109 void SetLinker(
int rank,
const TcpSocket& socket);
114 void ListenThread(
int incoming_cnt);
124 void ParseMachineList(
const std::string& machines,
const std::string& filename);
130 bool CheckLinker(
int rank);
149 std::chrono::duration<double, std::milli> network_time_;
155 std::vector<std::string> client_ips_;
157 std::vector<int> client_ports_;
161 int local_listen_port_;
163 std::vector<std::unique_ptr<TcpSocket>> linkers_;
165 std::unique_ptr<TcpSocket> listener_;
175 return num_machines_;
183 return recursive_halving_map_;
186inline void Linkers::Recv(
int rank,
char* data, int64_t len)
const {
189 int cur_size =
static_cast<int>(std::min<int64_t>(len - used, INT32_MAX));
192 }
while (used < len);
195inline void Linkers::Send(
int rank,
char* data, int64_t len)
const {
198 int cur_size =
static_cast<int>(std::min<int64_t>(len - used, INT32_MAX));
201 }
while (used < len);
205 int recv_rank,
char* recv_data, int64_t recv_len) {
206 auto start_time = std::chrono::high_resolution_clock::now();
207 std::thread send_worker(
208 [
this, send_rank, send_data, send_len]() {
209 Send(send_rank, send_data, send_len);
211 Recv(recv_rank, recv_data, recv_len);
214 auto end_time = std::chrono::high_resolution_clock::now();
216 network_time_ += std::chrono::duration<double, std::milli>(end_time - start_time);
221inline void Linkers::Recv(
int rank,
char* data,
int len)
const {
223 while (recv_cnt < len) {
224 recv_cnt += linkers_[
rank]->Recv(data + recv_cnt,
226 std::min(len - recv_cnt, SocketConfig::kMaxReceiveSize));
230inline void Linkers::Send(
int rank,
char* data,
int len)
const {
235 while (send_cnt < len) {
236 send_cnt += linkers_[
rank]->Send(data + send_cnt, len - send_cnt);
241 int recv_rank,
char* recv_data,
int recv_len) {
242 auto start_time = std::chrono::high_resolution_clock::now();
243 if (send_len < SocketConfig::kSocketBufferSize) {
245 Send(send_rank, send_data, send_len);
246 Recv(recv_rank, recv_data, recv_len);
249 std::thread send_worker(
250 [
this, send_rank, send_data, send_len]() {
251 Send(send_rank, send_data, send_len);
253 Recv(recv_rank, recv_data, recv_len);
257 auto end_time = std::chrono::high_resolution_clock::now();
259 network_time_ += std::chrono::duration<double, std::milli>(end_time - start_time);
266inline void Linkers::Recv(
int rank,
char* data,
int len)
const {
269 while (read_cnt < len) {
270 MPI_SAFE_CALL(MPI_Recv(data + read_cnt, len - read_cnt, MPI_BYTE,
rank, MPI_ANY_TAG, MPI_COMM_WORLD, &status));
272 MPI_SAFE_CALL(MPI_Get_count(&status, MPI_BYTE, &cur_cnt));
277inline void Linkers::Send(
int rank,
char* data,
int len)
const {
282 MPI_Request send_request;
283 MPI_SAFE_CALL(MPI_Isend(data, len, MPI_BYTE,
rank, 0, MPI_COMM_WORLD, &send_request));
284 MPI_SAFE_CALL(MPI_Wait(&send_request, &status));
288 int recv_rank,
char* recv_data,
int recv_len) {
289 MPI_Request send_request;
291 MPI_SAFE_CALL(MPI_Isend(send_data, send_len, MPI_BYTE, send_rank, 0, MPI_COMM_WORLD, &send_request));
295 while (read_cnt < recv_len) {
296 MPI_SAFE_CALL(MPI_Recv(recv_data + read_cnt, recv_len - read_cnt, MPI_BYTE, recv_rank, 0, MPI_COMM_WORLD, &status));
298 MPI_SAFE_CALL(MPI_Get_count(&status, MPI_BYTE, &cur_cnt));
302 MPI_SAFE_CALL(MPI_Wait(&send_request, &status));
The network structure for all_gather.
Definition network.h:19
An network basic communication warpper. Will warp low level communication methods,...
Definition linkers.h:34
void SendRecv(int send_rank, char *send_data, int send_len, int recv_rank, char *recv_data, int recv_len)
Send and Recv at same time, blocking.
Linkers(Config config)
Constructor.
void Send(int rank, char *data, int len) const
Send data, blocking.
const RecursiveHalvingMap & recursive_halving_map()
Get Recursive Halving map of this network.
Definition linkers.h:182
void Recv(int rank, char *data, int len) const
Recv data, blocking.
const BruckMap & bruck_map()
Get Bruck map of this network.
Definition linkers.h:178
int num_machines()
Get total number of machines.
Definition linkers.h:174
int rank()
Get rank of local machine.
Definition linkers.h:170
Network structure for recursive halving algorithm.
Definition network.h:53
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10