4#ifndef XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_
5#define XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_
6#include <gtest/gtest.h>
13#include "../../../src/common/linalg_op.h"
14#include "../../../src/data/array_interface.h"
17inline void TestMetaInfoStridedData(int32_t device) {
20 ctx.UpdateAllowUnknown(Args{{
"gpu_id", std::to_string(device)}});
23 linalg::Tensor<float, 3> labels;
24 labels.Reshape(4, 2, 3);
25 auto& h_label = labels.Data()->HostVector();
26 std::iota(h_label.begin(), h_label.end(), 0.0);
28 ASSERT_EQ(t_labels.Shape().size(), 2);
31 auto const& h_result = info.labels.View(-1);
32 ASSERT_EQ(h_result.Shape().size(), 2);
33 auto in_labels = labels.View(-1);
34 linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](
size_t i,
float& v_0) {
36 auto i0 = std::get<0>(tup);
37 auto i1 = std::get<1>(tup);
39 auto v_1 = in_labels(i0, 0, i1);
45 linalg::Tensor<uint64_t, 2> qid;
47 auto& h_qid = qid.Data()->HostVector();
48 std::iota(h_qid.begin(), h_qid.end(), 0);
51 info.SetInfo(ctx,
"qid", StringView{str});
52 auto const& h_result = info.group_ptr_;
53 ASSERT_EQ(h_result.size(), s.Size() + 1);
57 linalg::Tensor<float, 3> base_margin;
58 base_margin.Reshape(4, 2, 3);
59 auto& h_margin = base_margin.Data()->HostVector();
60 std::iota(h_margin.begin(), h_margin.end(), 0.0);
62 ASSERT_EQ(t_margin.Shape().size(), 2);
65 auto const& h_result = info.base_margin_.View(-1);
66 ASSERT_EQ(h_result.Shape().size(), 2);
67 auto in_margin = base_margin.View(-1);
68 linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](
size_t i,
float v_0) {
70 auto i0 = std::get<0>(tup);
71 auto i1 = std::get<1>(tup);
73 auto v_1 = in_margin(i0, 0, i1);
A device-and-host vector abstraction layer.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2021-2023 by XGBoost Contributors.
auto ArrayInterfaceStr(TensorView< T const, D > const &t)
Return string representation of array interface.
Definition linalg.h:724
LINALG_HD auto UnravelIndex(size_t idx, common::Span< size_t const, D > shape)
Turns linear index into multi-dimension index.
Definition linalg.h:613
constexpr detail::AllTag All()
Specify all elements in the axis for slicing.
Definition linalg.h:265
namespace of xgboost
Definition base.h:90