Medial Code Documentation
Loading...
Searching...
No Matches
allreduce_mock.h
Go to the documentation of this file.
1
9#ifndef RABIT_ALLREDUCE_MOCK_H_
10#define RABIT_ALLREDUCE_MOCK_H_
11#include <vector>
12#include <map>
13#include <sstream>
14#include <dmlc/timer.h>
16#include "allreduce_base.h"
17
18namespace rabit {
19namespace engine {
21 public:
22 // constructor
24 num_trial_ = 0;
25 force_local_ = 0;
26 report_stats_ = 0;
27 tsum_allreduce_ = 0.0;
28 tsum_allgather_ = 0.0;
29 }
30 // destructor
31 ~AllreduceMock() override = default;
32 void SetParam(const char *name, const char *val) override {
33 AllreduceBase::SetParam(name, val);
34 // additional parameters
35 if (!strcmp(name, "rabit_num_trial")) num_trial_ = atoi(val);
36 if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val);
37 if (!strcmp(name, "report_stats")) report_stats_ = atoi(val);
38 if (!strcmp(name, "force_local")) force_local_ = atoi(val);
39 if (!strcmp(name, "mock")) {
40 MockKey k;
41 utils::Check(sscanf(val, "%d,%d,%d,%d",
42 &k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
43 "invalid mock parameter");
44 mock_map_[k] = 1;
45 }
46 }
47 void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
48 ReduceFunction reducer, PreprocFunction prepare_fun,
49 void *prepare_arg) override {
50 this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce");
51 double tstart = dmlc::GetTime();
52 AllreduceBase::Allreduce(sendrecvbuf_, type_nbytes, count, reducer,
53 prepare_fun, prepare_arg);
54 tsum_allreduce_ += dmlc::GetTime() - tstart;
55 }
56 void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin,
57 size_t slice_end, size_t size_prev_slice) override {
58 this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather");
59 double tstart = dmlc::GetTime();
60 AllreduceBase::Allgather(sendrecvbuf, total_size, slice_begin, slice_end,
61 size_prev_slice);
62 tsum_allgather_ += dmlc::GetTime() - tstart;
63 }
64 void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override {
65 this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast");
66 AllreduceBase::Broadcast(sendrecvbuf_, total_size, root);
67 }
68 int LoadCheckPoint() override {
69 tsum_allreduce_ = 0.0;
70 tsum_allgather_ = 0.0;
71 time_checkpoint_ = dmlc::GetTime();
72 if (force_local_ == 0) {
74 } else {
76 }
77 }
78 void CheckPoint() override {
79 this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint");
80 double tstart = dmlc::GetTime();
81 double tbet_chkpt = tstart - time_checkpoint_;
83 time_checkpoint_ = dmlc::GetTime();
84 double tcost = dmlc::GetTime() - tstart;
85 if (report_stats_ != 0 && rank == 0) {
86 std::stringstream ss;
87 ss << "[v" << version_number << "] global_size="
88 << ",check_tcost="<< tcost <<" sec"
89 << ",allreduce_tcost=" << tsum_allreduce_ << " sec"
90 << ",allgather_tcost=" << tsum_allgather_ << " sec"
91 << ",between_chpt=" << tbet_chkpt << "sec\n";
92 this->TrackerPrint(ss.str());
93 }
94 tsum_allreduce_ = 0.0;
95 tsum_allgather_ = 0.0;
96 }
97
98 protected:
99 // force checkpoint to local
100 int force_local_;
101 // whether report statistics
102 int report_stats_;
103 // sum of allreduce
104 double tsum_allreduce_;
105 // sum of allgather
106 double tsum_allgather_;
107 double time_checkpoint_;
108
109 private:
110 // key to identify the mock stage
111 struct MockKey {
112 int rank;
113 int version;
114 int seqno;
115 int ntrial;
116 MockKey() = default;
117 MockKey(int rank, int version, int seqno, int ntrial)
118 : rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
119 inline bool operator==(const MockKey &b) const {
120 return rank == b.rank &&
121 version == b.version &&
122 seqno == b.seqno &&
123 ntrial == b.ntrial;
124 }
125 inline bool operator<(const MockKey &b) const {
126 if (rank != b.rank) return rank < b.rank;
127 if (version != b.version) return version < b.version;
128 if (seqno != b.seqno) return seqno < b.seqno;
129 return ntrial < b.ntrial;
130 }
131 };
132 // number of failure trials
133 int num_trial_;
134 // record all mock actions
135 std::map<MockKey, int> mock_map_;
136 // used to generate all kinds of exceptions
137 inline void Verify(const MockKey &key, const char *name) {
138 if (mock_map_.count(key) != 0) {
139 num_trial_ += 1;
140 // data processing frameworks runs on shared process
141 throw dmlc::Error(std::to_string(rank) + "@@@Hit Mock Error: " + name);
142 }
143 }
144};
145} // namespace engine
146} // namespace rabit
147#endif // RABIT_ALLREDUCE_MOCK_H_
Basic implementation of AllReduce using TCP non-block socket and tree-shape reduction.
implementation of basic Allreduce engine
Definition allreduce_base.h:42
void TrackerPrint(const std::string &msg) override
print the msg in the tracker, this function can be used to communicate the information of the progres...
Definition allreduce_base.cc:145
void CheckPoint() override
Increase internal version number. Deprecated.
Definition allreduce_base.h:153
virtual void SetParam(const char *name, const char *val)
set parameters to the engine
Definition allreduce_base.cc:182
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer, PreprocFunction prepare_fun=nullptr, void *prepare_arg=nullptr) override
perform in-place allreduce, on sendrecvbuf this function is NOT thread-safe
Definition allreduce_base.h:123
void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override
broadcast data from root to all nodes
Definition allreduce_base.h:141
int LoadCheckPoint() override
deprecated
Definition allreduce_base.h:150
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice) override
internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,...
Definition allreduce_base.h:102
Definition allreduce_mock.h:20
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer, PreprocFunction prepare_fun, void *prepare_arg) override
perform in-place allreduce, on sendrecvbuf this function is NOT thread-safe
Definition allreduce_mock.h:47
void CheckPoint() override
Increase internal version number. Deprecated.
Definition allreduce_mock.h:78
void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override
broadcast data from root to all nodes
Definition allreduce_mock.h:64
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end, size_t size_prev_slice) override
internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,...
Definition allreduce_mock.h:56
int LoadCheckPoint() override
deprecated
Definition allreduce_mock.h:68
void SetParam(const char *name, const char *val) override
set parameters to the engine
Definition allreduce_mock.h:32
void() PreprocFunction(void *arg)
Preprocessing function, that is called before AllReduce, used to prepare the data used by AllReduce.
Definition engine.h:29
void() ReduceFunction(const void *src, void *dst, int count, const MPI::Datatype &dtype)
reduce function, the same form of MPI reduce function is used, to be compatible with MPI interface In...
Definition engine.h:41
cross platform timer for timing
This file defines the core interface of rabit library.
double GetTime(void)
return time in seconds
Definition timer.h:27
void Check(bool exp, const char *fmt,...)
same as assert, but this is intended to be used as a message for users
Definition utils.h:91
namespace of rabit
Definition engine.h:18
exception class that will be thrown by default logger if DMLC_LOG_FATAL_THROW == 1
Definition logging.h:29