Medial Code Documentation
Loading...
Searching...
No Matches
sparse_page_source.h
Go to the documentation of this file.
1
5#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
6#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
7
8#include <algorithm> // for min
9#include <atomic> // for atomic
10#include <future> // for async
11#include <map>
12#include <memory>
13#include <mutex> // for mutex
14#include <string>
15#include <thread>
16#include <utility> // for pair, move
17#include <vector>
18
19#include "../common/common.h"
20#include "../common/io.h" // for PrivateMmapConstStream
21#include "../common/timer.h" // for Monitor, Timer
22#include "adapter.h"
23#include "proxy_dmatrix.h" // for DMatrixProxy
24#include "sparse_page_writer.h" // for SparsePageFormat
25#include "xgboost/base.h"
26#include "xgboost/data.h"
27
28namespace xgboost::data {
29inline void TryDeleteCacheFile(const std::string& file) {
30 if (std::remove(file.c_str()) != 0) {
31 // Don't throw, this is called in a destructor.
32 LOG(WARNING) << "Couldn't remove external memory cache file " << file
33 << "; you may want to remove it manually";
34 }
35}
36
40struct Cache {
41 // whether the write to the cache is complete
42 bool written;
43 std::string name;
44 std::string format;
45 // offset into binary cache file.
46 std::vector<std::uint64_t> offset;
47
48 Cache(bool w, std::string n, std::string fmt)
49 : written{w}, name{std::move(n)}, format{std::move(fmt)} {
50 offset.push_back(0);
51 }
52
53 static std::string ShardName(std::string name, std::string format) {
54 CHECK_EQ(format.front(), '.');
55 return name + format;
56 }
57
58 [[nodiscard]] std::string ShardName() const {
59 return ShardName(this->name, this->format);
60 }
64 void Push(std::size_t n_bytes) { offset.push_back(n_bytes); }
68 [[nodiscard]] auto View(std::size_t i) const {
69 std::uint64_t off = offset.at(i);
70 std::uint64_t len = offset.at(i + 1) - offset[i];
71 return std::pair{off, len};
72 }
76 void Commit() {
77 if (!written) {
78 std::partial_sum(offset.begin(), offset.end(), offset.begin());
79 written = true;
80 }
81 }
82};
83
84// Prevents multi-threaded call to `GetBatches`.
86 std::mutex& lock_;
87
88 public:
89 explicit TryLockGuard(std::mutex& lock) : lock_{lock} { // NOLINT
90 CHECK(lock_.try_lock()) << "Multiple threads attempting to use Sparse DMatrix.";
91 }
93 lock_.unlock();
94 }
95};
96
97// Similar to `dmlc::OMPException`, but doesn't need the threads to be joined before rethrow
99 std::mutex mutex_;
100 std::atomic<bool> flag_{false};
101 std::exception_ptr curr_exce_{nullptr};
102
103 public:
104 template <typename Fn>
105 decltype(auto) Run(Fn&& fn) noexcept(true) {
106 try {
107 return fn();
108 } catch (dmlc::Error const& e) {
109 std::lock_guard<std::mutex> guard{mutex_};
110 if (!curr_exce_) {
111 curr_exce_ = std::current_exception();
112 }
113 flag_ = true;
114 } catch (std::exception const& e) {
115 std::lock_guard<std::mutex> guard{mutex_};
116 if (!curr_exce_) {
117 curr_exce_ = std::current_exception();
118 }
119 flag_ = true;
120 } catch (...) {
121 std::lock_guard<std::mutex> guard{mutex_};
122 if (!curr_exce_) {
123 curr_exce_ = std::current_exception();
124 }
125 flag_ = true;
126 }
127 return std::invoke_result_t<Fn>();
128 }
129
130 void Rethrow() noexcept(false) {
131 if (flag_) {
132 CHECK(curr_exce_);
133 std::rethrow_exception(curr_exce_);
134 }
135 }
136};
137
141template <typename S>
143 protected:
144 // Prevents calling this iterator from multiple places(or threads).
145 std::mutex single_threaded_;
146 // The current page.
147 std::shared_ptr<S> page_;
148
149 bool at_end_ {false};
150 float missing_;
151 std::int32_t nthreads_;
152 bst_feature_t n_features_;
153 // Index to the current page.
154 std::uint32_t count_{0};
155 // Total number of batches.
156 std::uint32_t n_batches_{0};
157
158 std::shared_ptr<Cache> cache_info_;
159
160 using Ring = std::vector<std::future<std::shared_ptr<S>>>;
161 // A ring storing futures to data. Since the DMatrix iterator is forward only, so we
162 // can pre-fetch data in a ring.
163 std::unique_ptr<Ring> ring_{new Ring};
164 // Catching exception in pre-fetch threads to prevent segfault. Not always work though,
165 // OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then
166 // OOM error should be rare.
167 ExceHandler exce_;
168 common::Monitor monitor_;
169
170 bool ReadCache() {
171 CHECK(!at_end_);
172 if (!cache_info_->written) {
173 return false;
174 }
175 if (ring_->empty()) {
176 ring_->resize(n_batches_);
177 }
178 // An heuristic for number of pre-fetched batches. We can make it part of BatchParam
179 // to let user adjust number of pre-fetched batches when needed.
180 uint32_t constexpr kPreFetch = 3;
181
182 size_t n_prefetch_batches = std::min(kPreFetch, n_batches_);
183 CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
184 std::size_t fetch_it = count_;
185
186 exce_.Rethrow();
187
188 for (std::size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
189 fetch_it %= n_batches_; // ring
190 if (ring_->at(fetch_it).valid()) {
191 continue;
192 }
193 auto const* self = this; // make sure it's const
194 CHECK_LT(fetch_it, cache_info_->offset.size());
195 ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() {
196 auto page = std::make_shared<S>();
197 this->exce_.Run([&] {
198 std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
199 auto name = self->cache_info_->ShardName();
200 auto [offset, length] = self->cache_info_->View(fetch_it);
201 auto fi = std::make_unique<common::PrivateMmapConstStream>(name, offset, length);
202 CHECK(fmt->Read(page.get(), fi.get()));
203 });
204 return page;
205 });
206 }
207
208 CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
209 n_prefetch_batches)
210 << "Sparse DMatrix assumes forward iteration.";
211
212 monitor_.Start("Wait");
213 page_ = (*ring_)[count_].get();
214 CHECK(!(*ring_)[count_].valid());
215 monitor_.Stop("Wait");
216
217 exce_.Rethrow();
218
219 return true;
220 }
221
222 void WriteCache() {
223 CHECK(!cache_info_->written);
224 common::Timer timer;
225 timer.Start();
226 std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
227
228 auto name = cache_info_->ShardName();
229 std::unique_ptr<common::AlignedFileWriteStream> fo;
230 if (this->Iter() == 0) {
231 fo = std::make_unique<common::AlignedFileWriteStream>(StringView{name}, "wb");
232 } else {
233 fo = std::make_unique<common::AlignedFileWriteStream>(StringView{name}, "ab");
234 }
235
236 auto bytes = fmt->Write(*page_, fo.get());
237
238 timer.Stop();
239 // Not entirely accurate, the kernels doesn't have to flush the data.
240 LOG(INFO) << static_cast<double>(bytes) / 1024.0 / 1024.0 << " MB written in "
241 << timer.ElapsedSeconds() << " seconds.";
242 cache_info_->Push(bytes);
243 }
244
245 virtual void Fetch() = 0;
246
247 public:
248 SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
249 std::shared_ptr<Cache> cache)
250 : missing_{missing},
251 nthreads_{nthreads},
252 n_features_{n_features},
253 n_batches_{n_batches},
254 cache_info_{std::move(cache)} {
255 monitor_.Init(typeid(S).name()); // not pretty, but works for basic profiling
256 }
257
258 SparsePageSourceImpl(SparsePageSourceImpl const &that) = delete;
259
260 ~SparsePageSourceImpl() override {
261 // Don't orphan the threads.
262 for (auto& fu : *ring_) {
263 if (fu.valid()) {
264 fu.get();
265 }
266 }
267 }
268
269 [[nodiscard]] uint32_t Iter() const { return count_; }
270
271 const S &operator*() const override {
272 CHECK(page_);
273 return *page_;
274 }
275
276 [[nodiscard]] std::shared_ptr<S const> Page() const override {
277 return page_;
278 }
279
280 [[nodiscard]] bool AtEnd() const override {
281 return at_end_;
282 }
283
284 virtual void Reset() {
285 TryLockGuard guard{single_threaded_};
286 at_end_ = false;
287 count_ = 0;
288 // Pre-fetch for the next round of iterations.
289 this->Fetch();
290 }
291};
292
293#if defined(XGBOOST_USE_CUDA)
294// Push data from CUDA.
295void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page);
296#else
297inline void DevicePush(DMatrixProxy*, float, SparsePage*) { common::AssertGPUSupport(); }
298#endif
299
300class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
301 // This is the source from the user.
303 DMatrixProxy* proxy_;
304 std::size_t base_row_id_{0};
305
306 void Fetch() final {
307 page_ = std::make_shared<SparsePage>();
308 if (!this->ReadCache()) {
309 bool type_error { false };
310 CHECK(proxy_);
311 HostAdapterDispatch(proxy_, [&](auto const &adapter_batch) {
312 page_->Push(adapter_batch, this->missing_, this->nthreads_);
313 }, &type_error);
314 if (type_error) {
315 DevicePush(proxy_, missing_, page_.get());
316 }
317 page_->SetBaseRowId(base_row_id_);
318 base_row_id_ += page_->Size();
319 n_batches_++;
320 this->WriteCache();
321 }
322 }
323
324 public:
327 DMatrixProxy *proxy, float missing, int nthreads,
328 bst_feature_t n_features, uint32_t n_batches, std::shared_ptr<Cache> cache)
329 : SparsePageSourceImpl(missing, nthreads, n_features, n_batches, cache),
330 iter_{iter}, proxy_{proxy} {
331 if (!cache_info_->written) {
332 iter_.Reset();
333 CHECK(iter_.Next()) << "Must have at least 1 batch.";
334 }
335 this->Fetch();
336 }
337
338 SparsePageSource& operator++() final {
339 TryLockGuard guard{single_threaded_};
340 count_++;
341 if (cache_info_->written) {
342 at_end_ = (count_ == n_batches_);
343 } else {
344 at_end_ = !iter_.Next();
345 }
346
347 if (at_end_) {
348 CHECK_EQ(cache_info_->offset.size(), n_batches_ + 1);
349 cache_info_->Commit();
350 if (n_batches_ != 0) {
351 CHECK_EQ(count_, n_batches_);
352 }
353 CHECK_GE(count_, 1);
354 proxy_ = nullptr;
355 } else {
356 this->Fetch();
357 }
358 return *this;
359 }
360
361 void Reset() override {
362 if (proxy_) {
363 TryLockGuard guard{single_threaded_};
364 iter_.Reset();
365 }
366 SparsePageSourceImpl::Reset();
367
368 TryLockGuard guard{single_threaded_};
369 base_row_id_ = 0;
370 }
371};
372
373// A mixin for advancing the iterator.
374template <typename S>
376 protected:
377 std::shared_ptr<SparsePageSource> source_;
379 // synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page
380 // so we avoid fetching it.
381 bool sync_{true};
382
383 public:
384 PageSourceIncMixIn(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
385 std::shared_ptr<Cache> cache, bool sync)
386 : Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
387
388 PageSourceIncMixIn& operator++() final {
389 TryLockGuard guard{this->single_threaded_};
390 if (sync_) {
391 ++(*source_);
392 }
393
394 ++this->count_;
395 this->at_end_ = this->count_ == this->n_batches_;
396
397 if (this->at_end_) {
398 this->cache_info_->Commit();
399 if (this->n_batches_ != 0) {
400 CHECK_EQ(this->count_, this->n_batches_);
401 }
402 CHECK_GE(this->count_, 1);
403 } else {
404 this->Fetch();
405 }
406
407 if (sync_) {
408 CHECK_EQ(source_->Iter(), this->count_);
409 }
410 return *this;
411 }
412};
413
414class CSCPageSource : public PageSourceIncMixIn<CSCPage> {
415 protected:
416 void Fetch() final {
417 if (!this->ReadCache()) {
418 auto const &csr = source_->Page();
419 this->page_.reset(new CSCPage{});
420 // we might be able to optimize this by merging transpose and pushcsc
421 this->page_->PushCSC(csr->GetTranspose(n_features_, nthreads_));
422 page_->SetBaseRowId(csr->base_rowid);
423 this->WriteCache();
424 }
425 }
426
427 public:
428 CSCPageSource(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
429 std::shared_ptr<Cache> cache, std::shared_ptr<SparsePageSource> source)
430 : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, true) {
431 this->source_ = source;
432 this->Fetch();
433 }
434};
435
436class SortedCSCPageSource : public PageSourceIncMixIn<SortedCSCPage> {
437 protected:
438 void Fetch() final {
439 if (!this->ReadCache()) {
440 auto const &csr = this->source_->Page();
441 this->page_.reset(new SortedCSCPage{});
442 // we might be able to optimize this by merging transpose and pushcsc
443 this->page_->PushCSC(csr->GetTranspose(n_features_, nthreads_));
444 CHECK_EQ(this->page_->Size(), n_features_);
445 CHECK_EQ(this->page_->data.Size(), csr->data.Size());
446 this->page_->SortRows(this->nthreads_);
447 page_->SetBaseRowId(csr->base_rowid);
448 this->WriteCache();
449 }
450 }
451
452 public:
453 SortedCSCPageSource(float missing, int nthreads, bst_feature_t n_features,
454 uint32_t n_batches, std::shared_ptr<Cache> cache,
455 std::shared_ptr<SparsePageSource> source)
456 : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, true) {
457 this->source_ = source;
458 this->Fetch();
459 }
460};
461} // namespace xgboost::data
462#endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
Definition data.h:447
Definition data.h:421
Definition data.h:437
In-memory storage unit of sparse batch, stored in CSR format.
Definition data.h:328
Definition sparse_page_source.h:414
Definition proxy_dmatrix.h:38
Definition proxy_dmatrix.h:22
Definition sparse_page_source.h:98
Definition sparse_page_source.h:375
Definition sparse_page_source.h:436
Base class for all page sources.
Definition sparse_page_source.h:142
Definition sparse_page_source.h:300
Definition sparse_page_source.h:85
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2019-2023, XGBoost Contributors.
Definition data.py:1
decltype(auto) HostAdapterDispatch(DMatrixProxy const *proxy, Fn fn, bool *type_error=nullptr)
Dispatch function call based on input type.
Definition proxy_dmatrix.h:131
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
Copyright 2014-2023, XGBoost Contributors.
exception class that will be thrown by default logger if DMLC_LOG_FATAL_THROW == 1
Definition logging.h:29
Definition string_view.h:15
Timing utility used to measure total method execution time over the lifetime of the containing object...
Definition timer.h:47
Definition timer.h:16
Information about the cache including path and page offsets.
Definition sparse_page_source.h:40
void Push(std::size_t n_bytes)
Record a page with size of n_bytes.
Definition sparse_page_source.h:64
auto View(std::size_t i) const
Returns the view start and length for the i^th page.
Definition sparse_page_source.h:68
void Commit()
Call this once the write for the cache is complete.
Definition sparse_page_source.h:76