9#ifndef RABIT_ALLREDUCE_MOCK_H_
10#define RABIT_ALLREDUCE_MOCK_H_
27 tsum_allreduce_ = 0.0;
28 tsum_allgather_ = 0.0;
32 void SetParam(
const char *name,
const char *val)
override {
35 if (!strcmp(name,
"rabit_num_trial")) num_trial_ = atoi(val);
36 if (!strcmp(name,
"DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val);
37 if (!strcmp(name,
"report_stats")) report_stats_ = atoi(val);
38 if (!strcmp(name,
"force_local")) force_local_ = atoi(val);
39 if (!strcmp(name,
"mock")) {
42 &k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
43 "invalid mock parameter");
47 void Allreduce(
void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
49 void *prepare_arg)
override {
50 this->Verify(MockKey(rank, version_number, seq_counter, num_trial_),
"AllReduce");
53 prepare_fun, prepare_arg);
56 void Allgather(
void *sendrecvbuf,
size_t total_size,
size_t slice_begin,
57 size_t slice_end,
size_t size_prev_slice)
override {
58 this->Verify(MockKey(rank, version_number, seq_counter, num_trial_),
"Allgather");
64 void Broadcast(
void *sendrecvbuf_,
size_t total_size,
int root)
override {
65 this->Verify(MockKey(rank, version_number, seq_counter, num_trial_),
"Broadcast");
69 tsum_allreduce_ = 0.0;
70 tsum_allgather_ = 0.0;
72 if (force_local_ == 0) {
79 this->Verify(MockKey(rank, version_number, seq_counter, num_trial_),
"CheckPoint");
81 double tbet_chkpt = tstart - time_checkpoint_;
85 if (report_stats_ != 0 && rank == 0) {
87 ss <<
"[v" << version_number <<
"] global_size="
88 <<
",check_tcost="<< tcost <<
" sec"
89 <<
",allreduce_tcost=" << tsum_allreduce_ <<
" sec"
90 <<
",allgather_tcost=" << tsum_allgather_ <<
" sec"
91 <<
",between_chpt=" << tbet_chkpt <<
"sec\n";
94 tsum_allreduce_ = 0.0;
95 tsum_allgather_ = 0.0;
104 double tsum_allreduce_;
106 double tsum_allgather_;
107 double time_checkpoint_;
117 MockKey(
int rank,
int version,
int seqno,
int ntrial)
118 : rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
119 inline bool operator==(
const MockKey &b)
const {
120 return rank == b.rank &&
121 version == b.version &&
125 inline bool operator<(
const MockKey &b)
const {
126 if (rank != b.rank)
return rank < b.rank;
127 if (version != b.version)
return version < b.version;
128 if (seqno != b.seqno)
return seqno < b.seqno;
129 return ntrial < b.ntrial;
135 std::map<MockKey, int> mock_map_;
137 inline void Verify(
const MockKey &key,
const char *name) {
138 if (mock_map_.count(key) != 0) {
141 throw dmlc::Error(std::to_string(rank) +
"@@@Hit Mock Error: " + name);
Basic implementation of AllReduce using TCP non-block socket and tree-shape reduction.
implementation of basic Allreduce engine
Definition allreduce_base.h:42
void TrackerPrint(const std::string &msg) override
print the msg in the tracker, this function can be used to communicate the information of the progres...
Definition allreduce_base.cc:145
void CheckPoint() override
Increase internal version number. Deprecated.
Definition allreduce_base.h:153
virtual void SetParam(const char *name, const char *val)
set parameters to the engine
Definition allreduce_base.cc:182
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer, PreprocFunction prepare_fun=nullptr, void *prepare_arg=nullptr) override
perform in-place allreduce, on sendrecvbuf this function is NOT thread-safe
Definition allreduce_base.h:123
void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override
broadcast data from root to all nodes
Definition allreduce_base.h:141
int LoadCheckPoint() override
deprecated
Definition allreduce_base.h:150
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice) override
internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,...
Definition allreduce_base.h:102
Definition allreduce_mock.h:20
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer, PreprocFunction prepare_fun, void *prepare_arg) override
perform in-place allreduce, on sendrecvbuf this function is NOT thread-safe
Definition allreduce_mock.h:47
void CheckPoint() override
Increase internal version number. Deprecated.
Definition allreduce_mock.h:78
void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override
broadcast data from root to all nodes
Definition allreduce_mock.h:64
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice) override
internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,...
Definition allreduce_mock.h:56
int LoadCheckPoint() override
deprecated
Definition allreduce_mock.h:68
void SetParam(const char *name, const char *val) override
set parameters to the engine
Definition allreduce_mock.h:32
void() PreprocFunction(void *arg)
Preprocessing function, that is called before AllReduce, used to prepare the data used by AllReduce.
Definition engine.h:29
void() ReduceFunction(const void *src, void *dst, int count, const MPI::Datatype &dtype)
reduce function, the same form of MPI reduce function is used, to be compatible with MPI interface In...
Definition engine.h:41
cross platform timer for timing
This file defines the core interface of rabit library.
double GetTime(void)
return time in seconds
Definition timer.h:27
void Check(bool exp, const char *fmt,...)
same as assert, but this is intended to be used as a message for users
Definition utils.h:91
namespace of rabit
Definition engine.h:18
exception class that will be thrown by default logger if DMLC_LOG_FATAL_THROW == 1
Definition logging.h:29