4#include <gtest/gtest.h>
8#include <xgboost/span.h>
20#include "../../../src/common/hist_util.h"
21#include "../../../src/tree/hist/hist_cache.h"
22#include "../../../src/tree/hist/param.h"
23#include "../../../src/tree/param.h"
24#include "../../../src/tree/split_evaluator.h"
25#include "../helpers.h"
26#include "gtest/gtest_pred_impl.h"
35 std::vector<size_t> sorted_idx_;
38 float best_score_{-std::numeric_limits<float>::infinity()};
43 void SetUp()
override {
44 param_.UpdateAllowUnknown(Args{{
"min_child_weight",
"0"}, {
"reg_lambda",
"0"}});
45 sorted_idx_.resize(n_bins_);
46 std::iota(sorted_idx_.begin(), sorted_idx_.end(), 0);
50 cuts_.cut_ptrs_.Resize(2);
52 auto &h_cuts = cuts_.cut_ptrs_.HostVector();
55 auto &h_vals = cuts_.cut_values_.HostVector();
56 h_vals.resize(n_bins_);
57 std::iota(h_vals.begin(), h_vals.end(), 0.0);
59 cuts_.min_vals_.Resize(1);
62 hist_.Reset(cuts_.TotalBins(), hist_param.max_cached_hist_node);
64 auto node_hist = hist_[0];
70 for (
auto &e : node_hist) {
77 int32_t best_thresh = -1;
78 float best_score{-std::numeric_limits<float>::infinity()};
80 auto tree_evaluator = evaluator.GetEvaluator<
TrainParam>();
82 auto parent_gain = tree_evaluator.CalcGain(0, param_, GradStats{total_gpair_});
83 for (
size_t i = 0; i < hist.size() - 1; ++i) {
85 auto right_sum = parent_sum - left_sum;
87 tree_evaluator.CalcSplitGain(param_, 0, 0, GradStats{left_sum}, GradStats{right_sum}) -
89 if (gain > best_score) {
94 return std::make_tuple(best_thresh, best_score);
101 std::vector<GradientPairPrecise> sorted_hist(node_hist.size());
102 for (
size_t i = 0; i < sorted_hist.size(); ++i) {
103 sorted_hist[i] = node_hist[sorted_idx_[i]];
105 std::tie(thresh, score) = enumerate({sorted_hist}, total_gpair_);
106 if (score > best_score_) {
109 }
while (std::next_permutation(sorted_idx_.begin(), sorted_idx_.end()));
113inline auto MakeCutsForTest(std::vector<float> values, std::vector<uint32_t> ptrs,
114 std::vector<float> min_values, int32_t device) {
116 cuts.cut_values_.HostVector() = values;
117 cuts.cut_ptrs_.HostVector() = ptrs;
118 cuts.min_vals_.HostVector() = min_values;
121 cuts.cut_ptrs_.SetDevice(device);
122 cuts.cut_values_.SetDevice(device);
123 cuts.min_vals_.SetDevice(device);
134 std::vector<GradientPairPrecise> feature_histogram_{
135 {0.5, 0.5}, {0.5, 0.5}, {1.0, 1.0}, {1.0, 1.0}};
138 void SetUp()
override {
139 cuts_ = MakeCutsForTest({0.0, 1.0, 2.0, 3.0}, {0, 4}, {0.0}, -1);
140 auto max_cat = *std::max_element(cuts_.cut_values_.HostVector().begin(),
141 cuts_.cut_values_.HostVector().end());
143 param_.UpdateAllowUnknown(
144 Args{{
"min_child_weight",
"0"}, {
"reg_lambda",
"0"}, {
"max_cat_to_onehot",
"1"}});
147 void CheckResult(
float loss_chg,
bst_feature_t split_ind,
float fvalue,
bool is_cat,
157 ASSERT_NEAR(loss_chg, 2.97619,
kRtEps);
159 ASSERT_TRUE(std::isnan(fvalue));
160 ASSERT_EQ(split_ind, 0);
161 ASSERT_FALSE(dft_left);
162 ASSERT_EQ(left_sum.GetHess(), 2.5);
163 ASSERT_EQ(right_sum.GetHess(), parent_sum_.GetHess() - left_sum.GetHess());
Linear congruential generator.
Definition helpers.h:124
Definition hist_util.h:37
void SetCategorical(bool has_cat, float max_cat)
Set meta info about categorical features.
Definition hist_util.h:101
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
A persistent cache for CPU histogram.
Definition hist_cache.h:30
void AllocateHistograms(common::Span< bst_node_t const > nodes_to_build, common::Span< bst_node_t const > nodes_to_sub)
Allocate histogram buffers for all nodes.
Definition hist_cache.h:83
Definition test_evaluate_splits.h:129
Enumerate all possible partitions for categorical split.
Definition test_evaluate_splits.h:32
Definition split_evaluator.h:28
A device-and-host vector abstraction layer.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2021-2023 by XGBoost Contributors.
Definition tree_updater.h:25
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
constexpr bst_float kRtEps
small eps gap for minimum split decision.
Definition base.h:319
training parameters for regression tree
Definition param.h:28