Medial Code Documentation
Loading...
Searching...
No Matches
test_array_interface.h
1// Copyright (c) 2019 by Contributors
2#include <gtest/gtest.h>
3#include <xgboost/data.h>
4#include <xgboost/json.h>
5#include <thrust/device_vector.h>
6
7#include <memory>
8#include "../../../src/common/bitfield.h"
9#include "../../../src/common/device_helpers.cuh"
10
11namespace xgboost {
12
13template <typename T>
14Json GenerateDenseColumn(std::string const& typestr, size_t kRows,
15 thrust::device_vector<T>* out_d_data) {
16 auto& d_data = *out_d_data;
17 d_data.resize(kRows);
18 Json column { Object() };
19 std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
20 column["shape"] = Array(j_shape);
21 column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(sizeof(T))))});
22 column["stream"] = nullptr;
23
24 d_data.resize(kRows);
25 thrust::sequence(thrust::device, d_data.begin(), d_data.end(), 0.0f, 2.0f);
26
27 auto p_d_data = d_data.data().get();
28
29 std::vector<Json> j_data {
30 Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
31 Json(Boolean(false))};
32 column["data"] = j_data;
33
34 column["version"] = 3;
35 column["typestr"] = String(typestr);
36 return column;
37}
38
39template <typename T>
40Json GenerateSparseColumn(std::string const& typestr, size_t kRows,
41 thrust::device_vector<T>* out_d_data) {
42 auto& d_data = *out_d_data;
43 Json column { Object() };
44 std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
45 column["shape"] = Array(j_shape);
46 column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(sizeof(T))))});
47 column["stream"] = nullptr;
48
49 d_data.resize(kRows);
50 for (size_t i = 0; i < d_data.size(); ++i) {
51 d_data[i] = i * 2.0;
52 }
53
54 auto p_d_data = d_data.data().get();
55
56 std::vector<Json> j_data {
57 Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
58 Json(Boolean(false))};
59 column["data"] = j_data;
60
61 column["version"] = 3;
62 column["typestr"] = String(typestr);
63 return column;
64}
65
66template <typename T>
67Json Generate2dArrayInterface(int rows, int cols, std::string typestr,
68 thrust::device_vector<T> *p_data) {
69 auto& data = *p_data;
70 thrust::sequence(data.begin(), data.end());
71
72 Json array_interface{Object()};
73 std::vector<Json> shape = {Json(static_cast<Integer::Int>(rows)),
74 Json(static_cast<Integer::Int>(cols))};
75 array_interface["shape"] = Array(shape);
76 std::vector<Json> j_data{
77 Json(Integer(reinterpret_cast<Integer::Int>(data.data().get()))),
78 Json(Boolean(false))};
79 array_interface["data"] = j_data;
80 array_interface["version"] = 3;
81 array_interface["typestr"] = String(typestr);
82 array_interface["stream"] = nullptr;
83 return array_interface;
84}
85} // namespace xgboost
Copyright 2015-2023 by XGBoost Contributors.
namespace of xgboost
Definition base.h:90