46 std::vector<std::uint64_t> offset;
48 Cache(
bool w, std::string n, std::string fmt)
49 : written{w}, name{std::move(n)}, format{std::move(fmt)} {
53 static std::string ShardName(std::string name, std::string format) {
54 CHECK_EQ(format.front(),
'.');
58 [[nodiscard]] std::string ShardName()
const {
59 return ShardName(this->name, this->format);
64 void Push(std::size_t n_bytes) { offset.push_back(n_bytes); }
68 [[nodiscard]]
auto View(std::size_t i)
const {
69 std::uint64_t off = offset.at(i);
70 std::uint64_t len = offset.at(i + 1) - offset[i];
71 return std::pair{off, len};
78 std::partial_sum(offset.begin(), offset.end(), offset.begin());
100 std::atomic<bool> flag_{
false};
101 std::exception_ptr curr_exce_{
nullptr};
104 template <
typename Fn>
105 decltype(
auto) Run(Fn&& fn)
noexcept(
true) {
109 std::lock_guard<std::mutex> guard{mutex_};
111 curr_exce_ = std::current_exception();
114 }
catch (std::exception
const& e) {
115 std::lock_guard<std::mutex> guard{mutex_};
117 curr_exce_ = std::current_exception();
121 std::lock_guard<std::mutex> guard{mutex_};
123 curr_exce_ = std::current_exception();
127 return std::invoke_result_t<Fn>();
130 void Rethrow()
noexcept(
false) {
133 std::rethrow_exception(curr_exce_);
145 std::mutex single_threaded_;
147 std::shared_ptr<S> page_;
149 bool at_end_ {
false};
151 std::int32_t nthreads_;
154 std::uint32_t count_{0};
156 std::uint32_t n_batches_{0};
158 std::shared_ptr<Cache> cache_info_;
160 using Ring = std::vector<std::future<std::shared_ptr<S>>>;
163 std::unique_ptr<Ring> ring_{
new Ring};
172 if (!cache_info_->written) {
175 if (ring_->empty()) {
176 ring_->resize(n_batches_);
180 uint32_t
constexpr kPreFetch = 3;
182 size_t n_prefetch_batches = std::min(kPreFetch, n_batches_);
183 CHECK_GT(n_prefetch_batches, 0) <<
"total batches:" << n_batches_;
184 std::size_t fetch_it = count_;
188 for (std::size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
189 fetch_it %= n_batches_;
190 if (ring_->at(fetch_it).valid()) {
193 auto const* self =
this;
194 CHECK_LT(fetch_it, cache_info_->offset.size());
195 ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self,
this]() {
196 auto page = std::make_shared<S>();
197 this->exce_.Run([&] {
198 std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>(
"raw")};
199 auto name = self->cache_info_->ShardName();
200 auto [offset, length] = self->cache_info_->View(fetch_it);
201 auto fi = std::make_unique<common::PrivateMmapConstStream>(name, offset, length);
202 CHECK(fmt->Read(page.get(), fi.get()));
208 CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](
auto const& f) { return f.valid(); }),
210 <<
"Sparse DMatrix assumes forward iteration.";
212 monitor_.Start(
"Wait");
213 page_ = (*ring_)[count_].get();
214 CHECK(!(*ring_)[count_].valid());
215 monitor_.Stop(
"Wait");
223 CHECK(!cache_info_->written);
226 std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>(
"raw")};
228 auto name = cache_info_->ShardName();
229 std::unique_ptr<common::AlignedFileWriteStream> fo;
230 if (this->Iter() == 0) {
231 fo = std::make_unique<common::AlignedFileWriteStream>(
StringView{name},
"wb");
233 fo = std::make_unique<common::AlignedFileWriteStream>(
StringView{name},
"ab");
236 auto bytes = fmt->Write(*page_, fo.get());
240 LOG(INFO) <<
static_cast<double>(bytes) / 1024.0 / 1024.0 <<
" MB written in "
241 << timer.ElapsedSeconds() <<
" seconds.";
242 cache_info_->Push(bytes);
245 virtual void Fetch() = 0;
249 std::shared_ptr<Cache> cache)
252 n_features_{n_features},
253 n_batches_{n_batches},
254 cache_info_{std::move(cache)} {
255 monitor_.Init(
typeid(S).name());
262 for (
auto& fu : *ring_) {
269 [[nodiscard]] uint32_t Iter()
const {
return count_; }
271 const S &operator*()
const override {
276 [[nodiscard]] std::shared_ptr<S const> Page()
const override {
280 [[nodiscard]]
bool AtEnd()
const override {
284 virtual void Reset() {
304 std::size_t base_row_id_{0};
307 page_ = std::make_shared<SparsePage>();
308 if (!this->ReadCache()) {
309 bool type_error {
false };
312 page_->Push(adapter_batch, this->missing_, this->nthreads_);
315 DevicePush(proxy_, missing_, page_.get());
317 page_->SetBaseRowId(base_row_id_);
318 base_row_id_ += page_->Size();
328 bst_feature_t n_features, uint32_t n_batches, std::shared_ptr<Cache> cache)
330 iter_{iter}, proxy_{proxy} {
331 if (!cache_info_->written) {
333 CHECK(iter_.Next()) <<
"Must have at least 1 batch.";
341 if (cache_info_->written) {
342 at_end_ = (count_ == n_batches_);
344 at_end_ = !iter_.Next();
348 CHECK_EQ(cache_info_->offset.size(), n_batches_ + 1);
349 cache_info_->Commit();
350 if (n_batches_ != 0) {
351 CHECK_EQ(count_, n_batches_);
361 void Reset()
override {
366 SparsePageSourceImpl::Reset();
377 std::shared_ptr<SparsePageSource> source_;
385 std::shared_ptr<Cache> cache,
bool sync)
386 : Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {}
395 this->at_end_ = this->count_ == this->n_batches_;
398 this->cache_info_->Commit();
399 if (this->n_batches_ != 0) {
400 CHECK_EQ(this->count_, this->n_batches_);
402 CHECK_GE(this->count_, 1);
408 CHECK_EQ(source_->Iter(), this->count_);
417 if (!this->ReadCache()) {
418 auto const &csr = source_->Page();
419 this->page_.reset(
new CSCPage{});
421 this->page_->PushCSC(csr->GetTranspose(n_features_, nthreads_));
422 page_->SetBaseRowId(csr->base_rowid);
429 std::shared_ptr<Cache> cache, std::shared_ptr<SparsePageSource> source)
431 this->source_ = source;
439 if (!this->ReadCache()) {
440 auto const &csr = this->source_->Page();
443 this->page_->PushCSC(csr->GetTranspose(n_features_, nthreads_));
444 CHECK_EQ(this->page_->Size(), n_features_);
445 CHECK_EQ(this->page_->data.Size(), csr->data.Size());
446 this->page_->SortRows(this->nthreads_);
447 page_->SetBaseRowId(csr->base_rowid);
454 uint32_t n_batches, std::shared_ptr<Cache> cache,
455 std::shared_ptr<SparsePageSource> source)
457 this->source_ = source;
In-memory storage unit of sparse batch, stored in CSR format.
Definition data.h:328
Definition sparse_page_source.h:414
Definition sparse_page_source.h:375
Definition sparse_page_source.h:436
Base class for all page sources.
Definition sparse_page_source.h:142
Definition sparse_page_source.h:300
decltype(auto) HostAdapterDispatch(DMatrixProxy const *proxy, Fn fn, bool *type_error=nullptr)
Dispatch function call based on input type.
Definition proxy_dmatrix.h:131
Copyright 2014-2023, XGBoost Contributors.