Medial Code Documentation
Loading...
Searching...
No Matches
test_evaluate_splits.h
1
4#include <gtest/gtest.h>
5#include <xgboost/base.h> // for GradientPairInternal, GradientPairPrecise
6#include <xgboost/data.h> // for MetaInfo
7#include <xgboost/host_device_vector.h> // for HostDeviceVector
8#include <xgboost/span.h> // for operator!=, Span, SpanIterator
9
10#include <algorithm> // for max, max_element, next_permutation, copy
11#include <cmath> // for isnan
12#include <cstddef> // for size_t
13#include <cstdint> // for int32_t, uint64_t, uint32_t
14#include <limits> // for numeric_limits
15#include <numeric> // for iota
16#include <tuple> // for make_tuple, tie, tuple
17#include <utility> // for pair
18#include <vector> // for vector
19
20#include "../../../src/common/hist_util.h" // for HistogramCuts, HistCollection, GHistRow
21#include "../../../src/tree/hist/hist_cache.h" // for HistogramCollection
22#include "../../../src/tree/hist/param.h" // for HistMakerTrainParam
23#include "../../../src/tree/param.h" // for TrainParam, GradStats
24#include "../../../src/tree/split_evaluator.h" // for TreeEvaluator
25#include "../helpers.h" // for SimpleLCG, SimpleRealUniformDistribution
26#include "gtest/gtest_pred_impl.h" // for AssertionResult, ASSERT_EQ, ASSERT_TRUE
27
28namespace xgboost::tree {
32class TestPartitionBasedSplit : public ::testing::Test {
33 protected:
34 size_t n_bins_ = 6;
35 std::vector<size_t> sorted_idx_;
36 TrainParam param_;
37 MetaInfo info_;
38 float best_score_{-std::numeric_limits<float>::infinity()};
41 GradientPairPrecise total_gpair_;
42
43 void SetUp() override {
44 param_.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});
45 sorted_idx_.resize(n_bins_);
46 std::iota(sorted_idx_.begin(), sorted_idx_.end(), 0);
47
48 info_.num_col_ = 1;
49
50 cuts_.cut_ptrs_.Resize(2);
51 cuts_.SetCategorical(true, n_bins_);
52 auto &h_cuts = cuts_.cut_ptrs_.HostVector();
53 h_cuts[0] = 0;
54 h_cuts[1] = n_bins_;
55 auto &h_vals = cuts_.cut_values_.HostVector();
56 h_vals.resize(n_bins_);
57 std::iota(h_vals.begin(), h_vals.end(), 0.0);
58
59 cuts_.min_vals_.Resize(1);
60
61 HistMakerTrainParam hist_param;
62 hist_.Reset(cuts_.TotalBins(), hist_param.max_cached_hist_node);
63 hist_.AllocateHistograms({0});
64 auto node_hist = hist_[0];
65
66 SimpleLCG lcg;
67 SimpleRealUniformDistribution<double> grad_dist{-4.0, 4.0};
68 SimpleRealUniformDistribution<double> hess_dist{0.0, 4.0};
69
70 for (auto &e : node_hist) {
71 e = GradientPairPrecise{grad_dist(&lcg), hess_dist(&lcg)};
72 total_gpair_ += e;
73 }
74
75 auto enumerate = [this, n_feat = info_.num_col_](common::GHistRow hist,
76 GradientPairPrecise parent_sum) {
77 int32_t best_thresh = -1;
78 float best_score{-std::numeric_limits<float>::infinity()};
79 TreeEvaluator evaluator{param_, static_cast<bst_feature_t>(n_feat), -1};
80 auto tree_evaluator = evaluator.GetEvaluator<TrainParam>();
81 GradientPairPrecise left_sum;
82 auto parent_gain = tree_evaluator.CalcGain(0, param_, GradStats{total_gpair_});
83 for (size_t i = 0; i < hist.size() - 1; ++i) {
84 left_sum += hist[i];
85 auto right_sum = parent_sum - left_sum;
86 auto gain =
87 tree_evaluator.CalcSplitGain(param_, 0, 0, GradStats{left_sum}, GradStats{right_sum}) -
88 parent_gain;
89 if (gain > best_score) {
90 best_score = gain;
91 best_thresh = i;
92 }
93 }
94 return std::make_tuple(best_thresh, best_score);
95 };
96
97 // enumerate all possible partitions to find the optimal split
98 do {
99 int32_t thresh;
100 float score;
101 std::vector<GradientPairPrecise> sorted_hist(node_hist.size());
102 for (size_t i = 0; i < sorted_hist.size(); ++i) {
103 sorted_hist[i] = node_hist[sorted_idx_[i]];
104 }
105 std::tie(thresh, score) = enumerate({sorted_hist}, total_gpair_);
106 if (score > best_score_) {
107 best_score_ = score;
108 }
109 } while (std::next_permutation(sorted_idx_.begin(), sorted_idx_.end()));
110 }
111};
112
113inline auto MakeCutsForTest(std::vector<float> values, std::vector<uint32_t> ptrs,
114 std::vector<float> min_values, int32_t device) {
116 cuts.cut_values_.HostVector() = values;
117 cuts.cut_ptrs_.HostVector() = ptrs;
118 cuts.min_vals_.HostVector() = min_values;
119
120 if (device >= 0) {
121 cuts.cut_ptrs_.SetDevice(device);
122 cuts.cut_values_.SetDevice(device);
123 cuts.min_vals_.SetDevice(device);
124 }
125
126 return cuts;
127}
128
129class TestCategoricalSplitWithMissing : public testing::Test {
130 protected:
132 // Setup gradients and parent sum with missing values.
133 GradientPairPrecise parent_sum_{1.0, 6.0};
134 std::vector<GradientPairPrecise> feature_histogram_{
135 {0.5, 0.5}, {0.5, 0.5}, {1.0, 1.0}, {1.0, 1.0}};
136 TrainParam param_;
137
138 void SetUp() override {
139 cuts_ = MakeCutsForTest({0.0, 1.0, 2.0, 3.0}, {0, 4}, {0.0}, -1);
140 auto max_cat = *std::max_element(cuts_.cut_values_.HostVector().begin(),
141 cuts_.cut_values_.HostVector().end());
142 cuts_.SetCategorical(true, max_cat);
143 param_.UpdateAllowUnknown(
144 Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}, {"max_cat_to_onehot", "1"}});
145 }
146
147 void CheckResult(float loss_chg, bst_feature_t split_ind, float fvalue, bool is_cat,
148 bool dft_left, GradientPairPrecise left_sum, GradientPairPrecise right_sum) {
149 // forward
150 // it: 0, gain: 0.545455
151 // it: 1, gain: 1.000000
152 // it: 2, gain: 2.250000
153 // backward
154 // it: 3, gain: 1.000000
155 // it: 2, gain: 2.250000
156 // it: 1, gain: 3.142857
157 ASSERT_NEAR(loss_chg, 2.97619, kRtEps);
158 ASSERT_TRUE(is_cat);
159 ASSERT_TRUE(std::isnan(fvalue));
160 ASSERT_EQ(split_ind, 0);
161 ASSERT_FALSE(dft_left);
162 ASSERT_EQ(left_sum.GetHess(), 2.5);
163 ASSERT_EQ(right_sum.GetHess(), parent_sum_.GetHess() - left_sum.GetHess());
164 }
165};
166} // namespace xgboost::tree
Meta information about dataset, always sit in memory.
Definition data.h:48
uint64_t num_col_
number of columns in the data
Definition data.h:56
Linear congruential generator.
Definition helpers.h:124
Definition hist_util.h:37
void SetCategorical(bool has_cat, float max_cat)
Set meta info about categorical features.
Definition hist_util.h:101
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
A persistent cache for CPU histogram.
Definition hist_cache.h:30
void AllocateHistograms(common::Span< bst_node_t const > nodes_to_build, common::Span< bst_node_t const > nodes_to_sub)
Allocate histogram buffers for all nodes.
Definition hist_cache.h:83
Definition test_evaluate_splits.h:129
Enumerate all possible partitions for categorical split.
Definition test_evaluate_splits.h:32
Definition split_evaluator.h:28
A device-and-host vector abstraction layer.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2021-2023 by XGBoost Contributors.
Definition tree_updater.h:25
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
constexpr bst_float kRtEps
small eps gap for minimum split decision.
Definition base.h:319
training parameters for regression tree
Definition param.h:28