4#ifndef XGBOOST_DATA_PROXY_DMATRIX_H_
5#define XGBOOST_DATA_PROXY_DMATRIX_H_
21template <
typename ResetFn,
typename NextFn>
28 DataIterProxy(DataIterHandle iter, ResetFn* reset, NextFn* next)
29 : iter_{iter}, reset_{reset}, next_{next} {}
31 bool Next() {
return next_(iter_); }
32 void Reset() { reset_(iter_); }
43#if defined(XGBOOST_USE_CUDA)
44 void FromCudaColumnar(
StringView interface_str);
49 int DeviceIdx()
const {
return ctx_.gpu_id; }
51 void SetCUDAArray(
char const* c_interface) {
52 common::AssertGPUSupport();
54#if defined(XGBOOST_USE_CUDA)
57 if (IsA<Array>(json_array_interface)) {
58 this->FromCudaColumnar(interface_str);
60 this->FromCudaArray(interface_str);
66 void SetCSRData(
char const* c_indptr,
char const* c_indices,
char const* c_values,
69 MetaInfo& Info()
override {
return info_; }
70 MetaInfo const& Info()
const override {
return info_; }
71 Context const* Ctx()
const override {
return &ctx_; }
73 bool SingleColBlock()
const override {
return false; }
74 bool EllpackExists()
const override {
return false; }
75 bool GHistIndexExists()
const override {
return false; }
76 bool SparsePageExists()
const override {
return false; }
78 template <
typename Page>
80 LOG(FATAL) <<
"Proxy DMatrix cannot return data batch.";
85 LOG(FATAL) <<
"Slicing DMatrix is not supported for Proxy DMatrix.";
88 DMatrix* SliceCol(
int,
int)
override {
89 LOG(FATAL) <<
"Slicing DMatrix columns is not supported for Proxy DMatrix.";
95 return NoBatch<SortedCSCPage>();
98 return NoBatch<EllpackPage>();
101 return NoBatch<GHistIndexMatrix>();
104 return NoBatch<ExtSparsePage>();
106 std::any Adapter()
const {
return batch_; }
110 auto proxy_handle =
static_cast<std::shared_ptr<DMatrix>*
>(proxy);
111 CHECK(proxy_handle) <<
"Invalid proxy handle.";
113 CHECK(typed) <<
"Invalid proxy handle.";
130template <
bool get_value = true,
typename Fn>
132 if (proxy->Adapter().type() ==
typeid(std::shared_ptr<CSRArrayAdapter>)) {
133 if constexpr (get_value) {
134 auto value = std::any_cast<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter())->
Value();
137 auto value = std::any_cast<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter());
143 }
else if (proxy->Adapter().type() ==
typeid(std::shared_ptr<ArrayAdapter>)) {
144 if constexpr (get_value) {
145 auto value = std::any_cast<std::shared_ptr<ArrayAdapter>>(proxy->Adapter())->
Value();
148 auto value = std::any_cast<std::shared_ptr<ArrayAdapter>>(proxy->Adapter());
158 LOG(FATAL) <<
"Unknown type: " << proxy->Adapter().type().name();
160 if constexpr (get_value) {
161 return std::result_of_t<Fn(
162 decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->
Value()))>();
164 return std::result_of_t<Fn(
decltype(std::declval<std::shared_ptr<ArrayAdapter>>()))>();
173 std::shared_ptr<DMatrixProxy> proxy,
float missing);
Data structure representing JSON format.
Definition json.h:357
static Json Load(StringView str, std::ios::openmode mode=std::ios::in)
Decode the JSON object.
Definition json.cc:652
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
Definition proxy_dmatrix.h:38
Definition proxy_dmatrix.h:22
Copyright 2014-2023, XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2019-2023, XGBoost Contributors.
Definition data.py:1
decltype(auto) HostAdapterDispatch(DMatrixProxy const *proxy, Fn fn, bool *type_error=nullptr)
Dispatch function call based on input type.
Definition proxy_dmatrix.h:131
std::shared_ptr< DMatrix > CreateDMatrixFromProxy(Context const *ctx, std::shared_ptr< DMatrixProxy > proxy, float missing)
Create a SimpleDMatrix instance from a DMatrixProxy.
Definition proxy_dmatrix.cc:39
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
Parameters for constructing histogram index batches.
Definition data.h:244
Runtime context for XGBoost.
Definition context.h:84
Definition string_view.h:15
Copyright 2015~2023 by XGBoost Contributors.