7#ifndef DMLC_INPUT_SPLIT_SHUFFLE_H_
8#define DMLC_INPUT_SPLIT_SHUFFLE_H_
27 if (num_shuffle_parts_ > 1) {
28 std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_);
29 int idx = shuffle_indexes_[0] + part_index_ * num_shuffle_parts_;
30 source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
33 source_->BeforeFirst();
37 source_->HintChunkSize(chunk_size);
40 return source_->GetTotalSize();
44 if (num_shuffle_parts_ > 1) {
45 if (!source_->NextRecord(out_rec)) {
46 if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) {
51 shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
52 source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
58 return source_->NextRecord(out_rec);
63 if (num_shuffle_parts_ > 1) {
64 if (!source_->NextChunk(out_chunk)) {
65 if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) {
70 shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
71 source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
77 return source_->NextChunk(out_chunk);
82 CHECK(nsplit == num_parts_) <<
"num_parts is not consistent!";
83 int idx = shuffle_indexes_[0] + rank * num_shuffle_parts_;
84 source_->ResetPartition(idx, nsplit * num_shuffle_parts_);
106 unsigned num_shuffle_parts,
108 : part_index_(part_index),
109 num_parts_(num_parts),
110 num_shuffle_parts_(num_shuffle_parts),
111 cur_shuffle_idx_(0) {
112 for (
unsigned i = 0; i < num_shuffle_parts_; i++) {
113 shuffle_indexes_.push_back(i);
115 trnd_.seed(kRandMagic_ + part_index_ + num_parts_ + num_shuffle_parts_ +
117 std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_);
118 int idx = shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
144 unsigned num_shuffle_parts,
146 CHECK(num_shuffle_parts > 0) <<
"number of shuffle parts should be greater than zero!";
148 uri, part_index, num_parts, type, num_shuffle_parts, shuffle_seed);
153 static const int kRandMagic_ = 666;
157 std::unique_ptr<InputSplit> source_;
159 unsigned part_index_;
163 unsigned num_shuffle_parts_;
165 unsigned cur_shuffle_idx_;
167 std::vector<int> shuffle_indexes_;
defines serializable interface of dmlc
namespace for dmlc
Definition array_view.h:12