43 if (split.loss_chg <=
kRtEps)
return false;
44 if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
47 if (split.loss_chg < param.min_split_loss) {
50 if (param.max_depth > 0 && depth == param.max_depth) {
53 if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
59 friend std::ostream& operator<<(std::ostream& os,
CPUExpandEntry const& e) {
60 os <<
"ExpandEntry:\n";
61 os <<
"nidx: " << e.nid <<
"\n";
62 os <<
"depth: " << e.depth <<
"\n";
63 os <<
"loss: " << e.split.loss_chg <<
"\n";
64 os <<
"split:\n" << e.split << std::endl;
78 std::vector<std::size_t>* cat_bits_sizes) {
81 split.CopyAndCollect(that.split, collected_cat_bits, cat_bits_sizes);
92 if (split.loss_chg <=
kRtEps)
return false;
93 auto is_zero = [](
auto const& sum) {
94 return std::all_of(sum.cbegin(), sum.cend(),
95 [&](
auto const& g) { return g.GetHess() - .0 == .0; });
97 if (is_zero(split.left_sum) || is_zero(split.right_sum)) {
100 if (split.loss_chg < param.min_split_loss) {
103 if (param.max_depth > 0 && depth == param.max_depth) {
106 if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
112 friend std::ostream& operator<<(std::ostream& os,
MultiExpandEntry const& e) {
113 os <<
"ExpandEntry: \n";
114 os <<
"nidx: " << e.nid <<
"\n";
115 os <<
"depth: " << e.depth <<
"\n";
116 os <<
"loss: " << e.split.loss_chg <<
"\n";
117 os <<
"split cond:" << e.split.split_value <<
"\n";
118 os <<
"split ind:" << e.split.SplitIndex() <<
"\n";
120 for (
auto v : e.split.left_sum) {
125 os <<
"right_sum: [";
126 for (
auto v : e.split.right_sum) {
144 std::vector<std::size_t>* cat_bits_sizes,
145 std::vector<GradientPairPrecise>* collected_gradients) {
148 split.CopyAndCollect(that.split, collected_cat_bits, cat_bits_sizes, collected_gradients);
Definition expand_entry.h:34
void CopyAndCollect(CPUExpandEntry const &that, std::vector< uint32_t > *collected_cat_bits, std::vector< std::size_t > *cat_bits_sizes)
Copy primitive fields into this, and collect cat_bits into a vector.
Definition expand_entry.h:77
Structure for storing tree split candidate.
Definition expand_entry.h:20
Definition expand_entry.h:85
void CopyAndCollect(MultiExpandEntry const &that, std::vector< uint32_t > *collected_cat_bits, std::vector< std::size_t > *cat_bits_sizes, std::vector< GradientPairPrecise > *collected_gradients)
Copy primitive fields into this, and collect cat_bits and gradients into vectors.
Definition expand_entry.h:143