96 static void Init(
int num_machines,
int rank, ReduceScatterFunction reduce_scatter_ext_fun, AllgatherFunction allgather_ext_fun);
100 static inline int rank();
113 static void Allreduce(
char* input, comm_size_t input_size,
int type_size,
114 char* output,
const ReduceFunction& reducer);
124 static void AllreduceByAllGather(
char* input, comm_size_t input_size,
int type_size,
char* output,
125 const ReduceFunction& reducer);
135 static void Allgather(
char* input, comm_size_t send_size,
char* output);
147 static void Allgather(
char* input,
const comm_size_t* block_start,
const comm_size_t* block_len,
char* output, comm_size_t all_size);
161 static void ReduceScatter(
char* input, comm_size_t input_size,
int type_size,
162 const comm_size_t* block_start,
const comm_size_t* block_len,
char* output, comm_size_t output_size,
163 const ReduceFunction& reducer);
166 static T GlobalSyncUpByMin(T& local) {
168 Allreduce(
reinterpret_cast<char*
>(&local),
169 sizeof(local),
sizeof(local),
170 reinterpret_cast<char*
>(&global),
171 [] (
const char* src,
char* dst,
int type_size, comm_size_t len) {
172 comm_size_t used_size = 0;
175 while (used_size < len) {
176 p1 =
reinterpret_cast<const T *
>(src);
177 p2 =
reinterpret_cast<T *
>(dst);
179 std::memcpy(dst, src, type_size);
183 used_size += type_size;
190 static T GlobalSyncUpByMax(T& local) {
192 Allreduce(
reinterpret_cast<char*
>(&local),
193 sizeof(local),
sizeof(local),
194 reinterpret_cast<char*
>(&global),
195 [] (
const char* src,
char* dst,
int type_size, comm_size_t len) {
196 comm_size_t used_size = 0;
199 while (used_size < len) {
200 p1 =
reinterpret_cast<const T *
>(src);
201 p2 =
reinterpret_cast<T *
>(dst);
203 std::memcpy(dst, src, type_size);
207 used_size += type_size;
214 static T GlobalSyncUpByMean(T& local) {
216 Allreduce(
reinterpret_cast<char*
>(&local),
217 sizeof(local),
sizeof(local),
218 reinterpret_cast<char*
>(&global),
219 [](
const char* src,
char* dst,
int type_size, comm_size_t len) {
220 comm_size_t used_size = 0;
223 while (used_size < len) {
224 p1 =
reinterpret_cast<const T *
>(src);
225 p2 =
reinterpret_cast<T *
>(dst);
229 used_size += type_size;
232 return static_cast<T
>(global / num_machines_);
236 static void GlobalSum(std::vector<T>& local) {
237 std::vector<T> global(local.size(), 0);
238 Allreduce(
reinterpret_cast<char*
>(local.data()),
239 static_cast<comm_size_t
>(
sizeof(T) * local.size()),
sizeof(T),
240 reinterpret_cast<char*
>(global.data()),
241 [](
const char* src,
char* dst,
int type_size, comm_size_t len) {
242 comm_size_t used_size = 0;
245 while (used_size < len) {
246 p1 = reinterpret_cast<const T *>(src);
247 p2 = reinterpret_cast<T *>(dst);
251 used_size += type_size;
254 for (
size_t i = 0; i < local.size(); ++i) {
255 local[i] = global[i];
260 static void AllgatherBruck(
char* input,
const comm_size_t* block_start,
const comm_size_t* block_len,
char* output, comm_size_t all_size);
262 static void AllgatherRecursiveDoubling(
char* input,
const comm_size_t* block_start,
const comm_size_t* block_len,
char* output, comm_size_t all_size);
264 static void AllgatherRing(
char* input,
const comm_size_t* block_start,
const comm_size_t* block_len,
char* output, comm_size_t all_size);
266 static void ReduceScatterRecursiveHalving(
char* input, comm_size_t input_size,
int type_size,
267 const comm_size_t* block_start,
const comm_size_t* block_len,
char* output, comm_size_t output_size,
268 const ReduceFunction& reducer);
270 static void ReduceScatterRing(
char* input, comm_size_t input_size,
int type_size,
271 const comm_size_t* block_start,
const comm_size_t* block_len,
char* output, comm_size_t output_size,
272 const ReduceFunction& reducer);
275 static THREAD_LOCAL
int num_machines_;
277 static THREAD_LOCAL
int rank_;
279 static THREAD_LOCAL std::unique_ptr<Linkers> linkers_;
281 static THREAD_LOCAL
BruckMap bruck_map_;
285 static THREAD_LOCAL std::vector<comm_size_t> block_start_;
287 static THREAD_LOCAL std::vector<comm_size_t> block_len_;
289 static THREAD_LOCAL std::vector<char> buffer_;
291 static THREAD_LOCAL comm_size_t buffer_size_;
293 static THREAD_LOCAL ReduceScatterFunction reduce_scatter_ext_fun_;
294 static THREAD_LOCAL AllgatherFunction allgather_ext_fun_;
static void Allreduce(char *input, comm_size_t input_size, int type_size, char *output, const ReduceFunction &reducer)
Perform all_reduce. if data size is small, will perform AllreduceByAllGather, else with call ReduceSc...
Definition network.cpp:64
static void ReduceScatter(char *input, comm_size_t input_size, int type_size, const comm_size_t *block_start, const comm_size_t *block_len, char *output, comm_size_t output_size, const ReduceFunction &reducer)
Perform reduce scatter by using recursive halving algorithm. Communication times is O(log(n)),...
Definition network.cpp:228