Medial Code Documentation
Loading...
Searching...
No Matches
proxy_dmatrix.h
1
4#ifndef XGBOOST_DATA_PROXY_DMATRIX_H_
5#define XGBOOST_DATA_PROXY_DMATRIX_H_
6
7#include <any> // for any, any_cast
8#include <memory>
9#include <string>
10#include <utility>
11
12#include "adapter.h"
13#include "xgboost/c_api.h"
14#include "xgboost/context.h"
15#include "xgboost/data.h"
16
17namespace xgboost::data {
18/*
19 * \brief A proxy to external iterator.
20 */
21template <typename ResetFn, typename NextFn>
23 DataIterHandle iter_;
24 ResetFn* reset_;
25 NextFn* next_;
26
27 public:
28 DataIterProxy(DataIterHandle iter, ResetFn* reset, NextFn* next)
29 : iter_{iter}, reset_{reset}, next_{next} {}
30
31 bool Next() { return next_(iter_); }
32 void Reset() { reset_(iter_); }
33};
34
35/*
36 * \brief A proxy of DMatrix used by external iterator.
37 */
38class DMatrixProxy : public DMatrix {
39 MetaInfo info_;
40 std::any batch_;
41 Context ctx_;
42
43#if defined(XGBOOST_USE_CUDA)
44 void FromCudaColumnar(StringView interface_str);
45 void FromCudaArray(StringView interface_str);
46#endif // defined(XGBOOST_USE_CUDA)
47
48 public:
49 int DeviceIdx() const { return ctx_.gpu_id; }
50
51 void SetCUDAArray(char const* c_interface) {
52 common::AssertGPUSupport();
53 CHECK(c_interface);
54#if defined(XGBOOST_USE_CUDA)
55 StringView interface_str{c_interface};
56 Json json_array_interface = Json::Load(interface_str);
57 if (IsA<Array>(json_array_interface)) {
58 this->FromCudaColumnar(interface_str);
59 } else {
60 this->FromCudaArray(interface_str);
61 }
62#endif // defined(XGBOOST_USE_CUDA)
63 }
64
65 void SetArrayData(StringView interface_str);
66 void SetCSRData(char const* c_indptr, char const* c_indices, char const* c_values,
67 bst_feature_t n_features, bool on_host);
68
69 MetaInfo& Info() override { return info_; }
70 MetaInfo const& Info() const override { return info_; }
71 Context const* Ctx() const override { return &ctx_; }
72
73 bool SingleColBlock() const override { return false; }
74 bool EllpackExists() const override { return false; }
75 bool GHistIndexExists() const override { return false; }
76 bool SparsePageExists() const override { return false; }
77
78 template <typename Page>
79 BatchSet<Page> NoBatch() {
80 LOG(FATAL) << "Proxy DMatrix cannot return data batch.";
81 return BatchSet<Page>(BatchIterator<Page>(nullptr));
82 }
83
84 DMatrix* Slice(common::Span<int32_t const> /*ridxs*/) override {
85 LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix.";
86 return nullptr;
87 }
88 DMatrix* SliceCol(int, int) override {
89 LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix.";
90 return nullptr;
91 }
92 BatchSet<SparsePage> GetRowBatches() override { return NoBatch<SparsePage>(); }
93 BatchSet<CSCPage> GetColumnBatches(Context const*) override { return NoBatch<CSCPage>(); }
94 BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const*) override {
95 return NoBatch<SortedCSCPage>();
96 }
97 BatchSet<EllpackPage> GetEllpackBatches(Context const*, BatchParam const&) override {
98 return NoBatch<EllpackPage>();
99 }
100 BatchSet<GHistIndexMatrix> GetGradientIndex(Context const*, BatchParam const&) override {
101 return NoBatch<GHistIndexMatrix>();
102 }
103 BatchSet<ExtSparsePage> GetExtBatches(Context const*, BatchParam const&) override {
104 return NoBatch<ExtSparsePage>();
105 }
106 std::any Adapter() const { return batch_; }
107};
108
109inline DMatrixProxy* MakeProxy(DMatrixHandle proxy) {
110 auto proxy_handle = static_cast<std::shared_ptr<DMatrix>*>(proxy);
111 CHECK(proxy_handle) << "Invalid proxy handle.";
112 DMatrixProxy* typed = static_cast<DMatrixProxy*>(proxy_handle->get());
113 CHECK(typed) << "Invalid proxy handle.";
114 return typed;
115}
116
130template <bool get_value = true, typename Fn>
131decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_error = nullptr) {
132 if (proxy->Adapter().type() == typeid(std::shared_ptr<CSRArrayAdapter>)) {
133 if constexpr (get_value) {
134 auto value = std::any_cast<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter())->Value();
135 return fn(value);
136 } else {
137 auto value = std::any_cast<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter());
138 return fn(value);
139 }
140 if (type_error) {
141 *type_error = false;
142 }
143 } else if (proxy->Adapter().type() == typeid(std::shared_ptr<ArrayAdapter>)) {
144 if constexpr (get_value) {
145 auto value = std::any_cast<std::shared_ptr<ArrayAdapter>>(proxy->Adapter())->Value();
146 return fn(value);
147 } else {
148 auto value = std::any_cast<std::shared_ptr<ArrayAdapter>>(proxy->Adapter());
149 return fn(value);
150 }
151 if (type_error) {
152 *type_error = false;
153 }
154 } else {
155 if (type_error) {
156 *type_error = true;
157 } else {
158 LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
159 }
160 if constexpr (get_value) {
161 return std::result_of_t<Fn(
162 decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value()))>();
163 } else {
164 return std::result_of_t<Fn(decltype(std::declval<std::shared_ptr<ArrayAdapter>>()))>();
165 }
166 }
167}
168
172std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const* ctx,
173 std::shared_ptr<DMatrixProxy> proxy, float missing);
174} // namespace xgboost::data
175#endif // XGBOOST_DATA_PROXY_DMATRIX_H_
Definition data.h:458
Definition data.h:494
Data structure representing JSON format.
Definition json.h:357
static Json Load(StringView str, std::ios::openmode mode=std::ios::in)
Decode the JSON object.
Definition json.cc:652
Meta information about dataset, always sit in memory.
Definition data.h:48
Definition json.h:26
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
Definition core.py:748
Definition proxy_dmatrix.h:38
Definition proxy_dmatrix.h:22
Copyright 2014-2023, XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2019-2023, XGBoost Contributors.
Definition data.py:1
decltype(auto) HostAdapterDispatch(DMatrixProxy const *proxy, Fn fn, bool *type_error=nullptr)
Dispatch function call based on input type.
Definition proxy_dmatrix.h:131
std::shared_ptr< DMatrix > CreateDMatrixFromProxy(Context const *ctx, std::shared_ptr< DMatrixProxy > proxy, float missing)
Create a SimpleDMatrix instance from a DMatrixProxy.
Definition proxy_dmatrix.cc:39
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
Parameters for constructing histogram index batches.
Definition data.h:244
Runtime context for XGBoost.
Definition context.h:84
Definition string_view.h:15
Copyright 2015~2023 by XGBoost Contributors.