2#include <gtest/gtest.h>
4#include <xgboost/json.h>
5#include <thrust/device_vector.h>
8#include "../../../src/common/bitfield.h"
9#include "../../../src/common/device_helpers.cuh"
14Json GenerateDenseColumn(std::string
const& typestr,
size_t kRows,
15 thrust::device_vector<T>* out_d_data) {
16 auto& d_data = *out_d_data;
18 Json column { Object() };
19 std::vector<Json> j_shape {Json(Integer(
static_cast<Integer::Int
>(kRows)))};
20 column[
"shape"] = Array(j_shape);
21 column[
"strides"] = Array(std::vector<Json>{Json(Integer(
static_cast<Integer::Int
>(
sizeof(T))))});
22 column[
"stream"] =
nullptr;
25 thrust::sequence(thrust::device, d_data.begin(), d_data.end(), 0.0f, 2.0f);
27 auto p_d_data = d_data.data().get();
29 std::vector<Json> j_data {
30 Json(Integer(
reinterpret_cast<Integer::Int
>(p_d_data))),
31 Json(Boolean(
false))};
32 column[
"data"] = j_data;
34 column[
"version"] = 3;
35 column[
"typestr"] = String(typestr);
40Json GenerateSparseColumn(std::string
const& typestr,
size_t kRows,
41 thrust::device_vector<T>* out_d_data) {
42 auto& d_data = *out_d_data;
43 Json column { Object() };
44 std::vector<Json> j_shape {Json(Integer(
static_cast<Integer::Int
>(kRows)))};
45 column[
"shape"] = Array(j_shape);
46 column[
"strides"] = Array(std::vector<Json>{Json(Integer(
static_cast<Integer::Int
>(
sizeof(T))))});
47 column[
"stream"] =
nullptr;
50 for (
size_t i = 0; i < d_data.size(); ++i) {
54 auto p_d_data = d_data.data().get();
56 std::vector<Json> j_data {
57 Json(Integer(
reinterpret_cast<Integer::Int
>(p_d_data))),
58 Json(Boolean(
false))};
59 column[
"data"] = j_data;
61 column[
"version"] = 3;
62 column[
"typestr"] = String(typestr);
67Json Generate2dArrayInterface(
int rows,
int cols, std::string typestr,
68 thrust::device_vector<T> *p_data) {
70 thrust::sequence(data.begin(), data.end());
72 Json array_interface{Object()};
73 std::vector<Json> shape = {Json(
static_cast<Integer::Int
>(rows)),
74 Json(
static_cast<Integer::Int
>(cols))};
75 array_interface[
"shape"] = Array(shape);
76 std::vector<Json> j_data{
77 Json(Integer(
reinterpret_cast<Integer::Int
>(data.data().get()))),
78 Json(Boolean(
false))};
79 array_interface[
"data"] = j_data;
80 array_interface[
"version"] = 3;
81 array_interface[
"typestr"] = String(typestr);
82 array_interface[
"stream"] =
nullptr;
83 return array_interface;
Copyright 2015-2023 by XGBoost Contributors.
namespace of xgboost
Definition base.h:90