Medial Code Documentation
Loading...
Searching...
No Matches
threadediter.h
Go to the documentation of this file.
1
9#ifndef DMLC_THREADEDITER_H_
10#define DMLC_THREADEDITER_H_
11// defines DMLC_USE_CXX11
12#include "./base.h"
13// this code depends on c++11
14#if DMLC_ENABLE_STD_THREAD
15#include <condition_variable>
16#include <functional>
17#include <mutex>
18#include <queue>
19#include <atomic>
20#include <thread>
21#include <utility>
22#include <memory>
23#include "./data.h"
24#include "./logging.h"
25
26namespace dmlc {
27
32class ScopedThread {
33 public:
38 explicit ScopedThread(std::thread thread)
39 : thread_(std::move(thread)) {
40 if (!thread_.joinable()) {
41 throw std::logic_error("No thread");
42 }
43 }
44 // destructor: join upon destruction
45 virtual ~ScopedThread() {
46 thread_.join();
47 }
48 // copy assignment and construction are not allowed
49 ScopedThread(ScopedThread const&) = delete;
50 ScopedThread& operator=(ScopedThread const&) = delete;
51
52 private:
53 std::thread thread_;
54};
55
77template<typename DType>
78class ThreadedIter : public DataIter<DType> {
79 public:
85 class Producer {
86 public:
87 // virtual destructor
88 virtual ~Producer() = default;
90 virtual void BeforeFirst() {
91 NotImplemented();
92 }
106 virtual bool Next(DType **inout_dptr) = 0;
107 };
112 explicit ThreadedIter(size_t max_capacity = 8)
113 : producer_(nullptr),
114 producer_thread_(nullptr),
115 max_capacity_(max_capacity),
116 nwait_consumer_(0),
117 nwait_producer_(0),
118 out_data_(NULL) {}
120 virtual ~ThreadedIter(void) {
121 this->Destroy();
122 }
129 inline void Destroy(void);
134 inline void set_max_capacity(size_t max_capacity) {
135 max_capacity_ = max_capacity;
136 }
142 inline void Init(std::shared_ptr<Producer> producer);
151 inline void Init(std::function<bool(DType **)> next,
152 std::function<void()> beforefirst = NotImplemented);
162 inline bool Next(DType **out_dptr);
169 inline void Recycle(DType **inout_dptr);
170
174 inline void ThrowExceptionIfSet(void);
175
179 inline void ClearException(void);
180
187 virtual bool Next(void) {
188 if (out_data_ != NULL) {
189 this->Recycle(&out_data_);
190 }
191 if (Next(&out_data_)) {
192 return true;
193 } else {
194 return false;
195 }
196 }
202 virtual const DType &Value(void) const {
203 CHECK(out_data_ != NULL) << "Calling Value at beginning or end?";
204 return *out_data_;
205 }
207 virtual void BeforeFirst(void) {
208 ThrowExceptionIfSet();
209 std::unique_lock<std::mutex> lock(mutex_);
210 if (out_data_ != NULL) {
211 free_cells_.push(out_data_);
212 out_data_ = NULL;
213 }
214 if (producer_sig_.load(std::memory_order_acquire) == kDestroy) return;
215
216 producer_sig_.store(kBeforeFirst, std::memory_order_release);
217 CHECK(!producer_sig_processed_.load(std::memory_order_acquire));
218 if (nwait_producer_ != 0) {
219 producer_cond_.notify_one();
220 }
221 CHECK(!producer_sig_processed_.load(std::memory_order_acquire));
222 // wait until the request has been processed
223 consumer_cond_.wait(lock, [this]() {
224 return producer_sig_processed_.load(std::memory_order_acquire);
225 });
226 producer_sig_processed_.store(false, std::memory_order_release);
227 bool notify = nwait_producer_ != 0 && !produce_end_;
228 lock.unlock();
229 // notify producer, in case they are waiting for the condition.
230 if (notify) producer_cond_.notify_one();
231 ThrowExceptionIfSet();
232 }
233
234 private:
236 inline static void NotImplemented(void) {
237 LOG(FATAL) << "BeforeFirst is not supported";
238 }
240 enum Signal {
241 kProduce,
242 kBeforeFirst,
243 kDestroy
244 };
246 // Producer *producer_owned_;
247 std::shared_ptr<Producer> producer_;
248
250 std::atomic<Signal> producer_sig_;
252 std::atomic<bool> producer_sig_processed_;
254 std::unique_ptr<ScopedThread> producer_thread_;
256 std::atomic<bool> produce_end_;
258 size_t max_capacity_;
260 std::mutex mutex_;
262 std::mutex mutex_exception_;
264 unsigned nwait_consumer_;
266 unsigned nwait_producer_;
268 std::condition_variable producer_cond_;
270 std::condition_variable consumer_cond_;
272 DType *out_data_;
274 std::queue<DType*> queue_;
276 std::queue<DType*> free_cells_;
278 std::exception_ptr iter_exception_{nullptr};
279};
280
281// implementation of functions
282template <typename DType> inline void ThreadedIter<DType>::Destroy(void) {
283 if (producer_thread_) {
284 {
285 // lock the mutex
286 std::lock_guard<std::mutex> lock(mutex_);
287 // send destroy signal
288 producer_sig_.store(kDestroy, std::memory_order_release);
289 if (nwait_producer_ != 0) {
290 producer_cond_.notify_one();
291 }
292 }
293 producer_thread_.reset(nullptr);
294 }
295 // end of critical region
296 // now the slave thread should exit
297 while (free_cells_.size() != 0) {
298 delete free_cells_.front();
299 free_cells_.pop();
300 }
301 while (queue_.size() != 0) {
302 delete queue_.front();
303 queue_.pop();
304 }
305 if (producer_ != NULL) {
306 producer_.reset();
307 }
308 if (out_data_ != NULL) {
309 delete out_data_;
310 out_data_ = NULL;
311 }
312}
313
314template<typename DType>
315inline void ThreadedIter<DType>::
316Init(std::shared_ptr<Producer> producer) {
317 CHECK(producer_ == NULL) << "can only call Init once";
318 auto next = [producer](DType **dptr) {
319 return producer->Next(dptr);
320 };
321 auto beforefirst = [producer]() {
322 producer->BeforeFirst();
323 };
324 this->Init(next, beforefirst);
325}
326
327template <typename DType>
328inline void ThreadedIter<DType>::Init(std::function<bool(DType **)> next,
329 std::function<void()> beforefirst) {
330 producer_sig_.store(kProduce, std::memory_order_release);
331 producer_sig_processed_.store(false, std::memory_order_release);
332 produce_end_.store(false, std::memory_order_release);
333 ClearException();
334 // procedure running in prodcuer
335 // run producer thread
336 auto producer_fun = [this, next, beforefirst]() {
337 while (true) {
338 try {
339 DType *cell = NULL;
340 {
341 // lockscope
342 std::unique_lock<std::mutex> lock(mutex_);
343 ++this->nwait_producer_;
344 producer_cond_.wait(lock, [this]() {
345 if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
346 bool ret = !produce_end_.load(std::memory_order_acquire)
347 && (queue_.size() < max_capacity_ ||
348 free_cells_.size() != 0);
349 return ret;
350 } else {
351 return true;
352 }
353 });
354 --this->nwait_producer_;
355 if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
356 if (free_cells_.size() != 0) {
357 cell = free_cells_.front();
358 free_cells_.pop();
359 }
360 } else if (producer_sig_.load(std::memory_order_acquire) == kBeforeFirst) {
361 // reset the producer
362 beforefirst();
363 // cleanup the queue
364 while (queue_.size() != 0) {
365 free_cells_.push(queue_.front());
366 queue_.pop();
367 }
368 // reset the state
369 produce_end_.store(false, std::memory_order_release);
370 producer_sig_processed_.store(true, std::memory_order_release);
371 producer_sig_.store(kProduce, std::memory_order_release);
372 // notify consumer that all the process as been done.
373 lock.unlock();
374 consumer_cond_.notify_all();
375 continue;
376 } else {
377 // destroy the thread
378 DCHECK(producer_sig_.load(std::memory_order_acquire) == kDestroy);
379 producer_sig_processed_.store(true, std::memory_order_release);
380 produce_end_.store(true, std::memory_order_release);
381 lock.unlock();
382 consumer_cond_.notify_all();
383 return;
384 }
385 } // end of lock scope
386 // now without lock
387 produce_end_.store(!next(&cell), std::memory_order_release);
388 DCHECK(cell != NULL || produce_end_.load(std::memory_order_acquire));
389 bool notify;
390 {
391 // lockscope
392 std::lock_guard<std::mutex> lock(mutex_);
393 if (!produce_end_.load(std::memory_order_acquire)) {
394 queue_.push(cell);
395 } else {
396 if (cell != NULL)
397 free_cells_.push(cell);
398 }
399 // put things into queue
400 notify = nwait_consumer_ != 0;
401 }
402 if (notify)
403 consumer_cond_.notify_all();
404 } catch (std::exception &e) {
405 // Shouldn't throw exception in destructor
406 DCHECK(producer_sig_.load(std::memory_order_acquire) != kDestroy);
407 {
408 std::lock_guard<std::mutex> lock(mutex_exception_);
409 if (!iter_exception_) {
410 iter_exception_ = std::current_exception();
411 }
412 }
413 bool next_notify = false;
414 {
415 std::unique_lock<std::mutex> lock(mutex_);
416 if (producer_sig_.load(std::memory_order_acquire) == kBeforeFirst) {
417 while (queue_.size() != 0) {
418 free_cells_.push(queue_.front());
419 queue_.pop();
420 }
421 produce_end_.store(true, std::memory_order_release);
422 producer_sig_processed_.store(true, std::memory_order_release);
423 lock.unlock();
424 consumer_cond_.notify_all();
425 } else if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
426 produce_end_.store(true, std::memory_order_release);
427 next_notify = nwait_consumer_ != 0;
428 lock.unlock();
429 if (next_notify)
430 consumer_cond_.notify_all();
431 }
432 }
433 return;
434 }
435 }
436 };
437 producer_thread_.reset(new ScopedThread{std::thread(producer_fun)});
438}
439
440template <typename DType>
441inline bool ThreadedIter<DType>::Next(DType **out_dptr) {
442 if (producer_sig_.load(std::memory_order_acquire) == kDestroy)
443 return false;
444 ThrowExceptionIfSet();
445 std::unique_lock<std::mutex> lock(mutex_);
446 CHECK(producer_sig_.load(std::memory_order_acquire) == kProduce)
447 << "Make sure you call BeforeFirst not inconcurrent with Next!";
448 ++nwait_consumer_;
449 consumer_cond_.wait(lock,
450 [this]() { return queue_.size() != 0
451 || produce_end_.load(std::memory_order_acquire); });
452 --nwait_consumer_;
453 if (queue_.size() != 0) {
454 *out_dptr = queue_.front();
455 queue_.pop();
456 bool notify = nwait_producer_ != 0
457 && !produce_end_.load(std::memory_order_acquire);
458 lock.unlock();
459 if (notify)
460 producer_cond_.notify_one();
461
462 ThrowExceptionIfSet();
463 return true;
464 } else {
465 CHECK(produce_end_.load(std::memory_order_acquire));
466 lock.unlock();
467
468 ThrowExceptionIfSet();
469 return false;
470 }
471}
472
473template <typename DType>
474inline void ThreadedIter<DType>::Recycle(DType **inout_dptr) {
475 bool notify;
476 ThrowExceptionIfSet();
477 {
478 std::lock_guard<std::mutex> lock(mutex_);
479 free_cells_.push(*inout_dptr);
480 *inout_dptr = NULL;
481 notify = nwait_producer_ != 0 && !produce_end_.load(std::memory_order_acquire);
482 }
483 if (notify)
484 producer_cond_.notify_one();
485 ThrowExceptionIfSet();
486}
487
488template <typename DType> inline void ThreadedIter<DType>::ThrowExceptionIfSet(void) {
489 std::exception_ptr tmp_exception{nullptr};
490 {
491 std::lock_guard<std::mutex> lock(mutex_exception_);
492 if (iter_exception_) {
493 tmp_exception = iter_exception_;
494 }
495 }
496 if (tmp_exception) {
497 try {
498 std::rethrow_exception(tmp_exception);
499 } catch (std::exception& exc) {
500 LOG(FATAL) << exc.what();
501 }
502 }
503}
504
505template <typename DType> inline void ThreadedIter<DType>::ClearException(void) {
506 std::lock_guard<std::mutex> lock(mutex_exception_);
507 iter_exception_ = nullptr;
508}
509
510} // namespace dmlc
511#endif // DMLC_USE_CXX11
512#endif // DMLC_THREADEDITER_H_
namespace for dmlc
Definition array_view.h:12
bool Init(int argc, char *argv[])
initializes the engine module
Definition engine.cc:43
Definition StdDeque.h:58
Macros common to all headers.
Definition external_memory.c:26