Medial Code Documentation
Loading...
Searching...
No Matches
expand_entry.h
1
4#ifndef XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
5#define XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
6
7#include <algorithm> // for all_of
8#include <ostream> // for ostream
9#include <utility> // for move
10#include <vector> // for vector
11
12#include "../param.h" // for SplitEntry, SplitEntryContainer, TrainParam
13#include "xgboost/base.h" // for GradientPairPrecise, bst_node_t
14
15namespace xgboost::tree {
19template <typename Impl>
21 bst_node_t nid{0};
22 bst_node_t depth{0};
23
24 [[nodiscard]] float GetLossChange() const {
25 return static_cast<Impl const*>(this)->split.loss_chg;
26 }
27 [[nodiscard]] bst_node_t GetNodeId() const { return nid; }
28
29 [[nodiscard]] bool IsValid(TrainParam const& param, bst_node_t num_leaves) const {
30 return static_cast<Impl const*>(this)->IsValidImpl(param, num_leaves);
31 }
32};
33
34struct CPUExpandEntry : public ExpandEntryImpl<CPUExpandEntry> {
35 SplitEntry split;
36
37 CPUExpandEntry() = default;
39 : ExpandEntryImpl{nidx, depth}, split(std::move(split)) {}
40 CPUExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {}
41
42 [[nodiscard]] bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const {
43 if (split.loss_chg <= kRtEps) return false;
44 if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
45 return false;
46 }
47 if (split.loss_chg < param.min_split_loss) {
48 return false;
49 }
50 if (param.max_depth > 0 && depth == param.max_depth) {
51 return false;
52 }
53 if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
54 return false;
55 }
56 return true;
57 }
58
59 friend std::ostream& operator<<(std::ostream& os, CPUExpandEntry const& e) {
60 os << "ExpandEntry:\n";
61 os << "nidx: " << e.nid << "\n";
62 os << "depth: " << e.depth << "\n";
63 os << "loss: " << e.split.loss_chg << "\n";
64 os << "split:\n" << e.split << std::endl;
65 return os;
66 }
67
77 void CopyAndCollect(CPUExpandEntry const& that, std::vector<uint32_t>* collected_cat_bits,
78 std::vector<std::size_t>* cat_bits_sizes) {
79 nid = that.nid;
80 depth = that.depth;
81 split.CopyAndCollect(that.split, collected_cat_bits, cat_bits_sizes);
82 }
83};
84
85struct MultiExpandEntry : public ExpandEntryImpl<MultiExpandEntry> {
87
88 MultiExpandEntry() = default;
89 MultiExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {}
90
91 [[nodiscard]] bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const {
92 if (split.loss_chg <= kRtEps) return false;
93 auto is_zero = [](auto const& sum) {
94 return std::all_of(sum.cbegin(), sum.cend(),
95 [&](auto const& g) { return g.GetHess() - .0 == .0; });
96 };
97 if (is_zero(split.left_sum) || is_zero(split.right_sum)) {
98 return false;
99 }
100 if (split.loss_chg < param.min_split_loss) {
101 return false;
102 }
103 if (param.max_depth > 0 && depth == param.max_depth) {
104 return false;
105 }
106 if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
107 return false;
108 }
109 return true;
110 }
111
112 friend std::ostream& operator<<(std::ostream& os, MultiExpandEntry const& e) {
113 os << "ExpandEntry: \n";
114 os << "nidx: " << e.nid << "\n";
115 os << "depth: " << e.depth << "\n";
116 os << "loss: " << e.split.loss_chg << "\n";
117 os << "split cond:" << e.split.split_value << "\n";
118 os << "split ind:" << e.split.SplitIndex() << "\n";
119 os << "left_sum: [";
120 for (auto v : e.split.left_sum) {
121 os << v << ", ";
122 }
123 os << "]\n";
124
125 os << "right_sum: [";
126 for (auto v : e.split.right_sum) {
127 os << v << ", ";
128 }
129 os << "]\n";
130 return os;
131 }
132
143 void CopyAndCollect(MultiExpandEntry const& that, std::vector<uint32_t>* collected_cat_bits,
144 std::vector<std::size_t>* cat_bits_sizes,
145 std::vector<GradientPairPrecise>* collected_gradients) {
146 nid = that.nid;
147 depth = that.depth;
148 split.CopyAndCollect(that.split, collected_cat_bits, cat_bits_sizes, collected_gradients);
149 }
150};
151} // namespace xgboost::tree
152#endif // XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2021-2023 by XGBoost Contributors.
Definition tree_updater.h:25
std::int32_t bst_node_t
Type for tree node index.
Definition base.h:112
constexpr bst_float kRtEps
small eps gap for minimum split decision.
Definition base.h:319
Definition expand_entry.h:34
void CopyAndCollect(CPUExpandEntry const &that, std::vector< uint32_t > *collected_cat_bits, std::vector< std::size_t > *cat_bits_sizes)
Copy primitive fields into this, and collect cat_bits into a vector.
Definition expand_entry.h:77
Structure for storing tree split candidate.
Definition expand_entry.h:20
Definition expand_entry.h:85
void CopyAndCollect(MultiExpandEntry const &that, std::vector< uint32_t > *collected_cat_bits, std::vector< std::size_t > *cat_bits_sizes, std::vector< GradientPairPrecise > *collected_gradients)
Copy primitive fields into this, and collect cat_bits and gradients into vectors.
Definition expand_entry.h:143
training parameters for regression tree
Definition param.h:28