Medial Code Documentation
Loading...
Searching...
No Matches
histogram_helpers.h
1#if defined(__CUDACC__)
2#include "../../src/data/ellpack_page.cuh"
3#endif
4
5#include <xgboost/data.h> // for SparsePage
6
7#include "./helpers.h" // for RandomDataGenerator
8
9namespace xgboost {
10#if defined(__CUDACC__)
11namespace {
12class HistogramCutsWrapper : public common::HistogramCuts {
13 public:
14 using SuperT = common::HistogramCuts;
15 void SetValues(std::vector<float> cuts) {
16 SuperT::cut_values_.HostVector() = std::move(cuts);
17 }
18 void SetPtrs(std::vector<uint32_t> ptrs) {
19 SuperT::cut_ptrs_.HostVector() = std::move(ptrs);
20 }
21 void SetMins(std::vector<float> mins) {
22 SuperT::min_vals_.HostVector() = std::move(mins);
23 }
24};
25} // anonymous namespace
26
27inline std::unique_ptr<EllpackPageImpl> BuildEllpackPage(
28 int n_rows, int n_cols, bst_float sparsity= 0) {
29 auto dmat = RandomDataGenerator(n_rows, n_cols, sparsity).Seed(3).GenerateDMatrix();
30 const SparsePage& batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
31
32 HistogramCutsWrapper cmat;
33 cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
34 // 24 cut fields, 3 cut fields for each feature (column).
35 cmat.SetValues({0.30f, 0.67f, 1.64f,
36 0.32f, 0.77f, 1.95f,
37 0.29f, 0.70f, 1.80f,
38 0.32f, 0.75f, 1.85f,
39 0.18f, 0.59f, 1.69f,
40 0.25f, 0.74f, 2.00f,
41 0.26f, 0.74f, 1.98f,
42 0.26f, 0.71f, 1.83f});
43 cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
44
45 bst_row_t row_stride = 0;
46 const auto &offset_vec = batch.offset.ConstHostVector();
47 for (size_t i = 1; i < offset_vec.size(); ++i) {
48 row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]);
49 }
50
51 auto page = std::unique_ptr<EllpackPageImpl>(
52 new EllpackPageImpl(0, cmat, batch, dmat->IsDense(), row_stride, {}));
53
54 return page;
55}
56#endif
57} // namespace xgboost
In-memory storage unit of sparse batch, stored in CSR format.
Definition data.h:328
Copyright 2015-2023 by XGBoost Contributors.
namespace of xgboost
Definition base.h:90
std::size_t bst_row_t
Type for data row index.
Definition base.h:110
float bst_float
float type, used for storing statistics
Definition base.h:97