41 std::shared_ptr<common::ColumnSampler> column_sampler_;
43 bool is_col_split_{
false};
45 std::vector<NodeEntry> snode_;
51 bool static SplitContainsMissingValues(
const GradStats e,
const NodeEntry &snode) {
52 if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) {
59 [[nodiscard]]
bool IsValid(GradStats
const &left, GradStats
const &right)
const {
60 return left.GetHess() >= param_->min_child_weight &&
61 right.GetHess() >= param_->min_child_weight;
73 const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
74 const std::vector<bst_float> &cut_val = cut.Values();
86 auto f_hist = hist.
subspan(cut_ptr[fidx], n_bins);
87 auto feature_sum = GradStats{
90 auto const &parent = snode_[nidx];
91 missing.SetSubstract(parent.stats, feature_sum);
93 for (
bst_bin_t i = ibegin; i != iend; i += 1) {
94 auto split_pt = cut_val[i];
97 right_sum = GradStats{hist[i]};
98 left_sum.SetSubstract(parent.stats, right_sum);
99 if (IsValid(left_sum, right_sum)) {
100 auto missing_left_chg =
101 static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
102 GradStats{right_sum}) -
104 best.
Update(missing_left_chg, fidx, split_pt,
true,
true, left_sum, right_sum);
108 right_sum.Add(missing);
109 left_sum.SetSubstract(parent.stats, right_sum);
110 if (IsValid(left_sum, right_sum)) {
111 auto missing_right_chg =
112 static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
113 GradStats{right_sum}) -
115 best.
Update(missing_right_chg, fidx, split_pt,
false,
true, left_sum, right_sum);
120 auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
121 best.cat_bits.resize(n, 0);
123 cat_bits.Set(best.split_value);
145 template <
int d_step>
150 static_assert(d_step == +1 || d_step == -1,
"Invalid step.");
152 auto const &cut_ptr = cut.Ptrs();
153 auto const &cut_val = cut.Values();
154 auto const &parent = snode_[nidx];
158 bst_bin_t n_bins_feature{f_end - f_begin};
159 auto n_bins = std::min(param_->max_cat_threshold, n_bins_feature);
167 auto f_hist = hist.
subspan(f_begin, n_bins_feature);
171 it_end = it_begin + n_bins - 1;
173 it_begin = f_end - 1;
174 it_end = it_begin - n_bins + 1;
178 for (
bst_bin_t i = it_begin; i != it_end; i += d_step) {
179 auto j = i - f_begin;
181 right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
182 left_sum.SetSubstract(parent.stats, right_sum);
184 left_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
185 right_sum.SetSubstract(parent.stats, left_sum);
187 if (IsValid(left_sum, right_sum)) {
188 auto loss_chg = evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
189 GradStats{right_sum}) -
192 if (best.
Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1,
true,
193 left_sum, right_sum)) {
199 if (best_thresh != -1) {
200 auto n = common::CatBitField::ComputeStorageSize(n_bins_feature);
201 best.cat_bits =
decltype(best.cat_bits)(n, 0);
203 bst_bin_t partition = d_step == 1 ? (best_thresh - it_begin + 1) : (best_thresh - f_begin);
204 CHECK_GT(partition, 0);
205 std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, [&](
size_t c) {
206 auto cat = cut_val[c + f_begin];
217 template <
int d_step>
222 static_assert(d_step == +1 || d_step == -1,
"Invalid step.");
225 const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
226 const std::vector<bst_float> &cut_val = cut.Values();
227 auto const &parent = snode_[nidx];
236 CHECK_LE(cut_ptr[fidx],
static_cast<uint32_t
>(std::numeric_limits<bst_bin_t>::max()));
237 CHECK_LE(cut_ptr[fidx + 1],
static_cast<uint32_t
>(std::numeric_limits<bst_bin_t>::max()));
240 const auto imin =
static_cast<bst_bin_t>(cut_ptr[fidx]);
245 ibegin =
static_cast<bst_bin_t>(cut_ptr[fidx]);
246 iend =
static_cast<bst_bin_t>(cut_ptr.at(fidx + 1));
248 ibegin =
static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
249 iend =
static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
252 for (
bst_bin_t i = ibegin; i != iend; i += d_step) {
255 left_sum.Add(hist[i].GetGrad(), hist[i].GetHess());
256 right_sum.SetSubstract(parent.stats, left_sum);
257 if (IsValid(left_sum, right_sum)) {
263 static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
264 GradStats{right_sum}) -
266 split_pt = cut_val[i];
267 best.
Update(loss_chg, fidx, split_pt, d_step == -1,
false, left_sum, right_sum);
271 static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum},
272 GradStats{left_sum}) -
275 split_pt = cut.MinValues()[fidx];
277 split_pt = cut_val[i - 1];
279 best.
Update(loss_chg, fidx, split_pt, d_step == -1,
false, right_sum, left_sum);
293 std::vector<CPUExpandEntry> Allgather(std::vector<CPUExpandEntry>
const &entries) {
296 auto const num_entries = entries.size();
299 std::vector<CPUExpandEntry> all_entries(num_entries * world);
300 std::vector<uint32_t> cat_bits;
301 std::vector<std::size_t> cat_bits_sizes;
302 for (std::size_t i = 0; i < num_entries; i++) {
303 all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes);
310 common::ParallelFor(num_entries * world, ctx_->
Threads(), [&] (
auto i) {
312 all_entries[i].split.cat_bits.resize(gathered.sizes[i]);
313 std::copy_n(gathered.result.cbegin() + gathered.offsets[i], gathered.sizes[i],
314 all_entries[i].split.cat_bits.begin());
323 std::vector<CPUExpandEntry> *p_entries) {
324 auto n_threads = ctx_->
Threads();
325 auto& entries = *p_entries;
327 std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(
329 for (
size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
330 auto nidx = entries[nidx_in_set].nid;
331 features[nidx_in_set] =
332 column_sampler_->GetFeatureSet(tree.
GetDepth(nidx));
334 CHECK(!features.empty());
335 const size_t grain_size =
336 std::max<size_t>(1, features.front()->Size() / n_threads);
338 return features[nidx_in_set]->Size();
341 std::vector<CPUExpandEntry> tloc_candidates(n_threads * entries.size());
342 for (
size_t i = 0; i < entries.size(); ++i) {
343 for (
decltype(n_threads) j = 0; j < n_threads; ++j) {
344 tloc_candidates[i * n_threads + j] = entries[i];
347 auto evaluator = tree_evaluator_.GetEvaluator();
348 auto const& cut_ptrs = cut.Ptrs();
350 common::ParallelFor2d(space, n_threads, [&](
size_t nidx_in_set,
common::Range1d r) {
351 auto tidx = omp_get_thread_num();
352 auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
353 auto best = &entry->split;
354 auto nidx = entry->nid;
355 auto histogram = hist[nidx];
356 auto features_set = features[nidx_in_set]->ConstHostSpan();
357 for (
auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
358 auto fidx = features_set[fidx_in_set];
359 bool is_cat = common::IsCat(feature_types, fidx);
360 if (!interaction_constraints_.Query(nidx, fidx)) {
364 auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
366 EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
368 std::vector<size_t> sorted_idx(n_bins);
369 std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
370 auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
372 std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](
size_t l,
size_t r) {
373 auto ret = evaluator.CalcWeightCat(*param_, feat_hist[l]) <
374 evaluator.CalcWeightCat(*param_, feat_hist[r]);
377 EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
378 EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
381 auto grad_stats = EnumerateSplit<+1>(cut, histogram, fidx, nidx, evaluator, best);
382 if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
383 EnumerateSplit<-1>(cut, histogram, fidx, nidx, evaluator, best);
389 for (
unsigned nidx_in_set = 0; nidx_in_set < entries.size();
391 for (
auto tidx = 0; tidx < n_threads; ++tidx) {
392 entries[nidx_in_set].split.Update(
393 tloc_candidates[n_threads * nidx_in_set + tidx].split);
400 auto all_entries = Allgather(entries);
402 for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
403 entries[nidx_in_set].split.Update(
404 all_entries[worker * entries.size() + nidx_in_set].split);
412 auto evaluator = tree_evaluator_.GetEvaluator();
415 GradStats parent_sum = candidate.split.left_sum;
416 parent_sum.Add(candidate.split.right_sum);
417 auto base_weight = evaluator.CalcWeight(candidate.nid, *param_, GradStats{parent_sum});
419 evaluator.CalcWeight(candidate.nid, *param_, GradStats{candidate.split.left_sum});
421 evaluator.CalcWeight(candidate.nid, *param_, GradStats{candidate.split.right_sum});
423 if (candidate.split.is_cat) {
425 candidate.nid, candidate.split.
SplitIndex(), candidate.split.cat_bits,
426 candidate.split.
DefaultLeft(), base_weight, left_weight * param_->learning_rate,
427 right_weight * param_->learning_rate, candidate.split.
loss_chg, parent_sum.GetHess(),
428 candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
432 left_weight * param_->learning_rate, right_weight * param_->learning_rate,
433 candidate.split.
loss_chg, parent_sum.GetHess(),
434 candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
438 auto left_child = tree[candidate.nid].LeftChild();
439 auto right_child = tree[candidate.nid].RightChild();
440 tree_evaluator_.AddSplit(candidate.nid, left_child, right_child,
441 tree[candidate.nid].SplitIndex(), left_weight,
443 evaluator = tree_evaluator_.GetEvaluator();
445 snode_.resize(tree.
GetNodes().size());
446 snode_.at(left_child).stats = candidate.split.left_sum;
447 snode_.at(left_child).root_gain =
448 evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.left_sum});
449 snode_.at(right_child).stats = candidate.split.right_sum;
450 snode_.at(right_child).root_gain =
451 evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.right_sum});
453 interaction_constraints_.Split(candidate.nid,
454 tree[candidate.nid].SplitIndex(), left_child,
458 [[nodiscard]]
auto Evaluator()
const {
return tree_evaluator_.GetEvaluator(); }
459 [[nodiscard]]
auto const &Stats()
const {
return snode_; }
461 float InitRoot(GradStats
const &root_sum) {
463 auto root_evaluator = tree_evaluator_.GetEvaluator();
465 snode_[0].stats = GradStats{root_sum.GetGrad(), root_sum.GetHess()};
466 snode_[0].root_gain =
467 root_evaluator.CalcGain(RegTree::kRoot, *param_, GradStats{snode_[0].stats});
468 auto weight = root_evaluator.CalcWeight(RegTree::kRoot, *param_, GradStats{snode_[0].stats});
476 std::shared_ptr<common::ColumnSampler> sampler)
479 column_sampler_{std::move(sampler)},
482 interaction_constraints_.Configure(*param, info.
num_col_);
483 column_sampler_->Init(ctx, info.
num_col_, info.feature_weights.HostVector(),
484 param_->colsample_bynode, param_->colsample_bylevel,
485 param_->colsample_bytree);
490 std::vector<double> gain_;
494 std::shared_ptr<common::ColumnSampler> column_sampler_;
496 bool is_col_split_{
false};
499 static double MultiCalcSplitGain(
TrainParam const ¶m,
504 CalcWeight(param, left_sum, left_weight);
505 CalcWeight(param, right_sum, right_weight);
507 auto left_gain = CalcGainGivenWeight(param, left_sum, left_weight);
508 auto right_gain = CalcGainGivenWeight(param, right_sum, right_weight);
509 return left_gain + right_gain;
512 template <bst_bin_t d_step>
517 auto const &cut_ptr = cut.Ptrs();
518 auto const &cut_val = cut.Values();
519 auto const &min_val = cut.MinValues();
521 auto sum = linalg::Empty<GradientPairPrecise>(ctx_, 2, hist.size());
527 ibegin =
static_cast<bst_bin_t>(cut_ptr[fidx]);
528 iend =
static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
530 ibegin =
static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
531 iend =
static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
533 const auto imin =
static_cast<bst_bin_t>(cut_ptr[fidx]);
535 auto n_targets = hist.size();
536 auto weight = linalg::Empty<float>(ctx_, 2, n_targets);
538 auto right_weight = weight.Slice(1,
linalg::All());
540 for (
bst_bin_t i = ibegin; i != iend; i += d_step) {
542 auto t_hist = hist[t];
543 auto t_p = parent_sum(t);
544 left_sum(t) += t_hist[i];
545 right_sum(t) = t_p - left_sum(t);
549 auto split_pt = cut_val[i];
551 MultiCalcSplitGain(*param_, right_sum, left_sum, right_weight, left_weight) -
553 p_best->Update(loss_chg, fidx, split_pt, d_step == -1,
false, left_sum, right_sum);
557 split_pt = min_val[fidx];
559 split_pt = cut_val[i - 1];
562 MultiCalcSplitGain(*param_, right_sum, left_sum, left_weight, right_weight) -
564 p_best->Update(loss_chg, fidx, split_pt, d_step == -1,
false, right_sum, left_sum);
569 return !std::equal(linalg::cbegin(left_sum), linalg::cend(left_sum),
570 linalg::cbegin(parent_sum));
580 std::vector<MultiExpandEntry> Allgather(std::vector<MultiExpandEntry>
const &entries) {
583 auto const num_entries = entries.size();
586 std::vector<MultiExpandEntry> all_entries(num_entries * world);
587 std::vector<uint32_t> cat_bits;
588 std::vector<std::size_t> cat_bits_sizes;
589 std::vector<GradientPairPrecise> gradients;
590 for (std::size_t i = 0; i < num_entries; i++) {
591 all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes,
600 auto const num_gradients = gradients.size();
601 std::vector<GradientPairPrecise> all_gradients(num_gradients * world);
602 std::copy_n(gradients.cbegin(), num_gradients, all_gradients.begin() + num_gradients * rank);
605 auto const total_entries = num_entries * world;
606 auto const gradients_per_entry = num_gradients / num_entries;
607 auto const gradients_per_side = gradients_per_entry / 2;
608 common::ParallelFor(total_entries, ctx_->
Threads(), [&] (
auto i) {
610 all_entries[i].split.cat_bits.resize(gathered_cat_bits.sizes[i]);
611 std::copy_n(gathered_cat_bits.result.cbegin() + gathered_cat_bits.offsets[i],
612 gathered_cat_bits.sizes[i], all_entries[i].split.cat_bits.begin());
615 all_entries[i].split.left_sum.resize(gradients_per_side);
616 std::copy_n(all_gradients.cbegin() + i * gradients_per_entry, gradients_per_side,
617 all_entries[i].split.left_sum.begin());
618 all_entries[i].split.right_sum.resize(gradients_per_side);
619 std::copy_n(all_gradients.cbegin() + i * gradients_per_entry + gradients_per_side,
620 gradients_per_side, all_entries[i].split.right_sum.begin());
629 auto &entries = *p_entries;
630 std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(entries.size());
632 for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
633 auto nidx = entries[nidx_in_set].nid;
634 features[nidx_in_set] = column_sampler_->GetFeatureSet(tree.
GetDepth(nidx));
636 CHECK(!features.empty());
638 std::int32_t n_threads = ctx_->
Threads();
639 std::size_t
const grain_size = std::max<std::size_t>(1, features.front()->Size() / n_threads);
641 entries.size(), [&](std::size_t nidx_in_set) { return features[nidx_in_set]->Size(); },
644 std::vector<MultiExpandEntry> tloc_candidates(n_threads * entries.size());
645 for (std::size_t i = 0; i < entries.size(); ++i) {
646 for (std::int32_t j = 0; j < n_threads; ++j) {
647 tloc_candidates[i * n_threads + j] = entries[i];
650 common::ParallelFor2d(space, n_threads, [&](std::size_t nidx_in_set,
common::Range1d r) {
651 auto tidx = omp_get_thread_num();
652 auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
653 auto best = &entry->split;
655 std::vector<common::ConstGHistRow> node_hist;
656 for (
auto t_hist : hist) {
657 node_hist.emplace_back((*t_hist)[entry->nid]);
659 auto features_set = features[nidx_in_set]->ConstHostSpan();
661 for (
auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
662 auto fidx = features_set[fidx_in_set];
663 if (!interaction_constraints_.Query(entry->nid, fidx)) {
666 auto parent_gain = gain_[entry->nid];
668 this->EnumerateSplit<+1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
670 this->EnumerateSplit<-1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
675 for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
676 for (
auto tidx = 0; tidx < n_threads; ++tidx) {
677 entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
684 auto all_entries = Allgather(entries);
686 for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
687 entries[nidx_in_set].split.Update(
688 all_entries[worker * entries.size() + nidx_in_set].split);
695 auto n_targets = root_sum.
Size();
700 CalcWeight(*param_, root_sum, weight.HostView());
701 auto root_gain = CalcGainGivenWeight(*param_, root_sum, weight.HostView());
702 gain_.front() = root_gain;
704 auto h_stats = stats_.HostView();
705 std::copy(linalg::cbegin(root_sum), linalg::cend(root_sum), linalg::begin(h_stats));
714 auto weight = linalg::Empty<float>(ctx_, 3, n_targets);
716 CalcWeight(*param_, parent_sum, base_weight);
720 linalg::MakeVec(candidate.split.left_sum.data(), candidate.split.left_sum.size());
721 CalcWeight(*param_, left_sum, param_->learning_rate, left_weight);
723 auto right_weight = weight.Slice(2,
linalg::All());
725 linalg::MakeVec(candidate.split.right_sum.data(), candidate.split.right_sum.size());
726 CalcWeight(*param_, right_sum, param_->learning_rate, right_weight);
728 p_tree->
ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
729 candidate.split.DefaultLeft(), base_weight, left_weight, right_weight);
731 auto left_child = p_tree->LeftChild(candidate.nid);
732 CHECK_GT(left_child, candidate.nid);
733 auto right_child = p_tree->RightChild(candidate.nid);
734 CHECK_GT(right_child, candidate.nid);
736 std::size_t n_nodes = p_tree->Size();
737 gain_.resize(n_nodes);
738 gain_[left_child] = CalcGainGivenWeight(*param_, left_sum, left_weight);
739 gain_[right_child] = CalcGainGivenWeight(*param_, right_sum, right_weight);
741 if (n_nodes >= stats_.Shape(0)) {
742 stats_.
Reshape(n_nodes * 2, stats_.Shape(1));
744 CHECK_EQ(stats_.Shape(1), n_targets);
746 std::copy(candidate.split.left_sum.cbegin(), candidate.split.left_sum.cend(),
747 linalg::begin(left_sum_stat));
749 std::copy(candidate.split.right_sum.cbegin(), candidate.split.right_sum.cend(),
750 linalg::begin(right_sum_stat));
754 std::shared_ptr<common::ColumnSampler> sampler)
756 column_sampler_{std::move(sampler)},
759 interaction_constraints_.Configure(*param, info.
num_col_);
760 column_sampler_->Init(ctx, info.
num_col_, info.feature_weights.HostVector(),
761 param_->colsample_bynode, param_->colsample_bylevel,
762 param_->colsample_bytree);