Medial Code Documentation
Loading...
Searching...
No Matches
threaded_input_split.h
Go to the documentation of this file.
1
7#ifndef DMLC_IO_THREADED_INPUT_SPLIT_H_
8#define DMLC_IO_THREADED_INPUT_SPLIT_H_
9
10#include <dmlc/base.h>
11// this code depends on c++11
12#if DMLC_ENABLE_STD_THREAD
13#include <dmlc/threadediter.h>
14#include <algorithm>
15#include "./input_split_base.h"
16
17namespace dmlc {
18namespace io {
23class ThreadedInputSplit : public InputSplit {
24 public:
29 explicit ThreadedInputSplit(InputSplitBase *base, const size_t batch_size)
30 : buffer_size_(InputSplitBase::kBufferSize),
31 batch_size_(batch_size),
32 base_(base), tmp_chunk_(NULL) {
33 iter_.set_max_capacity(2);
34 // initalize the iterator
35 iter_.Init([this](InputSplitBase::Chunk **dptr) {
36 if (*dptr == NULL) {
37 *dptr = new InputSplitBase::Chunk(buffer_size_);
38 }
39 return base_->NextBatchEx(*dptr, batch_size_);
40 },
41 [base]() { base->BeforeFirst(); });
42 }
43 // destructor
44 virtual ~ThreadedInputSplit(void) {
45 iter_.Destroy();
46 delete tmp_chunk_;
47 delete base_;
48 }
49 virtual void BeforeFirst() {
50 iter_.BeforeFirst();
51 if (tmp_chunk_ != NULL) {
52 iter_.Recycle(&tmp_chunk_);
53 }
54 }
55 virtual void HintChunkSize(size_t chunk_size) {
56 buffer_size_ = std::max(chunk_size / sizeof(uint32_t), buffer_size_);
57 }
58 // implement next record
59 virtual bool NextRecord(Blob *out_rec) {
60 if (tmp_chunk_ == NULL) {
61 if (!iter_.Next(&tmp_chunk_)) return false;
62 }
63 while (!base_->ExtractNextRecord(out_rec, tmp_chunk_)) {
64 iter_.Recycle(&tmp_chunk_);
65 if (!iter_.Next(&tmp_chunk_)) return false;
66 }
67 return true;
68 }
69 // implement next chunk
70 virtual bool NextChunk(Blob *out_chunk) {
71 if (tmp_chunk_ == NULL) {
72 if (!iter_.Next(&tmp_chunk_)) return false;
73 }
74 while (!base_->ExtractNextChunk(out_chunk, tmp_chunk_)) {
75 iter_.Recycle(&tmp_chunk_);
76 if (!iter_.Next(&tmp_chunk_)) return false;
77 }
78 return true;
79 }
80
81 virtual size_t GetTotalSize(void) {
82 return base_->GetTotalSize();
83 }
84
85 virtual void ResetPartition(unsigned part_index, unsigned num_parts) {
86 base_->ResetPartition(part_index, num_parts);
87 this->BeforeFirst();
88 }
89
90 private:
92 size_t buffer_size_;
94 size_t batch_size_;
96 InputSplitBase *base_;
98 ThreadedIter<InputSplitBase::Chunk> iter_;
100 InputSplitBase::Chunk *tmp_chunk_;
101};
102} // namespace io
103} // namespace dmlc
104#endif // DMLC_USE_CXX11
105#endif // DMLC_IO_THREADED_INPUT_SPLIT_H_
defines configuration macros
base class to construct input split from multiple files
namespace for dmlc
Definition array_view.h:12
thread backed iterator that can be used to implement general thread-based pipeline such as prefetch a...