Medial Code Documentation
Loading...
Searching...
No Matches
adapter.h
Go to the documentation of this file.
1
5#ifndef XGBOOST_DATA_ADAPTER_H_
6#define XGBOOST_DATA_ADAPTER_H_
7#include <dmlc/data.h>
8
9#include <algorithm>
10#include <cstddef> // for size_t
11#include <functional>
12#include <limits>
13#include <map>
14#include <memory>
15#include <string>
16#include <utility> // std::move
17#include <vector>
18
19#include "../c_api/c_api_error.h"
20#include "../common/error_msg.h" // for MaxFeatureSize
21#include "../common/math.h"
22#include "array_interface.h"
23#include "arrow-cdi.h"
24#include "xgboost/base.h"
25#include "xgboost/data.h"
26#include "xgboost/logging.h"
27#include "xgboost/span.h"
28#include "xgboost/string_view.h"
29
30namespace xgboost {
31namespace data {
32
76constexpr size_t kAdapterUnknownSize = std::numeric_limits<size_t >::max();
77
78struct COOTuple {
79 COOTuple() = default;
80 XGBOOST_DEVICE COOTuple(size_t row_idx, size_t column_idx, float value)
81 : row_idx(row_idx), column_idx(column_idx), value(value) {}
82
83 size_t row_idx{0};
84 size_t column_idx{0};
85 float value{0};
86};
87
89 float missing;
90
91 XGBOOST_DEVICE explicit IsValidFunctor(float missing) : missing(missing) {}
92
93 XGBOOST_DEVICE bool operator()(float value) const {
94 return !(common::CheckNAN(value) || value == missing);
95 }
96
97 XGBOOST_DEVICE bool operator()(const data::COOTuple& e) const {
98 return !(common::CheckNAN(e.value) || e.value == missing);
99 }
100
101 XGBOOST_DEVICE bool operator()(const Entry& e) const {
102 return !(common::CheckNAN(e.fvalue) || e.fvalue == missing);
103 }
104};
105
106namespace detail {
107
111template <typename DType>
113 public:
114 void BeforeFirst() override { counter_ = 0; }
115 bool Next() override {
116 if (counter_ == 0) {
117 counter_++;
118 return true;
119 }
120 return false;
121 }
122
123 private:
124 int counter_{0};
125};
126
130 public:
131 const float* Labels() const { return nullptr; }
132 const float* Weights() const { return nullptr; }
133 const uint64_t* Qid() const { return nullptr; }
134 const float* BaseMargin() const { return nullptr; }
135};
136
137}; // namespace detail
138
140 public:
141 class Line {
142 public:
143 Line(size_t row_idx, size_t size, const unsigned* feature_idx,
144 const float* values)
145 : row_idx_(row_idx),
146 size_(size),
147 feature_idx_(feature_idx),
148 values_(values) {}
149
150 size_t Size() const { return size_; }
151 COOTuple GetElement(size_t idx) const {
152 return COOTuple{row_idx_, feature_idx_[idx], values_[idx]};
153 }
154
155 private:
156 size_t row_idx_;
157 size_t size_;
158 const unsigned* feature_idx_;
159 const float* values_;
160 };
161 CSRAdapterBatch(const size_t* row_ptr, const unsigned* feature_idx,
162 const float* values, size_t num_rows, size_t, size_t)
163 : row_ptr_(row_ptr),
164 feature_idx_(feature_idx),
165 values_(values),
166 num_rows_(num_rows) {}
167 const Line GetLine(size_t idx) const {
168 size_t begin_offset = row_ptr_[idx];
169 size_t end_offset = row_ptr_[idx + 1];
170 return Line(idx, end_offset - begin_offset, &feature_idx_[begin_offset],
171 &values_[begin_offset]);
172 }
173 size_t Size() const { return num_rows_; }
174 static constexpr bool kIsRowMajor = true;
175
176 private:
177 const size_t* row_ptr_;
178 const unsigned* feature_idx_;
179 const float* values_;
180 size_t num_rows_;
181};
182
183class CSRAdapter : public detail::SingleBatchDataIter<CSRAdapterBatch> {
184 public:
185 CSRAdapter(const size_t* row_ptr, const unsigned* feature_idx,
186 const float* values, size_t num_rows, size_t num_elements,
187 size_t num_features)
188 : batch_(row_ptr, feature_idx, values, num_rows, num_elements,
189 num_features),
190 num_rows_(num_rows),
191 num_columns_(num_features) {}
192 const CSRAdapterBatch& Value() const override { return batch_; }
193 size_t NumRows() const { return num_rows_; }
194 size_t NumColumns() const { return num_columns_; }
195
196 private:
197 CSRAdapterBatch batch_;
198 size_t num_rows_;
199 size_t num_columns_;
200};
201
203 public:
204 DenseAdapterBatch(const float* values, size_t num_rows, size_t num_features)
205 : values_(values),
206 num_rows_(num_rows),
207 num_features_(num_features) {}
208
209 private:
210 class Line {
211 public:
212 Line(const float* values, size_t size, size_t row_idx)
213 : row_idx_(row_idx), size_(size), values_(values) {}
214
215 size_t Size() const { return size_; }
216 COOTuple GetElement(size_t idx) const {
217 return COOTuple{row_idx_, idx, values_[idx]};
218 }
219
220 private:
221 size_t row_idx_;
222 size_t size_;
223 const float* values_;
224 };
225
226 public:
227 size_t Size() const { return num_rows_; }
228 const Line GetLine(size_t idx) const {
229 return Line(values_ + idx * num_features_, num_features_, idx);
230 }
231 static constexpr bool kIsRowMajor = true;
232
233 private:
234 const float* values_;
235 size_t num_rows_;
236 size_t num_features_;
237};
238
239class DenseAdapter : public detail::SingleBatchDataIter<DenseAdapterBatch> {
240 public:
241 DenseAdapter(const float* values, size_t num_rows, size_t num_features)
242 : batch_(values, num_rows, num_features),
243 num_rows_(num_rows),
244 num_columns_(num_features) {}
245 const DenseAdapterBatch& Value() const override { return batch_; }
246
247 size_t NumRows() const { return num_rows_; }
248 size_t NumColumns() const { return num_columns_; }
249
250 private:
251 DenseAdapterBatch batch_;
252 size_t num_rows_;
253 size_t num_columns_;
254};
255
257 public:
258 static constexpr bool kIsRowMajor = true;
259
260 private:
261 ArrayInterface<2> array_interface_;
262
263 class Line {
264 ArrayInterface<2> array_interface_;
265 size_t ridx_;
266
267 public:
268 Line(ArrayInterface<2> array_interface, size_t ridx)
269 : array_interface_{std::move(array_interface)}, ridx_{ridx} {}
270
271 size_t Size() const { return array_interface_.Shape(1); }
272
273 COOTuple GetElement(size_t idx) const {
274 return {ridx_, idx, array_interface_(ridx_, idx)};
275 }
276 };
277
278 public:
279 ArrayAdapterBatch() = default;
280 Line const GetLine(size_t idx) const {
281 return Line{array_interface_, idx};
282 }
283
284 size_t NumRows() const { return array_interface_.Shape(0); }
285 size_t NumCols() const { return array_interface_.Shape(1); }
286 size_t Size() const { return this->NumRows(); }
287
288 explicit ArrayAdapterBatch(ArrayInterface<2> array_interface)
289 : array_interface_{std::move(array_interface)} {}
290};
291
297class ArrayAdapter : public detail::SingleBatchDataIter<ArrayAdapterBatch> {
298 public:
299 explicit ArrayAdapter(StringView array_interface) {
300 auto j = Json::Load(array_interface);
301 array_interface_ = ArrayInterface<2>(get<Object const>(j));
302 batch_ = ArrayAdapterBatch{array_interface_};
303 }
304 [[nodiscard]] ArrayAdapterBatch const& Value() const override { return batch_; }
305 [[nodiscard]] std::size_t NumRows() const { return array_interface_.Shape(0); }
306 [[nodiscard]] std::size_t NumColumns() const { return array_interface_.Shape(1); }
307
308 private:
309 ArrayAdapterBatch batch_;
310 ArrayInterface<2> array_interface_;
311};
312
314 ArrayInterface<1> indptr_;
315 ArrayInterface<1> indices_;
316 ArrayInterface<1> values_;
317 bst_feature_t n_features_;
318
319 class Line {
320 ArrayInterface<1> indices_;
321 ArrayInterface<1> values_;
322 size_t ridx_;
323 size_t offset_;
324
325 public:
326 Line(ArrayInterface<1> indices, ArrayInterface<1> values, size_t ridx,
327 size_t offset)
328 : indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx},
329 offset_{offset} {}
330
331 COOTuple GetElement(std::size_t idx) const {
332 return {ridx_, TypedIndex<std::size_t, 1>{indices_}(offset_ + idx), values_(offset_ + idx)};
333 }
334
335 size_t Size() const {
336 return values_.Shape(0);
337 }
338 };
339
340 public:
341 static constexpr bool kIsRowMajor = true;
342
343 public:
344 CSRArrayAdapterBatch() = default;
346 ArrayInterface<1> values, bst_feature_t n_features)
347 : indptr_{std::move(indptr)},
348 indices_{std::move(indices)},
349 values_{std::move(values)},
350 n_features_{n_features} {
351 }
352
353 size_t NumRows() const {
354 size_t size = indptr_.Shape(0);
355 size = size == 0 ? 0 : size - 1;
356 return size;
357 }
358 size_t NumCols() const { return n_features_; }
359 size_t Size() const { return this->NumRows(); }
360
361 Line const GetLine(size_t idx) const {
362 auto begin_no_stride = TypedIndex<size_t, 1>{indptr_}(idx);
363 auto end_no_stride = TypedIndex<size_t, 1>{indptr_}(idx + 1);
364
365 auto indices = indices_;
366 auto values = values_;
367 // Slice indices and values, stride remains unchanged since this is slicing by
368 // specific index.
369 auto offset = indices.strides[0] * begin_no_stride;
370
371 indices.shape[0] = end_no_stride - begin_no_stride;
372 values.shape[0] = end_no_stride - begin_no_stride;
373
374 return Line{indices, values, idx, offset};
375 }
376};
377
383class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch> {
384 public:
385 CSRArrayAdapter(StringView indptr, StringView indices, StringView values,
386 size_t num_cols)
387 : indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} {
388 batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_,
389 static_cast<bst_feature_t>(num_cols_)};
390 }
391
392 CSRArrayAdapterBatch const& Value() const override {
393 return batch_;
394 }
395 size_t NumRows() const {
396 size_t size = indptr_.Shape(0);
397 size = size == 0 ? 0 : size - 1;
398 return size;
399 }
400 size_t NumColumns() const { return num_cols_; }
401
402 private:
403 CSRArrayAdapterBatch batch_;
404 ArrayInterface<1> indptr_;
405 ArrayInterface<1> indices_;
406 ArrayInterface<1> values_;
407 size_t num_cols_;
408};
409
411 public:
412 CSCAdapterBatch(const size_t* col_ptr, const unsigned* row_idx,
413 const float* values, size_t num_features)
414 : col_ptr_(col_ptr),
415 row_idx_(row_idx),
416 values_(values),
417 num_features_(num_features) {}
418
419 private:
420 class Line {
421 public:
422 Line(size_t col_idx, size_t size, const unsigned* row_idx,
423 const float* values)
424 : col_idx_(col_idx), size_(size), row_idx_(row_idx), values_(values) {}
425
426 size_t Size() const { return size_; }
427 COOTuple GetElement(size_t idx) const {
428 return COOTuple{row_idx_[idx], col_idx_, values_[idx]};
429 }
430
431 private:
432 size_t col_idx_;
433 size_t size_;
434 const unsigned* row_idx_;
435 const float* values_;
436 };
437
438 public:
439 size_t Size() const { return num_features_; }
440 const Line GetLine(size_t idx) const {
441 size_t begin_offset = col_ptr_[idx];
442 size_t end_offset = col_ptr_[idx + 1];
443 return Line(idx, end_offset - begin_offset, &row_idx_[begin_offset],
444 &values_[begin_offset]);
445 }
446 static constexpr bool kIsRowMajor = false;
447
448 private:
449 const size_t* col_ptr_;
450 const unsigned* row_idx_;
451 const float* values_;
452 size_t num_features_;
453};
454
455class CSCAdapter : public detail::SingleBatchDataIter<CSCAdapterBatch> {
456 public:
457 CSCAdapter(const size_t* col_ptr, const unsigned* row_idx,
458 const float* values, size_t num_features, size_t num_rows)
459 : batch_(col_ptr, row_idx, values, num_features),
460 num_rows_(num_rows),
461 num_columns_(num_features) {}
462 const CSCAdapterBatch& Value() const override { return batch_; }
463
464 // JVM package sends 0 as unknown
465 size_t NumRows() const {
466 return num_rows_ == 0 ? kAdapterUnknownSize : num_rows_;
467 }
468 size_t NumColumns() const { return num_columns_; }
469
470 private:
471 CSCAdapterBatch batch_;
472 size_t num_rows_;
473 size_t num_columns_;
474};
475
477 ArrayInterface<1> indptr_;
478 ArrayInterface<1> indices_;
479 ArrayInterface<1> values_;
480
481 class Line {
482 std::size_t column_idx_;
483 ArrayInterface<1> row_idx_;
484 ArrayInterface<1> values_;
485 std::size_t offset_;
486
487 public:
488 Line(std::size_t idx, ArrayInterface<1> row_idx, ArrayInterface<1> values, std::size_t offset)
489 : column_idx_{idx},
490 row_idx_{std::move(row_idx)},
491 values_{std::move(values)},
492 offset_{offset} {}
493
494 std::size_t Size() const { return values_.Shape(0); }
495 COOTuple GetElement(std::size_t idx) const {
496 return {TypedIndex<std::size_t, 1>{row_idx_}(offset_ + idx), column_idx_,
497 values_(offset_ + idx)};
498 }
499 };
500
501 public:
502 static constexpr bool kIsRowMajor = false;
503
505 ArrayInterface<1> values)
506 : indptr_{std::move(indptr)}, indices_{std::move(indices)}, values_{std::move(values)} {}
507
508 std::size_t Size() const { return indptr_.n - 1; }
509 Line GetLine(std::size_t idx) const {
510 auto begin_no_stride = TypedIndex<std::size_t, 1>{indptr_}(idx);
511 auto end_no_stride = TypedIndex<std::size_t, 1>{indptr_}(idx + 1);
512
513 auto indices = indices_;
514 auto values = values_;
515 // Slice indices and values, stride remains unchanged since this is slicing by
516 // specific index.
517 auto offset = indices.strides[0] * begin_no_stride;
518 indices.shape[0] = end_no_stride - begin_no_stride;
519 values.shape[0] = end_no_stride - begin_no_stride;
520
521 return Line{idx, indices, values, offset};
522 }
523};
524
528class CSCArrayAdapter : public detail::SingleBatchDataIter<CSCArrayAdapterBatch> {
529 ArrayInterface<1> indptr_;
530 ArrayInterface<1> indices_;
531 ArrayInterface<1> values_;
532 size_t num_rows_;
534
535 public:
536 CSCArrayAdapter(StringView indptr, StringView indices, StringView values, std::size_t num_rows)
537 : indptr_{indptr},
538 indices_{indices},
539 values_{values},
540 num_rows_{num_rows},
541 batch_{CSCArrayAdapterBatch{indptr_, indices_, values_}} {}
542
543 // JVM package sends 0 as unknown
544 size_t NumRows() const { return num_rows_ == 0 ? kAdapterUnknownSize : num_rows_; }
545 size_t NumColumns() const { return indptr_.n - 1; }
546 const CSCArrayAdapterBatch& Value() const override { return batch_; }
547};
548
550 enum class DTType : std::uint8_t {
551 kFloat32 = 0,
552 kFloat64 = 1,
553 kBool8 = 2,
554 kInt32 = 3,
555 kInt8 = 4,
556 kInt16 = 5,
557 kInt64 = 6,
558 kUnknown = 7
559 };
560
561 static DTType DTGetType(std::string type_string) {
562 if (type_string == "float32") {
563 return DTType::kFloat32;
564 } else if (type_string == "float64") {
565 return DTType::kFloat64;
566 } else if (type_string == "bool8") {
567 return DTType::kBool8;
568 } else if (type_string == "int32") {
569 return DTType::kInt32;
570 } else if (type_string == "int8") {
571 return DTType::kInt8;
572 } else if (type_string == "int16") {
573 return DTType::kInt16;
574 } else if (type_string == "int64") {
575 return DTType::kInt64;
576 } else {
577 LOG(FATAL) << "Unknown data table type.";
578 return DTType::kUnknown;
579 }
580 }
581
582 public:
583 DataTableAdapterBatch(void const* const* const data, char const* const* feature_stypes,
584 std::size_t num_rows, std::size_t num_features)
585 : data_(data), num_rows_(num_rows) {
586 CHECK(feature_types_.empty());
587 std::transform(feature_stypes, feature_stypes + num_features,
588 std::back_inserter(feature_types_),
589 [](char const* stype) { return DTGetType(stype); });
590 }
591
592 private:
593 class Line {
594 std::size_t row_idx_;
595 void const* const* const data_;
596 std::vector<DTType> const& feature_types_;
597
598 float DTGetValue(void const* column, DTType dt_type, std::size_t ridx) const {
599 float missing = std::numeric_limits<float>::quiet_NaN();
600 switch (dt_type) {
601 case DTType::kFloat32: {
602 float val = reinterpret_cast<const float*>(column)[ridx];
603 return std::isfinite(val) ? val : missing;
604 }
605 case DTType::kFloat64: {
606 double val = reinterpret_cast<const double*>(column)[ridx];
607 return std::isfinite(val) ? static_cast<float>(val) : missing;
608 }
609 case DTType::kBool8: {
610 bool val = reinterpret_cast<const bool*>(column)[ridx];
611 return static_cast<float>(val);
612 }
613 case DTType::kInt32: {
614 int32_t val = reinterpret_cast<const int32_t*>(column)[ridx];
615 return val != (-2147483647 - 1) ? static_cast<float>(val) : missing;
616 }
617 case DTType::kInt8: {
618 int8_t val = reinterpret_cast<const int8_t*>(column)[ridx];
619 return val != -128 ? static_cast<float>(val) : missing;
620 }
621 case DTType::kInt16: {
622 int16_t val = reinterpret_cast<const int16_t*>(column)[ridx];
623 return val != -32768 ? static_cast<float>(val) : missing;
624 }
625 case DTType::kInt64: {
626 int64_t val = reinterpret_cast<const int64_t*>(column)[ridx];
627 return val != -9223372036854775807 - 1 ? static_cast<float>(val) : missing;
628 }
629 default: {
630 LOG(FATAL) << "Unknown data table type.";
631 return 0.0f;
632 }
633 }
634 }
635
636 public:
637 Line(std::size_t ridx, void const* const* const data, std::vector<DTType> const& ft)
638 : row_idx_{ridx}, data_{data}, feature_types_{ft} {}
639 std::size_t Size() const { return feature_types_.size(); }
640 COOTuple GetElement(std::size_t idx) const {
641 return COOTuple{row_idx_, idx, DTGetValue(data_[idx], feature_types_[idx], row_idx_)};
642 }
643 };
644
645 public:
646 size_t Size() const { return num_rows_; }
647 const Line GetLine(std::size_t ridx) const { return {ridx, data_, feature_types_}; }
648 static constexpr bool kIsRowMajor = true;
649
650 private:
651 void const* const* const data_;
652
653 std::vector<DTType> feature_types_;
654 std::size_t num_rows_;
655};
656
657class DataTableAdapter : public detail::SingleBatchDataIter<DataTableAdapterBatch> {
658 public:
659 DataTableAdapter(void** data, const char** feature_stypes, std::size_t num_rows,
660 std::size_t num_features)
661 : batch_(data, feature_stypes, num_rows, num_features),
662 num_rows_(num_rows),
663 num_columns_(num_features) {}
664 const DataTableAdapterBatch& Value() const override { return batch_; }
665 std::size_t NumRows() const { return num_rows_; }
666 std::size_t NumColumns() const { return num_columns_; }
667
668 private:
669 DataTableAdapterBatch batch_;
670 std::size_t num_rows_;
671 std::size_t num_columns_;
672};
673
675 public:
676 class Line {
677 public:
678 Line(size_t row_idx, const uint32_t *feature_idx, const float *value,
679 size_t size)
680 : row_idx_(row_idx),
681 feature_idx_(feature_idx),
682 value_(value),
683 size_(size) {}
684
685 size_t Size() { return size_; }
686 COOTuple GetElement(size_t idx) {
687 float fvalue = value_ == nullptr ? 1.0f : value_[idx];
688 return COOTuple{row_idx_, feature_idx_[idx], fvalue};
689 }
690
691 private:
692 size_t row_idx_;
693 const uint32_t* feature_idx_;
694 const float* value_;
695 size_t size_;
696 };
697 FileAdapterBatch(const dmlc::RowBlock<uint32_t>* block, size_t row_offset)
698 : block_(block), row_offset_(row_offset) {}
699 Line GetLine(size_t idx) const {
700 auto begin = block_->offset[idx];
701 auto end = block_->offset[idx + 1];
702 return Line{idx + row_offset_, &block_->index[begin], &block_->value[begin],
703 end - begin};
704 }
705 const float* Labels() const { return block_->label; }
706 const float* Weights() const { return block_->weight; }
707 const uint64_t* Qid() const { return block_->qid; }
708 const float* BaseMargin() const { return nullptr; }
709
710 size_t Size() const { return block_->size; }
711 static constexpr bool kIsRowMajor = true;
712
713 private:
714 const dmlc::RowBlock<uint32_t>* block_;
715 size_t row_offset_;
716};
717
720class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
721 public:
722 explicit FileAdapter(dmlc::Parser<uint32_t>* parser) : parser_(parser) {}
723
724 const FileAdapterBatch& Value() const override { return *batch_.get(); }
725 void BeforeFirst() override {
726 batch_.reset();
727 parser_->BeforeFirst();
728 row_offset_ = 0;
729 }
730 bool Next() override {
731 bool next = parser_->Next();
732 batch_.reset(new FileAdapterBatch(&parser_->Value(), row_offset_));
733 row_offset_ += parser_->Value().size;
734 return next;
735 }
736 // Indicates a number of rows/columns must be inferred
737 size_t NumRows() const { return kAdapterUnknownSize; }
738 size_t NumColumns() const { return kAdapterUnknownSize; }
739
740 private:
741 size_t row_offset_{0};
742 std::unique_ptr<FileAdapterBatch> batch_;
743 dmlc::Parser<uint32_t>* parser_;
744};
745
748template <typename DataIterHandle, typename XGBCallbackDataIterNext, typename XGBoostBatchCSR>
749class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
750 public:
751 IteratorAdapter(DataIterHandle data_handle, XGBCallbackDataIterNext* next_callback)
752 : columns_{data::kAdapterUnknownSize},
753 data_handle_(data_handle),
754 next_callback_(next_callback) {}
755
756 // override functions
757 void BeforeFirst() override {
758 CHECK(at_first_) << "Cannot reset IteratorAdapter";
759 }
760
761 bool Next() override {
762 if ((*next_callback_)(
763 data_handle_,
764 [](void *handle, XGBoostBatchCSR batch) -> int {
765 API_BEGIN();
766 static_cast<IteratorAdapter *>(handle)->SetData(batch);
767 API_END();
768 },
769 this) != 0) {
770 at_first_ = false;
771 return true;
772 } else {
773 return false;
774 }
775 }
776
777 FileAdapterBatch const& Value() const override {
778 return *batch_.get();
779 }
780
781 // callback to set the data
782 void SetData(const XGBoostBatchCSR& batch) {
783 offset_.clear();
784 label_.clear();
785 weight_.clear();
786 index_.clear();
787 value_.clear();
788 offset_.insert(offset_.end(), batch.offset, batch.offset + batch.size + 1);
789
790 if (batch.label != nullptr) {
791 label_.insert(label_.end(), batch.label, batch.label + batch.size);
792 }
793 if (batch.weight != nullptr) {
794 weight_.insert(weight_.end(), batch.weight, batch.weight + batch.size);
795 }
796 if (batch.index != nullptr) {
797 index_.insert(index_.end(), batch.index + offset_[0],
798 batch.index + offset_.back());
799 }
800 if (batch.value != nullptr) {
801 value_.insert(value_.end(), batch.value + offset_[0],
802 batch.value + offset_.back());
803 }
804 if (offset_[0] != 0) {
805 size_t base = offset_[0];
806 for (size_t &item : offset_) {
807 item -= base;
808 }
809 }
810 CHECK(columns_ == data::kAdapterUnknownSize || columns_ == batch.columns)
811 << "Number of columns between batches changed from " << columns_
812 << " to " << batch.columns;
813
814 columns_ = batch.columns;
815 block_.size = batch.size;
816
817 block_.offset = dmlc::BeginPtr(offset_);
818 block_.label = dmlc::BeginPtr(label_);
819 block_.weight = dmlc::BeginPtr(weight_);
820 block_.qid = nullptr;
821 block_.field = nullptr;
822 block_.index = dmlc::BeginPtr(index_);
823 block_.value = dmlc::BeginPtr(value_);
824
825 batch_.reset(new FileAdapterBatch(&block_, row_offset_));
826 row_offset_ += offset_.size() - 1;
827 }
828
829 size_t NumColumns() const { return columns_; }
830 size_t NumRows() const { return kAdapterUnknownSize; }
831
832 private:
833 std::vector<size_t> offset_;
834 std::vector<dmlc::real_t> label_;
835 std::vector<dmlc::real_t> weight_;
836 std::vector<uint32_t> index_;
837 std::vector<dmlc::real_t> value_;
838
839 size_t columns_;
840 size_t row_offset_{0};
841 // at the beginning.
842 bool at_first_{true};
843 // handle to the iterator,
844 DataIterHandle data_handle_;
845 // call back to get the data.
846 XGBCallbackDataIterNext *next_callback_;
847 // internal Rowblock
849 std::unique_ptr<FileAdapterBatch> batch_;
850};
851
852enum ColumnDType : uint8_t {
853 kUnknown,
854 kInt8,
855 kUInt8,
856 kInt16,
857 kUInt16,
858 kInt32,
859 kUInt32,
860 kInt64,
861 kUInt64,
862 kFloat,
863 kDouble
864};
865
866class Column {
867 public:
868 Column() = default;
869
870 Column(size_t col_idx, size_t length, size_t null_count, const uint8_t* bitmap)
871 : col_idx_{col_idx}, length_{length}, null_count_{null_count}, bitmap_{bitmap} {}
872
873 virtual ~Column() = default;
874
875 Column(const Column&) = delete;
876 Column& operator=(const Column&) = delete;
877 Column(Column&&) = delete;
878 Column& operator=(Column&&) = delete;
879
880 // whether the valid bit is set for this element
881 bool IsValid(size_t row_idx) const {
882 return (!bitmap_ || (bitmap_[row_idx/8] & (1 << (row_idx%8))));
883 }
884
885 virtual COOTuple GetElement(size_t row_idx) const = 0;
886
887 virtual bool IsValidElement(size_t row_idx) const = 0;
888
889 virtual std::vector<float> AsFloatVector() const = 0;
890
891 virtual std::vector<uint64_t> AsUint64Vector() const = 0;
892
893 size_t Length() const { return length_; }
894
895 protected:
896 size_t col_idx_;
897 size_t length_;
898 size_t null_count_;
899 const uint8_t* bitmap_;
900};
901
902// Only columns of primitive types are supported. An ArrowColumnarBatch is a
903// collection of std::shared_ptr<PrimitiveColumn>. These columns can be of different data types.
904// Hence, PrimitiveColumn is a class template; and all concrete PrimitiveColumns
905// derive from the abstract class Column.
906template <typename T>
907class PrimitiveColumn : public Column {
908 static constexpr float kNaN = std::numeric_limits<float>::quiet_NaN();
909
910 public:
911 PrimitiveColumn(size_t idx, size_t length, size_t null_count,
912 const uint8_t* bitmap, const T* data, float missing)
913 : Column{idx, length, null_count, bitmap}, data_{data}, missing_{missing} {}
914
915 COOTuple GetElement(size_t row_idx) const override {
916 CHECK(data_ && row_idx < length_) << "Column is empty or out-of-bound index of the column";
917 return { row_idx, col_idx_, IsValidElement(row_idx) ?
918 static_cast<float>(data_[row_idx]) : kNaN };
919 }
920
921 bool IsValidElement(size_t row_idx) const override {
922 // std::isfinite needs to cast to double to prevent msvc report error
923 return IsValid(row_idx)
924 && std::isfinite(static_cast<double>(data_[row_idx]))
925 && static_cast<float>(data_[row_idx]) != missing_;
926 }
927
928 std::vector<float> AsFloatVector() const override {
929 CHECK(data_) << "Column is empty";
930 std::vector<float> fv(length_);
931 std::transform(data_, data_ + length_, fv.begin(),
932 [](T v) { return static_cast<float>(v); });
933 return fv;
934 }
935
936 std::vector<uint64_t> AsUint64Vector() const override {
937 CHECK(data_) << "Column is empty";
938 std::vector<uint64_t> iv(length_);
939 std::transform(data_, data_ + length_, iv.begin(),
940 [](T v) { return static_cast<uint64_t>(v); });
941 return iv;
942 }
943
944 private:
945 const T* data_;
946 float missing_; // user specified missing value
947};
948
950 // data type of the column
951 ColumnDType type{ColumnDType::kUnknown};
952 // location of the column in an Arrow record batch
953 int64_t loc{-1};
954};
955
957 std::vector<ColumnarMetaInfo> columns;
958
959 // map Arrow format strings to types
960 static ColumnDType FormatMap(char const* format_str) {
961 CHECK(format_str) << "Format string cannot be empty";
962 switch (format_str[0]) {
963 case 'c':
964 return ColumnDType::kInt8;
965 case 'C':
966 return ColumnDType::kUInt8;
967 case 's':
968 return ColumnDType::kInt16;
969 case 'S':
970 return ColumnDType::kUInt16;
971 case 'i':
972 return ColumnDType::kInt32;
973 case 'I':
974 return ColumnDType::kUInt32;
975 case 'l':
976 return ColumnDType::kInt64;
977 case 'L':
978 return ColumnDType::kUInt64;
979 case 'f':
980 return ColumnDType::kFloat;
981 case 'g':
982 return ColumnDType::kDouble;
983 default:
984 CHECK(false) << "Column data type not supported by XGBoost";
985 return ColumnDType::kUnknown;
986 }
987 }
988
989 void Import(struct ArrowSchema *schema) {
990 if (schema) {
991 CHECK(std::string(schema->format) == "+s"); // NOLINT
992 CHECK(columns.empty());
993 for (auto i = 0; i < schema->n_children; ++i) {
994 std::string name{schema->children[i]->name};
995 ColumnDType type = FormatMap(schema->children[i]->format);
996 ColumnarMetaInfo col_info{type, i};
997 columns.push_back(col_info);
998 }
999 if (schema->release) {
1000 schema->release(schema);
1001 }
1002 }
1003 }
1004};
1005
1007 public:
1008 ArrowColumnarBatch(struct ArrowArray *rb, struct ArrowSchemaImporter* schema)
1009 : rb_{rb}, schema_{schema} {
1010 CHECK(rb_) << "Cannot import non-existent record batch";
1011 CHECK(!schema_->columns.empty()) << "Cannot import record batch without a schema";
1012 }
1013
1014 size_t Import(float missing) {
1015 auto& infov = schema_->columns;
1016 for (size_t i = 0; i < infov.size(); ++i) {
1017 columns_.push_back(CreateColumn(i, infov[i], missing));
1018 }
1019
1020 // Compute the starting location for every row in this batch
1021 auto batch_size = rb_->length;
1022 auto num_columns = columns_.size();
1023 row_offsets_.resize(batch_size + 1, 0);
1024 for (auto i = 0; i < batch_size; ++i) {
1025 row_offsets_[i+1] = row_offsets_[i];
1026 for (size_t j = 0; j < num_columns; ++j) {
1027 if (GetColumn(j).IsValidElement(i)) {
1028 row_offsets_[i+1]++;
1029 }
1030 }
1031 }
1032 // return number of elements in the batch
1033 return row_offsets_.back();
1034 }
1035
1036 ArrowColumnarBatch(const ArrowColumnarBatch&) = delete;
1037 ArrowColumnarBatch& operator=(const ArrowColumnarBatch&) = delete;
1039 ArrowColumnarBatch& operator=(ArrowColumnarBatch&&) = delete;
1040
1041 virtual ~ArrowColumnarBatch() {
1042 if (rb_ && rb_->release) {
1043 rb_->release(rb_);
1044 rb_ = nullptr;
1045 }
1046 columns_.clear();
1047 }
1048
1049 size_t Size() const { return rb_ ? rb_->length : 0; }
1050
1051 size_t NumColumns() const { return columns_.size(); }
1052
1053 size_t NumElements() const { return row_offsets_.back(); }
1054
1055 const Column& GetColumn(size_t col_idx) const {
1056 return *columns_[col_idx];
1057 }
1058
1059 void ShiftRowOffsets(size_t batch_offset) {
1060 std::transform(row_offsets_.begin(), row_offsets_.end(), row_offsets_.begin(),
1061 [=](size_t c) { return c + batch_offset; });
1062 }
1063
1064 const std::vector<size_t>& RowOffsets() const { return row_offsets_; }
1065
1066 private:
1067 std::shared_ptr<Column> CreateColumn(size_t idx,
1068 ColumnarMetaInfo info,
1069 float missing) const {
1070 if (info.loc < 0) {
1071 return nullptr;
1072 }
1073
1074 auto loc_in_batch = info.loc;
1075 auto length = rb_->length;
1076 auto null_count = rb_->null_count;
1077 auto buffers0 = rb_->children[loc_in_batch]->buffers[0];
1078 auto buffers1 = rb_->children[loc_in_batch]->buffers[1];
1079 const uint8_t* bitmap = buffers0 ? reinterpret_cast<const uint8_t*>(buffers0) : nullptr;
1080 const uint8_t* data = buffers1 ? reinterpret_cast<const uint8_t*>(buffers1) : nullptr;
1081
1082 // if null_count is not computed, compute it here
1083 if (null_count < 0) {
1084 if (!bitmap) {
1085 null_count = 0;
1086 } else {
1087 null_count = length;
1088 for (auto i = 0; i < length; ++i) {
1089 if (bitmap[i/8] & (1 << (i%8))) {
1090 null_count--;
1091 }
1092 }
1093 }
1094 }
1095
1096 switch (info.type) {
1097 case ColumnDType::kInt8:
1098 return std::make_shared<PrimitiveColumn<int8_t>>(
1099 idx, length, null_count, bitmap,
1100 reinterpret_cast<const int8_t*>(data), missing);
1101 case ColumnDType::kUInt8:
1102 return std::make_shared<PrimitiveColumn<uint8_t>>(
1103 idx, length, null_count, bitmap, data, missing);
1104 case ColumnDType::kInt16:
1105 return std::make_shared<PrimitiveColumn<int16_t>>(
1106 idx, length, null_count, bitmap,
1107 reinterpret_cast<const int16_t*>(data), missing);
1108 case ColumnDType::kUInt16:
1109 return std::make_shared<PrimitiveColumn<uint16_t>>(
1110 idx, length, null_count, bitmap,
1111 reinterpret_cast<const uint16_t*>(data), missing);
1112 case ColumnDType::kInt32:
1113 return std::make_shared<PrimitiveColumn<int32_t>>(
1114 idx, length, null_count, bitmap,
1115 reinterpret_cast<const int32_t*>(data), missing);
1116 case ColumnDType::kUInt32:
1117 return std::make_shared<PrimitiveColumn<uint32_t>>(
1118 idx, length, null_count, bitmap,
1119 reinterpret_cast<const uint32_t*>(data), missing);
1120 case ColumnDType::kInt64:
1121 return std::make_shared<PrimitiveColumn<int64_t>>(
1122 idx, length, null_count, bitmap,
1123 reinterpret_cast<const int64_t*>(data), missing);
1124 case ColumnDType::kUInt64:
1125 return std::make_shared<PrimitiveColumn<uint64_t>>(
1126 idx, length, null_count, bitmap,
1127 reinterpret_cast<const uint64_t*>(data), missing);
1128 case ColumnDType::kFloat:
1129 return std::make_shared<PrimitiveColumn<float>>(
1130 idx, length, null_count, bitmap,
1131 reinterpret_cast<const float*>(data), missing);
1132 case ColumnDType::kDouble:
1133 return std::make_shared<PrimitiveColumn<double>>(
1134 idx, length, null_count, bitmap,
1135 reinterpret_cast<const double*>(data), missing);
1136 default:
1137 return nullptr;
1138 }
1139 }
1140
1141 struct ArrowArray* rb_;
1142 struct ArrowSchemaImporter* schema_;
1143 std::vector<std::shared_ptr<Column>> columns_;
1144 std::vector<size_t> row_offsets_;
1145};
1146
1147using ArrowColumnarBatchVec = std::vector<std::unique_ptr<ArrowColumnarBatch>>;
1148class RecordBatchesIterAdapter: public dmlc::DataIter<ArrowColumnarBatchVec> {
1149 public:
1150 RecordBatchesIterAdapter(XGDMatrixCallbackNext* next_callback, int nbatch)
1151 : next_callback_{next_callback}, nbatches_{nbatch} {}
1152
1153 void BeforeFirst() override {
1154 CHECK(at_first_) << "Cannot reset RecordBatchesIterAdapter";
1155 }
1156
1157 bool Next() override {
1158 batches_.clear();
1159 while (batches_.size() < static_cast<size_t>(nbatches_) && (*next_callback_)(this) != 0) {
1160 at_first_ = false;
1161 }
1162
1163 if (batches_.size() > 0) {
1164 return true;
1165 } else {
1166 return false;
1167 }
1168 }
1169
1170 void SetData(struct ArrowArray* rb, struct ArrowSchema* schema) {
1171 // Schema is only imported once at the beginning, regardless how many
1172 // baches are comming.
1173 // But even schema is not imported we still need to release its C data
1174 // exported from Arrow.
1175 if (at_first_ && schema) {
1176 schema_.Import(schema);
1177 } else {
1178 if (schema && schema->release) {
1179 schema->release(schema);
1180 }
1181 }
1182 if (rb) {
1183 batches_.push_back(std::make_unique<ArrowColumnarBatch>(rb, &schema_));
1184 }
1185 }
1186
1187 const ArrowColumnarBatchVec& Value() const override {
1188 return batches_;
1189 }
1190
1191 size_t NumColumns() const { return schema_.columns.size(); }
1192 size_t NumRows() const { return kAdapterUnknownSize; }
1193
1194 private:
1195 XGDMatrixCallbackNext *next_callback_;
1196 bool at_first_{true};
1197 int nbatches_;
1198 struct ArrowSchemaImporter schema_;
1199 ArrowColumnarBatchVec batches_;
1200};
1201
1203 HostSparsePageView page_;
1204
1205 public:
1206 struct Line {
1207 Entry const* inst;
1208 size_t n;
1209 bst_row_t ridx;
1210 COOTuple GetElement(size_t idx) const { return {ridx, inst[idx].index, inst[idx].fvalue}; }
1211 size_t Size() const { return n; }
1212 };
1213
1214 explicit SparsePageAdapterBatch(HostSparsePageView page) : page_{std::move(page)} {}
1215 Line GetLine(size_t ridx) const { return Line{page_[ridx].data(), page_[ridx].size(), ridx}; }
1216 size_t Size() const { return page_.Size(); }
1217};
1218}; // namespace data
1219} // namespace xgboost
1220#endif // XGBOOST_DATA_ADAPTER_H_
Copyright 2019-2023 by XGBoost Contributors.
data iterator interface this is not a C++ style iterator, but nice for data pulling:) This interface ...
Definition data.h:56
virtual bool Next(void)=0
move to next item
virtual void BeforeFirst(void)=0
set before first of the item
virtual const DType & Value(void) const =0
get current data
parser interface that parses input data used to load dmlc data format into your own data format Diffe...
Definition data.h:293
A type erased view over array_interface protocol defined by numpy.
Definition array_interface.h:388
static Json Load(StringView str, std::ios::openmode mode=std::ios::in)
Decode the JSON object.
Definition json.cc:652
Definition adapter.h:256
Adapter for dense array on host, in Python that's numpy.ndarray.
Definition adapter.h:297
ArrayAdapterBatch const & Value() const override
get current data
Definition adapter.h:304
Definition adapter.h:1006
Definition adapter.h:410
Definition adapter.h:455
const CSCAdapterBatch & Value() const override
get current data
Definition adapter.h:462
Definition adapter.h:476
CSC adapter with support for array interface.
Definition adapter.h:528
const CSCArrayAdapterBatch & Value() const override
get current data
Definition adapter.h:546
Definition adapter.h:141
Definition adapter.h:139
Definition adapter.h:183
const CSRAdapterBatch & Value() const override
get current data
Definition adapter.h:192
Definition adapter.h:313
Adapter for CSR array on host, in Python that's scipy.sparse.csr_matrix.
Definition adapter.h:383
CSRArrayAdapterBatch const & Value() const override
get current data
Definition adapter.h:392
Definition adapter.h:866
Definition adapter.h:549
Definition adapter.h:657
const DataTableAdapterBatch & Value() const override
get current data
Definition adapter.h:664
Definition adapter.h:202
Definition adapter.h:239
const DenseAdapterBatch & Value() const override
get current data
Definition adapter.h:245
Definition adapter.h:676
Definition adapter.h:674
FileAdapter wraps dmlc::parser to read files and provide access in a common interface.
Definition adapter.h:720
const FileAdapterBatch & Value() const override
get current data
Definition adapter.h:724
bool Next() override
move to next item
Definition adapter.h:730
void BeforeFirst() override
set before first of the item
Definition adapter.h:725
Data iterator that takes callback to return data, used in JVM package for accepting data iterator.
Definition adapter.h:749
bool Next() override
move to next item
Definition adapter.h:761
FileAdapterBatch const & Value() const override
get current data
Definition adapter.h:777
void BeforeFirst() override
set before first of the item
Definition adapter.h:757
Definition adapter.h:907
const ArrowColumnarBatchVec & Value() const override
get current data
Definition adapter.h:1187
bool Next() override
move to next item
Definition adapter.h:1157
void BeforeFirst() override
set before first of the item
Definition adapter.h:1153
Definition adapter.h:1202
Indicates this data source cannot contain meta-info such as labels, weights or qid.
Definition adapter.h:129
Simplifies the use of DataIter when there is only one batch.
Definition adapter.h:112
void BeforeFirst() override
set before first of the item
Definition adapter.h:114
bool Next() override
move to next item
Definition adapter.h:115
defines common input data structure, and interface for handling the input data
XGB_EXTERN_C typedef int XGDMatrixCallbackNext(DataIterHandle iter)
Callback function prototype for getting next batch of data.
void * DataIterHandle
handle to a external data iterator
Definition c_api.h:334
XGB_EXTERN_C typedef int XGBCallbackDataIterNext(DataIterHandle data_handle, XGBCallbackSetData *set_function, DataHolderHandle set_function_handle)
The data reading callback function. The iterator will be able to give subset of batch in the data.
Copyright 2015-2023 by XGBoost Contributors.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition base.h:64
Copyright 2015-2023 by XGBoost Contributors.
defines console logging options for xgboost. Use to enforce unified print behavior.
detail namespace with internal helper functions
Definition json.hpp:249
T * BeginPtr(std::vector< T > &vec)
safely get the beginning address of a vector
Definition base.h:284
Definition StdDeque.h:58
constexpr size_t kAdapterUnknownSize
External data formats should implement an adapter as below.
Definition adapter.h:76
namespace of xgboost
Definition base.h:90
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
std::size_t bst_row_t
Type for data row index.
Definition base.h:110
Definition arrow-cdi.h:47
Definition arrow-cdi.h:31
Mini batch used in XGBoost Data Iteration.
Definition c_api.h:340
float * label
labels of each instance
Definition c_api.h:354
int64_t * offset
row pointer to the rows in the data
Definition c_api.h:351
int * index
feature index
Definition c_api.h:358
float * value
feature values
Definition c_api.h:360
float * weight
weight of each instance, can be NULL
Definition c_api.h:356
size_t size
number of rows in the minibatch
Definition c_api.h:342
a block of data, containing several rows in sparse matrix This is useful for (streaming-sxtyle) algor...
Definition data.h:175
const DType * label
array[size] label of each instance
Definition data.h:181
size_t size
batch size
Definition data.h:177
const real_t * weight
With weight: array[size] label of each instance, otherwise nullptr.
Definition data.h:183
const IndexType * index
feature index
Definition data.h:189
const DType * value
feature value, can be NULL, indicating all values are 1
Definition data.h:191
const uint64_t * qid
With qid: array[size] session id of each instance, otherwise nullptr.
Definition data.h:185
const IndexType * field
field id
Definition data.h:187
const size_t * offset
array[size+1], row pointer to beginning of each rows
Definition data.h:179
Element from a sparse vector.
Definition data.h:216
Definition data.h:310
Definition string_view.h:15
Helper for type casting.
Definition array_interface.h:665
Definition adapter.h:956
Definition adapter.h:78
Definition adapter.h:949
Definition adapter.h:88