Medial Code Documentation
Loading...
Searching...
No Matches
data.h
Go to the documentation of this file.
1
7#ifndef XGBOOST_DATA_H_
8#define XGBOOST_DATA_H_
9
10#include <dmlc/base.h>
11#include <dmlc/data.h>
12#include <dmlc/serializer.h>
13#include <xgboost/base.h>
15#include <xgboost/linalg.h>
16#include <xgboost/span.h>
17#include <xgboost/string_view.h>
18
19#include <algorithm>
20#include <limits>
21#include <memory>
22#include <numeric>
23#include <string>
24#include <utility>
25#include <vector>
26
27namespace xgboost {
28// forward declare dmatrix.
29class DMatrix;
30struct Context;
31
33enum class DataType : uint8_t {
34 kFloat32 = 1,
35 kDouble = 2,
36 kUInt32 = 3,
37 kUInt64 = 4,
38 kStr = 5
39};
40
41enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };
42
43enum class DataSplitMode : int { kRow = 0, kCol = 1 };
44
48class MetaInfo {
49 public:
51 static constexpr uint64_t kNumField = 12;
52
54 uint64_t num_row_{0}; // NOLINT
56 uint64_t num_col_{0}; // NOLINT
58 uint64_t num_nonzero_{0}; // NOLINT
62 DataSplitMode data_split_mode{DataSplitMode::kRow};
67 std::vector<bst_group_t> group_ptr_; // NOLINT
84
88 std::vector<std::string> feature_type_names;
92 std::vector<std::string> feature_names;
93 /*
94 * \brief Type of each feature. Automatically set when feature_type_names is specifed.
95 */
97 /*
98 * \brief Weight of each feature, used to define the probability of each feature being
99 * selected when using column sampling.
100 */
102
104 MetaInfo() = default;
105 MetaInfo(MetaInfo&& that) = default;
106 MetaInfo& operator=(MetaInfo&& that) = default;
107 MetaInfo& operator=(MetaInfo const& that) = delete;
108
112 void Validate(int32_t device) const;
113
114 MetaInfo Slice(common::Span<int32_t const> ridxs) const;
115
116 MetaInfo Copy() const;
117
123 inline bst_float GetWeight(size_t i) const {
124 return weights_.Size() != 0 ? weights_.HostVector()[i] : 1.0f;
125 }
127 const std::vector<size_t>& LabelAbsSort(Context const* ctx) const;
129 void Clear();
134 void LoadBinary(dmlc::Stream* fi);
139 void SaveBinary(dmlc::Stream* fo) const;
147 void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num);
153 void SetInfo(Context const& ctx, StringView key, StringView interface_str);
154
155 void GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
156 const void** out_dptr) const;
157
158 void SetFeatureInfo(const char *key, const char **info, const bst_ulong size);
159 void GetFeatureInfo(const char *field, std::vector<std::string>* out_str_vecs) const;
160
161 /*
162 * \brief Extend with other MetaInfo.
163 *
164 * \param that The other MetaInfo object.
165 *
166 * \param accumulate_rows Whether rows need to be accumulated in this function. If
167 * client code knows number of rows in advance, set this
168 * parameter to false.
169 * \param check_column Whether the extend method should check the consistency of
170 * columns.
171 */
172 void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
173
182
184 bool IsRowSplit() const {
185 return data_split_mode == DataSplitMode::kRow;
186 }
187
189 bool IsColumnSplit() const { return data_split_mode == DataSplitMode::kCol; }
191 bool IsRanking() const { return !group_ptr_.empty(); }
192
197 bool IsVerticalFederated() const;
198
205 bool ShouldHaveLabels() const;
206
207 private:
208 void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
209 void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
210
212 mutable std::vector<size_t> label_order_cache_;
213};
214
216struct Entry {
222 Entry() = default;
230 inline static bool CmpValue(const Entry& a, const Entry& b) {
231 return a.fvalue < b.fvalue;
232 }
233 static bool CmpIndex(Entry const& a, Entry const& b) {
234 return a.index < b.index;
235 }
236 inline bool operator==(const Entry& other) const {
237 return (this->index == other.index && this->fvalue == other.fvalue);
238 }
239};
240
257 bool regen{false};
261 bool forbid_regen{false};
265 double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};
266
270 BatchParam() = default;
283 : max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
284
285 [[nodiscard]] bool ParamNotEqual(BatchParam const& other) const {
286 // Check non-floating parameters.
287 bool cond = max_bin != other.max_bin;
288 // Check sparse thresh.
289 bool l_nan = std::isnan(sparse_thresh);
290 bool r_nan = std::isnan(other.sparse_thresh);
291 bool st_chg = (l_nan != r_nan) || (!l_nan && !r_nan && (sparse_thresh != other.sparse_thresh));
292 cond |= st_chg;
293
294 return cond;
295 }
296 [[nodiscard]] bool Initialized() const { return max_bin != 0; }
300 [[nodiscard]] BatchParam MakeCache() const {
301 auto p = *this;
302 // These parameters have nothing to do with how the gradient index was generated in the
303 // first place.
304 p.regen = false;
305 p.forbid_regen = false;
306 return p;
307 }
308};
309
312
315
316 Inst operator[](size_t i) const {
317 auto size = *(offset.data() + i + 1) - *(offset.data() + i);
318 return {data.data() + *(offset.data() + i),
319 static_cast<Inst::index_type>(size)};
320 }
321
322 [[nodiscard]] size_t Size() const { return offset.size() == 0 ? 0 : offset.size() - 1; }
323};
324
329 public:
330 // Offset for each row.
334
335 size_t base_rowid {0};
336
339
340 [[nodiscard]] HostSparsePageView GetView() const {
341 return {offset.ConstHostSpan(), data.ConstHostSpan()};
342 }
343
346 this->Clear();
347 }
348
349 SparsePage(SparsePage const& that) = delete;
350 SparsePage(SparsePage&& that) = default;
351 SparsePage& operator=(SparsePage const& that) = delete;
352 SparsePage& operator=(SparsePage&& that) = default;
353 virtual ~SparsePage() = default;
354
356 [[nodiscard]] size_t Size() const {
357 return offset.Size() == 0 ? 0 : offset.Size() - 1;
358 }
359
361 [[nodiscard]] size_t MemCostBytes() const {
362 return offset.Size() * sizeof(size_t) + data.Size() * sizeof(Entry);
363 }
364
366 inline void Clear() {
367 base_rowid = 0;
368 auto& offset_vec = offset.HostVector();
369 offset_vec.clear();
370 offset_vec.push_back(0);
371 data.HostVector().clear();
372 }
373
375 inline void SetBaseRowId(size_t row_id) {
376 base_rowid = row_id;
377 }
378
379 [[nodiscard]] SparsePage GetTranspose(int num_columns, int32_t n_threads) const;
380
384 void SortIndices(int32_t n_threads);
388 [[nodiscard]] bool IsIndicesSorted(int32_t n_threads) const;
392 void Reindex(uint64_t feature_offset, int32_t n_threads);
393
394 void SortRows(int32_t n_threads);
395
406 template <typename AdapterBatchT>
407 uint64_t Push(const AdapterBatchT& batch, float missing, int nthread);
408
413 void Push(const SparsePage &batch);
418 void PushCSC(const SparsePage& batch);
419};
420
421class CSCPage: public SparsePage {
422 public:
423 CSCPage() : SparsePage() {}
424 explicit CSCPage(SparsePage page) : SparsePage(std::move(page)) {}
425};
426
432 public:
433 std::shared_ptr<SparsePage const> page;
434 explicit ExtSparsePage(std::shared_ptr<SparsePage const> p) : page{std::move(p)} {}
435};
436
437class SortedCSCPage : public SparsePage {
438 public:
440 explicit SortedCSCPage(SparsePage page) : SparsePage(std::move(page)) {}
441};
442
443class EllpackPage;
444class GHistIndexMatrix;
445
446template<typename T>
448 public:
449 using iterator_category = std::forward_iterator_tag; // NOLINT
450 virtual ~BatchIteratorImpl() = default;
451 virtual const T& operator*() const = 0;
452 virtual BatchIteratorImpl& operator++() = 0;
453 [[nodiscard]] virtual bool AtEnd() const = 0;
454 virtual std::shared_ptr<T const> Page() const = 0;
455};
456
457template<typename T>
459 public:
460 using iterator_category = std::forward_iterator_tag; // NOLINT
461 explicit BatchIterator(BatchIteratorImpl<T>* impl) { impl_.reset(impl); }
462 explicit BatchIterator(std::shared_ptr<BatchIteratorImpl<T>> impl) { impl_ = impl; }
463
464 BatchIterator &operator++() {
465 CHECK(impl_ != nullptr);
466 ++(*impl_);
467 return *this;
468 }
469
470 const T& operator*() const {
471 CHECK(impl_ != nullptr);
472 return *(*impl_);
473 }
474
475 bool operator!=(const BatchIterator&) const {
476 CHECK(impl_ != nullptr);
477 return !impl_->AtEnd();
478 }
479
480 [[nodiscard]] bool AtEnd() const {
481 CHECK(impl_ != nullptr);
482 return impl_->AtEnd();
483 }
484
485 [[nodiscard]] std::shared_ptr<T const> Page() const {
486 return impl_->Page();
487 }
488
489 private:
490 std::shared_ptr<BatchIteratorImpl<T>> impl_;
491};
492
493template<typename T>
494class BatchSet {
495 public:
496 explicit BatchSet(BatchIterator<T> begin_iter) : begin_iter_(std::move(begin_iter)) {}
497 BatchIterator<T> begin() { return begin_iter_; } // NOLINT
498 BatchIterator<T> end() { return BatchIterator<T>(nullptr); } // NOLINT
499
500 private:
501 BatchIterator<T> begin_iter_;
502};
503
505
509class DMatrix {
510 public:
512 DMatrix() = default;
514 virtual MetaInfo& Info() = 0;
515 virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
516 auto const& ctx = *this->Ctx();
517 this->Info().SetInfo(ctx, key, dptr, dtype, num);
518 }
519 virtual void SetInfo(const char* key, std::string const& interface_str) {
520 auto const& ctx = *this->Ctx();
521 this->Info().SetInfo(ctx, key, StringView{interface_str});
522 }
524 [[nodiscard]] virtual const MetaInfo& Info() const = 0;
525
527 [[nodiscard]] XGBAPIThreadLocalEntry& GetThreadLocal() const;
532 [[nodiscard]] virtual Context const* Ctx() const = 0;
533
537 template <typename T>
539 template <typename T>
540 BatchSet<T> GetBatches(Context const* ctx);
541 template <typename T>
542 BatchSet<T> GetBatches(Context const* ctx, const BatchParam& param);
543 template <typename T>
544 [[nodiscard]] bool PageExists() const;
545
546 // the following are column meta data, should be able to answer them fast.
548 [[nodiscard]] virtual bool SingleColBlock() const = 0;
550 virtual ~DMatrix();
551
553 [[nodiscard]] bool IsDense() const {
554 return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
555 }
556
566 static DMatrix* Load(const std::string& uri, bool silent = true,
567 DataSplitMode data_split_mode = DataSplitMode::kRow);
568
581 template <typename AdapterT>
582 static DMatrix* Create(AdapterT* adapter, float missing, int nthread,
583 const std::string& cache_prefix = "",
584 DataSplitMode data_split_mode = DataSplitMode::kRow);
585
605 template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
606 typename XGDMatrixCallbackNext>
607 static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
608 DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
609 int nthread, bst_bin_t max_bin);
610
629 template <typename DataIterHandle, typename DMatrixHandle,
631 static DMatrix *Create(DataIterHandle iter, DMatrixHandle proxy,
633 XGDMatrixCallbackNext *next, float missing,
634 int32_t nthread, std::string cache);
635
636 virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
637
645 virtual DMatrix *SliceCol(int num_slices, int slice_id) = 0;
646
647 protected:
648 virtual BatchSet<SparsePage> GetRowBatches() = 0;
649 virtual BatchSet<CSCPage> GetColumnBatches(Context const* ctx) = 0;
650 virtual BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const* ctx) = 0;
651 virtual BatchSet<EllpackPage> GetEllpackBatches(Context const* ctx, BatchParam const& param) = 0;
652 virtual BatchSet<GHistIndexMatrix> GetGradientIndex(Context const* ctx,
653 BatchParam const& param) = 0;
654 virtual BatchSet<ExtSparsePage> GetExtBatches(Context const* ctx, BatchParam const& param) = 0;
655
656 [[nodiscard]] virtual bool EllpackExists() const = 0;
657 [[nodiscard]] virtual bool GHistIndexExists() const = 0;
658 [[nodiscard]] virtual bool SparsePageExists() const = 0;
659};
660
661template <>
662inline BatchSet<SparsePage> DMatrix::GetBatches() {
663 return GetRowBatches();
664}
665
666template <>
667inline bool DMatrix::PageExists<EllpackPage>() const {
668 return this->EllpackExists();
669}
670
671template <>
672inline bool DMatrix::PageExists<GHistIndexMatrix>() const {
673 return this->GHistIndexExists();
674}
675
676template <>
677inline bool DMatrix::PageExists<SparsePage>() const {
678 return this->SparsePageExists();
679}
680
681template <>
682inline BatchSet<SparsePage> DMatrix::GetBatches(Context const*) {
683 return GetRowBatches();
684}
685
686template <>
687inline BatchSet<CSCPage> DMatrix::GetBatches(Context const* ctx) {
688 return GetColumnBatches(ctx);
689}
690
691template <>
692inline BatchSet<SortedCSCPage> DMatrix::GetBatches(Context const* ctx) {
693 return GetSortedColumnBatches(ctx);
694}
695
696template <>
697inline BatchSet<EllpackPage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
698 return GetEllpackBatches(ctx, param);
699}
700
701template <>
702inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
703 return GetGradientIndex(ctx, param);
704}
705
706template <>
707inline BatchSet<ExtSparsePage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
708 return GetExtBatches(ctx, param);
709}
710} // namespace xgboost
711
712DECLARE_FIELD_ENUM_CLASS(xgboost::DataSplitMode);
713
714namespace dmlc {
716
717namespace serializer {
718
719template <>
720struct Handler<xgboost::Entry> {
721 inline static void Write(Stream* strm, const xgboost::Entry& data) {
722 strm->Write(data.index);
723 strm->Write(data.fvalue);
724 }
725
726 inline static bool Read(Stream* strm, xgboost::Entry* data) {
727 return strm->Read(&data->index) && strm->Read(&data->fvalue);
728 }
729};
730
731} // namespace serializer
732} // namespace dmlc
733#endif // XGBOOST_DATA_H_
interface of stream I/O for serialization
Definition io.h:30
virtual void Write(const void *ptr, size_t size)=0
writes data to a stream
virtual size_t Read(void *ptr, size_t size)=0
reads data from a stream
Definition data.h:447
Definition data.h:458
Definition data.h:494
Definition data.h:421
Internal data structured used by XGBoost during training.
Definition data.h:509
virtual const MetaInfo & Info() const =0
meta information of the dataset
virtual Context const * Ctx() const =0
Get the context object of this DMatrix.
virtual MetaInfo & Info()=0
meta information of the dataset
static DMatrix * Create(AdapterT *adapter, float missing, int nthread, const std::string &cache_prefix="", DataSplitMode data_split_mode=DataSplitMode::kRow)
Creates a new DMatrix from an external data adapter.
Definition data.cc:975
virtual ~DMatrix()
virtual destructor
Definition data.cc:822
virtual DMatrix * SliceCol(int num_slices, int slice_id)=0
Slice a DMatrix by columns.
BatchSet< T > GetBatches()
Gets batches.
virtual bool SingleColBlock() const =0
static DMatrix * Load(const std::string &uri, bool silent=true, DataSplitMode data_split_mode=DataSplitMode::kRow)
Load DMatrix from URI.
Definition data.cc:853
XGBAPIThreadLocalEntry & GetThreadLocal() const
Get thread local memory for returning data from DMatrix.
Definition data.cc:818
bool IsDense() const
Whether the matrix is dense.
Definition data.h:553
DMatrix()=default
default constructor
A page stored in ELLPACK format.
Definition ellpack_page.h:21
Sparse page for exporting DMatrix.
Definition data.h:431
preprocessed global index matrix, in CSR format.
Definition gradient_index.h:38
Definition host_device_vector.h:87
Data structure representing JSON format.
Definition json.h:357
Meta information about dataset, always sit in memory.
Definition data.h:48
linalg::Tensor< float, 2 > base_margin_
initialized margins, if specified, xgboost will start from this init margin can be used to specify in...
Definition data.h:75
std::vector< std::string > feature_names
Name for each feature.
Definition data.h:92
HostDeviceVector< bst_float > labels_upper_bound_
upper bound of the label, to be used for survival analysis (censored regression)
Definition data.h:83
void Validate(int32_t device) const
Validate all metainfo.
Definition data.cc:757
uint64_t num_col_
number of columns in the data
Definition data.h:56
std::vector< std::string > feature_type_names
Name of type for each feature provided by users. Eg. "int"/"float"/"i"/"q".
Definition data.h:88
HostDeviceVector< bst_float > weights_
weights of each instance, optional
Definition data.h:69
bool IsVerticalFederated() const
A convenient method to check if we are doing vertical federated learning, which requires some special...
Definition data.cc:807
void SynchronizeNumberOfColumns()
Synchronize the number of columns across all workers.
Definition data.cc:731
bool IsColumnSplit() const
Whether the data is split column-wise.
Definition data.h:189
bst_float GetWeight(size_t i) const
Get weight of each instances.
Definition data.h:123
DataSplitMode data_split_mode
data split mode
Definition data.h:62
void LoadBinary(dmlc::Stream *fi)
Load the Meta info from binary stream.
Definition data.cc:295
std::vector< bst_group_t > group_ptr_
the index of begin and end of a group needed when the learning task is ranking.
Definition data.h:67
uint64_t num_row_
number of rows in the data
Definition data.h:54
void Clear()
clear all the information
Definition data.cc:202
static constexpr uint64_t kNumField
number of data fields in MetaInfo
Definition data.h:51
bool IsRanking() const
Whether this is a learning to rank data.
Definition data.h:191
uint64_t num_nonzero_
number of nonzero entries in the data
Definition data.h:58
linalg::Tensor< float, 2 > labels
label of each instance
Definition data.h:60
const std::vector< size_t > & LabelAbsSort(Context const *ctx) const
get sorted indexes (argsort) of labels by absolute value (used by cox loss)
Definition data.cc:282
void SaveBinary(dmlc::Stream *fo) const
Save the Meta info to binary stream.
Definition data.cc:233
bool ShouldHaveLabels() const
A convenient method to check if the MetaInfo should contain labels.
Definition data.cc:811
MetaInfo()=default
default constructor
void SetInfo(Context const &ctx, const char *key, const void *dptr, DataType dtype, size_t num)
Set information in the meta info.
Definition data.cc:562
bool IsRowSplit() const
Whether the data is split row-wise.
Definition data.h:184
HostDeviceVector< bst_float > labels_lower_bound_
lower bound of the label, to be used for survival analysis (censored regression)
Definition data.h:79
Definition data.h:437
In-memory storage unit of sparse batch, stored in CSR format.
Definition data.h:328
SparsePage()
constructor
Definition data.h:345
void SetBaseRowId(size_t row_id)
Set the base row id for this page.
Definition data.h:375
void Reindex(uint64_t feature_offset, int32_t n_threads)
Reindex the column index with an offset.
Definition data.cc:1081
uint64_t Push(const AdapterBatchT &batch, float missing, int nthread)
Pushes external data batch onto this page.
Definition data.cc:1117
void PushCSC(const SparsePage &batch)
Push a SparsePage stored in CSC format.
Definition data.cc:1216
bool IsIndicesSorted(int32_t n_threads) const
Check wether the column index is sorted.
Definition data.cc:1053
void SortIndices(int32_t n_threads)
Sort the column index.
Definition data.cc:1070
HostDeviceVector< Entry > data
the data of the segments
Definition data.h:333
size_t MemCostBytes() const
Definition data.h:361
void Clear()
clear the page
Definition data.h:366
size_t Size() const
Definition data.h:356
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
A tensor storage.
Definition linalg.h:742
defines configuration macros
defines common input data structure, and interface for handling the input data
XGB_EXTERN_C typedef int XGDMatrixCallbackNext(DataIterHandle iter)
Callback function prototype for getting next batch of data.
XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle)
Callback function prototype for resetting external iterator.
A device-and-host vector abstraction layer.
Copyright 2015-2023 by XGBoost Contributors.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition base.h:64
#define DECLARE_FIELD_ENUM_CLASS(EnumClass)
Specialization of FieldEntry for enum class (backed by int)
Definition parameter.h:50
Copyright 2021-2023 by XGBoost Contributors.
namespace for dmlc
Definition array_view.h:12
Definition feature_weights.py:1
namespace of xgboost
Definition base.h:90
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
uint64_t bst_ulong
unsigned long integers
Definition base.h:95
int32_t bst_bin_t
Type for histogram bin index.
Definition base.h:103
DataType
data type accepted by xgboost interface
Definition data.h:33
float bst_float
float type, used for storing statistics
Definition base.h:97
serializer template class that helps serialization. This file do not need to be directly used by most...
generic serialization handler
Definition serializer.h:259
static void Write(Stream *strm, const T &data)
write data to stream
Definition serializer.h:265
static bool Read(Stream *strm, T *data)
read data to stream
Definition serializer.h:283
Parameters for constructing histogram index batches.
Definition data.h:244
bool forbid_regen
Forbid regenerating the gradient index.
Definition data.h:261
bst_bin_t max_bin
Maximum number of bins per feature for histograms.
Definition data.h:248
common::Span< float const > hess
Hessian, used for sketching with future approx implementation.
Definition data.h:252
bool regen
Whether should we force DMatrix to regenerate the batch.
Definition data.h:257
BatchParam()=default
Exact or others that don't need histogram.
double sparse_thresh
Parameter used to generate column matrix for hist.
Definition data.h:265
BatchParam(bst_bin_t max_bin, common::Span< float const > hessian, bool regenerate)
Used by the approx tree method.
Definition data.h:282
BatchParam MakeCache() const
Make a copy of self for DMatrix to describe how its existing index was generated.
Definition data.h:300
BatchParam(bst_bin_t max_bin, double sparse_thresh)
Used by the hist tree method.
Definition data.h:274
Runtime context for XGBoost.
Definition context.h:84
Element from a sparse vector.
Definition data.h:216
XGBOOST_DEVICE Entry(bst_feature_t index, bst_float fvalue)
constructor with index and value
Definition data.h:228
Entry()=default
default constructor
bst_feature_t index
feature index
Definition data.h:218
bst_float fvalue
feature value
Definition data.h:220
static bool CmpValue(const Entry &a, const Entry &b)
reversely compare feature values
Definition data.h:230
Definition data.h:310
Definition string_view.h:15
entry to to easily hold returning information
Definition api_entry.h:16
#define DMLC_DECLARE_TRAITS(Trait, Type, Value)
macro to quickly declare traits information
Definition type_traits.h:126