9#ifndef DMLC_THREADEDITER_H_
10#define DMLC_THREADEDITER_H_
14#if DMLC_ENABLE_STD_THREAD
15#include <condition_variable>
38 explicit ScopedThread(std::thread thread)
39 : thread_(
std::move(thread)) {
40 if (!thread_.joinable()) {
41 throw std::logic_error(
"No thread");
45 virtual ~ScopedThread() {
49 ScopedThread(ScopedThread
const&) =
delete;
50 ScopedThread& operator=(ScopedThread
const&) =
delete;
77template<
typename DType>
78class ThreadedIter :
public DataIter<DType> {
88 virtual ~Producer() =
default;
90 virtual void BeforeFirst() {
106 virtual bool Next(DType **inout_dptr) = 0;
112 explicit ThreadedIter(
size_t max_capacity = 8)
113 : producer_(nullptr),
114 producer_thread_(nullptr),
115 max_capacity_(max_capacity),
120 virtual ~ThreadedIter(
void) {
129 inline void Destroy(
void);
134 inline void set_max_capacity(
size_t max_capacity) {
135 max_capacity_ = max_capacity;
142 inline void Init(std::shared_ptr<Producer> producer);
151 inline void Init(std::function<
bool(DType **)> next,
152 std::function<
void()> beforefirst = NotImplemented);
162 inline bool Next(DType **out_dptr);
169 inline void Recycle(DType **inout_dptr);
174 inline void ThrowExceptionIfSet(
void);
179 inline void ClearException(
void);
187 virtual bool Next(
void) {
188 if (out_data_ != NULL) {
189 this->Recycle(&out_data_);
191 if (Next(&out_data_)) {
202 virtual const DType &Value(
void)
const {
203 CHECK(out_data_ != NULL) <<
"Calling Value at beginning or end?";
207 virtual void BeforeFirst(
void) {
208 ThrowExceptionIfSet();
209 std::unique_lock<std::mutex> lock(mutex_);
210 if (out_data_ != NULL) {
211 free_cells_.push(out_data_);
214 if (producer_sig_.load(std::memory_order_acquire) == kDestroy)
return;
216 producer_sig_.store(kBeforeFirst, std::memory_order_release);
217 CHECK(!producer_sig_processed_.load(std::memory_order_acquire));
218 if (nwait_producer_ != 0) {
219 producer_cond_.notify_one();
221 CHECK(!producer_sig_processed_.load(std::memory_order_acquire));
223 consumer_cond_.wait(lock, [
this]() {
224 return producer_sig_processed_.load(std::memory_order_acquire);
226 producer_sig_processed_.store(
false, std::memory_order_release);
227 bool notify = nwait_producer_ != 0 && !produce_end_;
230 if (notify) producer_cond_.notify_one();
231 ThrowExceptionIfSet();
236 inline static void NotImplemented(
void) {
237 LOG(FATAL) <<
"BeforeFirst is not supported";
247 std::shared_ptr<Producer> producer_;
250 std::atomic<Signal> producer_sig_;
252 std::atomic<bool> producer_sig_processed_;
254 std::unique_ptr<ScopedThread> producer_thread_;
256 std::atomic<bool> produce_end_;
258 size_t max_capacity_;
262 std::mutex mutex_exception_;
264 unsigned nwait_consumer_;
266 unsigned nwait_producer_;
268 std::condition_variable producer_cond_;
270 std::condition_variable consumer_cond_;
274 std::queue<DType*> queue_;
276 std::queue<DType*> free_cells_;
278 std::exception_ptr iter_exception_{
nullptr};
282template <
typename DType>
inline void ThreadedIter<DType>::Destroy(
void) {
283 if (producer_thread_) {
286 std::lock_guard<std::mutex> lock(mutex_);
288 producer_sig_.store(kDestroy, std::memory_order_release);
289 if (nwait_producer_ != 0) {
290 producer_cond_.notify_one();
293 producer_thread_.reset(
nullptr);
297 while (free_cells_.size() != 0) {
298 delete free_cells_.front();
301 while (queue_.size() != 0) {
302 delete queue_.front();
305 if (producer_ != NULL) {
308 if (out_data_ != NULL) {
314template<
typename DType>
315inline void ThreadedIter<DType>::
316Init(std::shared_ptr<Producer> producer) {
317 CHECK(producer_ == NULL) <<
"can only call Init once";
318 auto next = [producer](DType **dptr) {
319 return producer->Next(dptr);
321 auto beforefirst = [producer]() {
322 producer->BeforeFirst();
324 this->
Init(next, beforefirst);
327template <
typename DType>
328inline void ThreadedIter<DType>::Init(std::function<
bool(DType **)> next,
329 std::function<
void()> beforefirst) {
330 producer_sig_.store(kProduce, std::memory_order_release);
331 producer_sig_processed_.store(
false, std::memory_order_release);
332 produce_end_.store(
false, std::memory_order_release);
336 auto producer_fun = [
this, next, beforefirst]() {
342 std::unique_lock<std::mutex> lock(mutex_);
343 ++this->nwait_producer_;
344 producer_cond_.wait(lock, [
this]() {
345 if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
346 bool ret = !produce_end_.load(std::memory_order_acquire)
347 && (queue_.size() < max_capacity_ ||
348 free_cells_.size() != 0);
354 --this->nwait_producer_;
355 if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
356 if (free_cells_.size() != 0) {
357 cell = free_cells_.front();
360 }
else if (producer_sig_.load(std::memory_order_acquire) == kBeforeFirst) {
364 while (queue_.size() != 0) {
365 free_cells_.push(queue_.front());
369 produce_end_.store(
false, std::memory_order_release);
370 producer_sig_processed_.store(
true, std::memory_order_release);
371 producer_sig_.store(kProduce, std::memory_order_release);
374 consumer_cond_.notify_all();
378 DCHECK(producer_sig_.load(std::memory_order_acquire) == kDestroy);
379 producer_sig_processed_.store(
true, std::memory_order_release);
380 produce_end_.store(
true, std::memory_order_release);
382 consumer_cond_.notify_all();
387 produce_end_.store(!next(&cell), std::memory_order_release);
388 DCHECK(cell != NULL || produce_end_.load(std::memory_order_acquire));
392 std::lock_guard<std::mutex> lock(mutex_);
393 if (!produce_end_.load(std::memory_order_acquire)) {
397 free_cells_.push(cell);
400 notify = nwait_consumer_ != 0;
403 consumer_cond_.notify_all();
404 }
catch (std::exception &e) {
406 DCHECK(producer_sig_.load(std::memory_order_acquire) != kDestroy);
408 std::lock_guard<std::mutex> lock(mutex_exception_);
409 if (!iter_exception_) {
410 iter_exception_ = std::current_exception();
413 bool next_notify =
false;
415 std::unique_lock<std::mutex> lock(mutex_);
416 if (producer_sig_.load(std::memory_order_acquire) == kBeforeFirst) {
417 while (queue_.size() != 0) {
418 free_cells_.push(queue_.front());
421 produce_end_.store(
true, std::memory_order_release);
422 producer_sig_processed_.store(
true, std::memory_order_release);
424 consumer_cond_.notify_all();
425 }
else if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
426 produce_end_.store(
true, std::memory_order_release);
427 next_notify = nwait_consumer_ != 0;
430 consumer_cond_.notify_all();
437 producer_thread_.reset(
new ScopedThread{std::thread(producer_fun)});
440template <
typename DType>
441inline bool ThreadedIter<DType>::Next(DType **out_dptr) {
442 if (producer_sig_.load(std::memory_order_acquire) == kDestroy)
444 ThrowExceptionIfSet();
445 std::unique_lock<std::mutex> lock(mutex_);
446 CHECK(producer_sig_.load(std::memory_order_acquire) == kProduce)
447 <<
"Make sure you call BeforeFirst not inconcurrent with Next!";
449 consumer_cond_.wait(lock,
450 [
this]() {
return queue_.size() != 0
451 || produce_end_.load(std::memory_order_acquire); });
453 if (queue_.size() != 0) {
454 *out_dptr = queue_.front();
456 bool notify = nwait_producer_ != 0
457 && !produce_end_.load(std::memory_order_acquire);
460 producer_cond_.notify_one();
462 ThrowExceptionIfSet();
465 CHECK(produce_end_.load(std::memory_order_acquire));
468 ThrowExceptionIfSet();
473template <
typename DType>
474inline void ThreadedIter<DType>::Recycle(DType **inout_dptr) {
476 ThrowExceptionIfSet();
478 std::lock_guard<std::mutex> lock(mutex_);
479 free_cells_.push(*inout_dptr);
481 notify = nwait_producer_ != 0 && !produce_end_.load(std::memory_order_acquire);
484 producer_cond_.notify_one();
485 ThrowExceptionIfSet();
488template <
typename DType>
inline void ThreadedIter<DType>::ThrowExceptionIfSet(
void) {
489 std::exception_ptr tmp_exception{
nullptr};
491 std::lock_guard<std::mutex> lock(mutex_exception_);
492 if (iter_exception_) {
493 tmp_exception = iter_exception_;
498 std::rethrow_exception(tmp_exception);
499 }
catch (std::exception& exc) {
500 LOG(FATAL) << exc.what();
505template <
typename DType>
inline void ThreadedIter<DType>::ClearException(
void) {
506 std::lock_guard<std::mutex> lock(mutex_exception_);
507 iter_exception_ =
nullptr;
namespace for dmlc
Definition array_view.h:12
bool Init(int argc, char *argv[])
initializes the engine module
Definition engine.cc:43
Macros common to all headers.
Definition external_memory.c:26