Medial Code Documentation
Loading...
Searching...
No Matches
input_split_shuffle.h
Go to the documentation of this file.
1
7#ifndef DMLC_INPUT_SPLIT_SHUFFLE_H_
8#define DMLC_INPUT_SPLIT_SHUFFLE_H_
9
10#include <dmlc/io.h>
11
12#include <algorithm>
13#include <cstdio>
14#include <cstring>
15#include <memory>
16#include <string>
17#include <vector>
18
19namespace dmlc {
22 public:
23 // destructor
24 virtual ~InputSplitShuffle(void) { source_.reset(); }
25 // implement BeforeFirst
26 virtual void BeforeFirst(void) {
27 if (num_shuffle_parts_ > 1) {
28 std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_);
29 int idx = shuffle_indexes_[0] + part_index_ * num_shuffle_parts_;
30 source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
31 cur_shuffle_idx_ = 0;
32 } else {
33 source_->BeforeFirst();
34 }
35 }
36 virtual void HintChunkSize(size_t chunk_size) {
37 source_->HintChunkSize(chunk_size);
38 }
39 virtual size_t GetTotalSize(void) {
40 return source_->GetTotalSize();
41 }
42 // implement next record
43 virtual bool NextRecord(Blob *out_rec) {
44 if (num_shuffle_parts_ > 1) {
45 if (!source_->NextRecord(out_rec)) {
46 if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) {
47 return false;
48 }
49 ++cur_shuffle_idx_;
50 int idx =
51 shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
52 source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
53 return NextRecord(out_rec);
54 } else {
55 return true;
56 }
57 } else {
58 return source_->NextRecord(out_rec);
59 }
60 }
61 // implement next chunk
62 virtual bool NextChunk(Blob* out_chunk) {
63 if (num_shuffle_parts_ > 1) {
64 if (!source_->NextChunk(out_chunk)) {
65 if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) {
66 return false;
67 }
68 ++cur_shuffle_idx_;
69 int idx =
70 shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
71 source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
72 return NextChunk(out_chunk);
73 } else {
74 return true;
75 }
76 } else {
77 return source_->NextChunk(out_chunk);
78 }
79 }
80 // implement ResetPartition.
81 virtual void ResetPartition(unsigned rank, unsigned nsplit) {
82 CHECK(nsplit == num_parts_) << "num_parts is not consistent!";
83 int idx = shuffle_indexes_[0] + rank * num_shuffle_parts_;
84 source_->ResetPartition(idx, nsplit * num_shuffle_parts_);
85 cur_shuffle_idx_ = 0;
86 }
102 InputSplitShuffle(const char* uri,
103 unsigned part_index,
104 unsigned num_parts,
105 const char* type,
106 unsigned num_shuffle_parts,
107 int shuffle_seed)
108 : part_index_(part_index),
109 num_parts_(num_parts),
110 num_shuffle_parts_(num_shuffle_parts),
111 cur_shuffle_idx_(0) {
112 for (unsigned i = 0; i < num_shuffle_parts_; i++) {
113 shuffle_indexes_.push_back(i);
114 }
115 trnd_.seed(kRandMagic_ + part_index_ + num_parts_ + num_shuffle_parts_ +
116 shuffle_seed);
117 std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_);
118 int idx = shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
119 source_.reset(
120 InputSplit::Create(uri, idx , num_parts_ * num_shuffle_parts_, type));
121 }
140 static InputSplit* Create(const char* uri,
141 unsigned part_index,
142 unsigned num_parts,
143 const char* type,
144 unsigned num_shuffle_parts,
145 int shuffle_seed) {
146 CHECK(num_shuffle_parts > 0) << "number of shuffle parts should be greater than zero!";
147 return new InputSplitShuffle(
148 uri, part_index, num_parts, type, num_shuffle_parts, shuffle_seed);
149 }
150
151 private:
152 // magic nyumber for seed
153 static const int kRandMagic_ = 666;
155 std::mt19937 trnd_;
157 std::unique_ptr<InputSplit> source_;
159 unsigned part_index_;
161 unsigned num_parts_;
163 unsigned num_shuffle_parts_;
165 unsigned cur_shuffle_idx_;
167 std::vector<int> shuffle_indexes_;
168};
169} // namespace dmlc
170#endif // DMLC_INPUT_SPLIT_SHUFFLE_H_
class to construct input split with global shuffling
Definition input_split_shuffle.h:21
InputSplitShuffle(const char *uri, unsigned part_index, unsigned num_parts, const char *type, unsigned num_shuffle_parts, int shuffle_seed)
constructor
Definition input_split_shuffle.h:102
virtual void ResetPartition(unsigned rank, unsigned nsplit)
reset the Input split to a certain part id, The InputSplit will be pointed to the head of the new spe...
Definition input_split_shuffle.h:81
virtual void HintChunkSize(size_t chunk_size)
hint the inputsplit how large the chunk size it should return when implementing NextChunk this is a h...
Definition input_split_shuffle.h:36
virtual size_t GetTotalSize(void)
get the total size of the InputSplit
Definition input_split_shuffle.h:39
virtual bool NextChunk(Blob *out_chunk)
get a chunk of memory that can contain multiple records, the caller needs to parse the content of the...
Definition input_split_shuffle.h:62
virtual bool NextRecord(Blob *out_rec)
get the next record, the returning value is valid until next call to NextRecord, NextChunk or NextBat...
Definition input_split_shuffle.h:43
static InputSplit * Create(const char *uri, unsigned part_index, unsigned num_parts, const char *type, unsigned num_shuffle_parts, int shuffle_seed)
factory function: create input split with chunk shuffling given a uri
Definition input_split_shuffle.h:140
virtual void BeforeFirst(void)
reset the position of InputSplit to beginning
Definition input_split_shuffle.h:26
input split creates that allows reading of records from split of data, independent part that covers a...
Definition io.h:155
static InputSplit * Create(const char *uri, unsigned part_index, unsigned num_parts, const char *type)
factory function: create input split given a uri
Definition io.cc:74
defines serializable interface of dmlc
namespace for dmlc
Definition array_view.h:12
a blob of memory region
Definition io.h:158