Medial Code Documentation
Loading...
Searching...
No Matches
test_iterative_dmatrix.h
1
4#pragma once
5#include <xgboost/context.h> // for Context
6
7#include <limits> // for numeric_limits
8#include <memory> // for make_shared
9
10#include "../../../src/data/iterative_dmatrix.h"
11#include "../helpers.h"
12
13namespace xgboost {
14namespace data {
15template <typename Page, typename Iter, typename Cuts>
16void TestRefDMatrix(Context const* ctx, Cuts&& get_cuts) {
17 int n_bins = 256;
18 Iter iter(0.3, 2048);
19 auto m = std::make_shared<IterativeDMatrix>(&iter, iter.Proxy(), nullptr, Reset, Next,
20 std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
21
22 Iter iter_1(0.8, 32, Iter::Cols(), 13);
23 auto m_1 = std::make_shared<IterativeDMatrix>(&iter_1, iter_1.Proxy(), m, Reset, Next,
24 std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
25
26 for (auto const& page_0 : m->template GetBatches<Page>(ctx, {})) {
27 for (auto const& page_1 : m_1->template GetBatches<Page>(ctx, {})) {
28 auto const& cuts_0 = get_cuts(page_0);
29 auto const& cuts_1 = get_cuts(page_1);
30 ASSERT_EQ(cuts_0.Values(), cuts_1.Values());
31 ASSERT_EQ(cuts_0.Ptrs(), cuts_1.Ptrs());
32 ASSERT_EQ(cuts_0.MinValues(), cuts_1.MinValues());
33 }
34 }
35
36 m_1 = std::make_shared<IterativeDMatrix>(&iter_1, iter_1.Proxy(), nullptr, Reset, Next,
37 std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
38 for (auto const& page_0 : m->template GetBatches<Page>(ctx, {})) {
39 for (auto const& page_1 : m_1->template GetBatches<Page>(ctx, {})) {
40 auto const& cuts_0 = get_cuts(page_0);
41 auto const& cuts_1 = get_cuts(page_1);
42 ASSERT_NE(cuts_0.Values(), cuts_1.Values());
43 ASSERT_NE(cuts_0.Ptrs(), cuts_1.Ptrs());
44 }
45 }
46
47 // Use DMatrix as ref
48 auto dm = RandomDataGenerator(2048, Iter::Cols(), 0.5).GenerateDMatrix(true);
49 auto dqm = std::make_shared<IterativeDMatrix>(&iter_1, iter_1.Proxy(), dm, Reset, Next,
50 std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
51 for (auto const& page_0 : dm->template GetBatches<Page>(ctx, {})) {
52 for (auto const& page_1 : dqm->template GetBatches<Page>(ctx, {})) {
53 auto const& cuts_0 = get_cuts(page_0);
54 auto const& cuts_1 = get_cuts(page_1);
55 ASSERT_EQ(cuts_0.Values(), cuts_1.Values());
56 ASSERT_EQ(cuts_0.Ptrs(), cuts_1.Ptrs());
57 ASSERT_EQ(cuts_0.MinValues(), cuts_1.MinValues());
58 }
59 }
60}
61} // namespace data
62} // namespace xgboost
Copyright 2014-2023, XGBoost Contributors.
namespace of xgboost
Definition base.h:90