Medial Code Documentation
Loading...
Searching...
No Matches
network.h
1#ifndef LIGHTGBM_NETWORK_H_
2#define LIGHTGBM_NETWORK_H_
3
4#include <LightGBM/utils/log.h>
5
6#include <LightGBM/meta.h>
7#include <LightGBM/config.h>
8
9#include <functional>
10#include <vector>
11#include <memory>
12
13namespace LightGBM {
14
16class Linkers;
17
19class BruckMap {
20public:
22 int k;
24 std::vector<int> in_ranks;
26 std::vector<int> out_ranks;
27 BruckMap();
28 explicit BruckMap(int n);
35 static BruckMap Construct(int rank, int num_machines);
36};
37
47 Normal, // normal node, 1 group only have 1 machine
48 GroupLeader, // leader of group when number of machines in this group is 2.
49 Other // non-leader machines in group
50};
51
54public:
56 int k;
59 bool is_power_of_2;
60 int neighbor;
62 std::vector<int> ranks;
64 std::vector<int> send_block_start;
66 std::vector<int> send_block_len;
68 std::vector<int> recv_block_start;
70 std::vector<int> recv_block_len;
71
73
74 RecursiveHalvingMap(int k, RecursiveHalvingNodeType _type, bool _is_power_of_2);
75
82 static RecursiveHalvingMap Construct(int rank, int num_machines);
83};
84
86class Network {
87public:
92 static void Init(Config config);
96 static void Init(int num_machines, int rank, ReduceScatterFunction reduce_scatter_ext_fun, AllgatherFunction allgather_ext_fun);
98 static void Dispose();
100 static inline int rank();
102 static inline int num_machines();
103
113 static void Allreduce(char* input, comm_size_t input_size, int type_size,
114 char* output, const ReduceFunction& reducer);
115
124 static void AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output,
125 const ReduceFunction& reducer);
126
135 static void Allgather(char* input, comm_size_t send_size, char* output);
136
147 static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
148
161 static void ReduceScatter(char* input, comm_size_t input_size, int type_size,
162 const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
163 const ReduceFunction& reducer);
164
165 template<class T>
166 static T GlobalSyncUpByMin(T& local) {
167 T global = local;
168 Allreduce(reinterpret_cast<char*>(&local),
169 sizeof(local), sizeof(local),
170 reinterpret_cast<char*>(&global),
171 [] (const char* src, char* dst, int type_size, comm_size_t len) {
172 comm_size_t used_size = 0;
173 const T *p1;
174 T *p2;
175 while (used_size < len) {
176 p1 = reinterpret_cast<const T *>(src);
177 p2 = reinterpret_cast<T *>(dst);
178 if (*p1 < *p2) {
179 std::memcpy(dst, src, type_size);
180 }
181 src += type_size;
182 dst += type_size;
183 used_size += type_size;
184 }
185 });
186 return global;
187 }
188
189 template<class T>
190 static T GlobalSyncUpByMax(T& local) {
191 T global = local;
192 Allreduce(reinterpret_cast<char*>(&local),
193 sizeof(local), sizeof(local),
194 reinterpret_cast<char*>(&global),
195 [] (const char* src, char* dst, int type_size, comm_size_t len) {
196 comm_size_t used_size = 0;
197 const T *p1;
198 T *p2;
199 while (used_size < len) {
200 p1 = reinterpret_cast<const T *>(src);
201 p2 = reinterpret_cast<T *>(dst);
202 if (*p1 > *p2) {
203 std::memcpy(dst, src, type_size);
204 }
205 src += type_size;
206 dst += type_size;
207 used_size += type_size;
208 }
209 });
210 return global;
211 }
212
213 template<class T>
214 static T GlobalSyncUpByMean(T& local) {
215 T global = (T)0;
216 Allreduce(reinterpret_cast<char*>(&local),
217 sizeof(local), sizeof(local),
218 reinterpret_cast<char*>(&global),
219 [](const char* src, char* dst, int type_size, comm_size_t len) {
220 comm_size_t used_size = 0;
221 const T *p1;
222 T *p2;
223 while (used_size < len) {
224 p1 = reinterpret_cast<const T *>(src);
225 p2 = reinterpret_cast<T *>(dst);
226 *p2 += *p1;
227 src += type_size;
228 dst += type_size;
229 used_size += type_size;
230 }
231 });
232 return static_cast<T>(global / num_machines_);
233 }
234
235 template<class T>
236 static void GlobalSum(std::vector<T>& local) {
237 std::vector<T> global(local.size(), 0);
238 Allreduce(reinterpret_cast<char*>(local.data()),
239 static_cast<comm_size_t>(sizeof(T) * local.size()), sizeof(T),
240 reinterpret_cast<char*>(global.data()),
241 [](const char* src, char* dst, int type_size, comm_size_t len) {
242 comm_size_t used_size = 0;
243 const T *p1;
244 T *p2;
245 while (used_size < len) {
246 p1 = reinterpret_cast<const T *>(src);
247 p2 = reinterpret_cast<T *>(dst);
248 *p2 += *p1;
249 src += type_size;
250 dst += type_size;
251 used_size += type_size;
252 }
253 });
254 for (size_t i = 0; i < local.size(); ++i) {
255 local[i] = global[i];
256 }
257 }
258
259private:
260 static void AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
261
262 static void AllgatherRecursiveDoubling(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
263
264 static void AllgatherRing(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
265
266 static void ReduceScatterRecursiveHalving(char* input, comm_size_t input_size, int type_size,
267 const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
268 const ReduceFunction& reducer);
269
270 static void ReduceScatterRing(char* input, comm_size_t input_size, int type_size,
271 const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
272 const ReduceFunction& reducer);
273
275 static THREAD_LOCAL int num_machines_;
277 static THREAD_LOCAL int rank_;
279 static THREAD_LOCAL std::unique_ptr<Linkers> linkers_;
281 static THREAD_LOCAL BruckMap bruck_map_;
283 static THREAD_LOCAL RecursiveHalvingMap recursive_halving_map_;
285 static THREAD_LOCAL std::vector<comm_size_t> block_start_;
287 static THREAD_LOCAL std::vector<comm_size_t> block_len_;
289 static THREAD_LOCAL std::vector<char> buffer_;
291 static THREAD_LOCAL comm_size_t buffer_size_;
293 static THREAD_LOCAL ReduceScatterFunction reduce_scatter_ext_fun_;
294 static THREAD_LOCAL AllgatherFunction allgather_ext_fun_;
295};
296
297inline int Network::rank() {
298 return rank_;
299}
300
301inline int Network::num_machines() {
302 return num_machines_;
303}
304
305} // namespace LightGBM
306
307#endif // LightGBM_NETWORK_H_
The network structure for all_gather.
Definition network.h:19
std::vector< int > in_ranks
in_ranks[i] means the incomming rank on i-th communication
Definition network.h:24
std::vector< int > out_ranks
out_ranks[i] means the out rank on i-th communication
Definition network.h:26
int k
The communication times for one all gather operation.
Definition network.h:22
static BruckMap Construct(int rank, int num_machines)
Create the object of bruck map.
Definition linker_topo.cpp:26
A static class that contains some collective communication algorithm.
Definition network.h:86
static void Init(Config config)
Initialize.
Definition network.cpp:26
static void AllreduceByAllGather(char *input, comm_size_t input_size, int type_size, char *output, const ReduceFunction &reducer)
Perform all_reduce by using all_gather. it can be use to reduce communication time when data is small...
Definition network.cpp:91
static void Allreduce(char *input, comm_size_t input_size, int type_size, char *output, const ReduceFunction &reducer)
Perform all_reduce. if data size is small, will perform AllreduceByAllGather, else with call ReduceSc...
Definition network.cpp:64
static int rank()
Get rank of this machine.
Definition network.h:297
static int num_machines()
Get total number of machines.
Definition network.h:301
static void Dispose()
Free this static class.
Definition network.cpp:56
static void ReduceScatter(char *input, comm_size_t input_size, int type_size, const comm_size_t *block_start, const comm_size_t *block_len, char *output, comm_size_t output_size, const ReduceFunction &reducer)
Perform reduce scatter by using recursive halving algorithm. Communication times is O(log(n)),...
Definition network.cpp:228
static void Allgather(char *input, comm_size_t send_size, char *output)
Performing all_gather by using bruck algorithm. Communication times is O(log(n)), and communication c...
Definition network.cpp:117
Network structure for recursive halving algorithm.
Definition network.h:53
std::vector< int > recv_block_start
send_block_start[i] means recv block start index at i-th communication
Definition network.h:68
RecursiveHalvingNodeType type
Node type.
Definition network.h:58
std::vector< int > send_block_len
send_block_start[i] means send block size at i-th communication
Definition network.h:66
static RecursiveHalvingMap Construct(int rank, int num_machines)
Create the object of recursive halving map.
Definition linker_topo.cpp:65
std::vector< int > send_block_start
send_block_start[i] means send block start index at i-th communication
Definition network.h:64
std::vector< int > recv_block_len
send_block_start[i] means recv block size at i-th communication
Definition network.h:70
int k
Communication times for one recursize halving algorithm
Definition network.h:56
std::vector< int > ranks
ranks[i] means the machines that will communicate with on i-th communication
Definition network.h:62
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
RecursiveHalvingNodeType
node type on recursive halving algorithm When number of machines is not power of 2,...
Definition network.h:46
Definition config.h:27