Medial Code Documentation
Loading...
Searching...
No Matches
evaluate_splits.h
1
4#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
5#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
6
7#include <algorithm> // for copy
8#include <cstddef> // for size_t
9#include <limits> // for numeric_limits
10#include <memory> // for shared_ptr
11#include <numeric> // for accumulate
12#include <utility> // for move
13#include <vector> // for vector
14
15#include "../../common/categorical.h" // for CatBitField
16#include "../../common/hist_util.h" // for GHistRow, HistogramCuts
17#include "../../common/linalg_op.h" // for cbegin, cend, begin
18#include "../../common/random.h" // for ColumnSampler
19#include "../constraints.h" // for FeatureInteractionConstraintHost
20#include "../param.h" // for TrainParam
21#include "../split_evaluator.h" // for TreeEvaluator
22#include "expand_entry.h" // for MultiExpandEntry
23#include "hist_cache.h" // for BoundedHistCollection
24#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_feature_t
25#include "xgboost/context.h" // for COntext
26#include "xgboost/linalg.h" // for Constants, Vector
27
28namespace xgboost::tree {
30 private:
31 struct NodeEntry {
33 GradStats stats;
35 bst_float root_gain{0.0f};
36 };
37
38 private:
39 Context const* ctx_;
40 TrainParam const* param_;
41 std::shared_ptr<common::ColumnSampler> column_sampler_;
42 TreeEvaluator tree_evaluator_;
43 bool is_col_split_{false};
44 FeatureInteractionConstraintHost interaction_constraints_;
45 std::vector<NodeEntry> snode_;
46
47 // if sum of statistics for non-missing values in the node
48 // is equal to sum of statistics for all values:
49 // then - there are no missing values
50 // else - there are missing values
51 bool static SplitContainsMissingValues(const GradStats e, const NodeEntry &snode) {
52 if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) {
53 return false;
54 } else {
55 return true;
56 }
57 }
58
59 [[nodiscard]] bool IsValid(GradStats const &left, GradStats const &right) const {
60 return left.GetHess() >= param_->min_child_weight &&
61 right.GetHess() >= param_->min_child_weight;
62 }
63
69 void EnumerateOneHot(common::HistogramCuts const &cut, common::ConstGHistRow hist,
70 bst_feature_t fidx, bst_node_t nidx,
72 SplitEntry *p_best) const {
73 const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
74 const std::vector<bst_float> &cut_val = cut.Values();
75
76 bst_bin_t ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
77 bst_bin_t iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
78 bst_bin_t n_bins = iend - ibegin;
79
80 GradStats left_sum;
81 GradStats right_sum;
82 // best split so far
83 SplitEntry best;
84 best.is_cat = false; // marker for whether it's updated or not.
85
86 auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
87 auto feature_sum = GradStats{
88 std::accumulate(f_hist.data(), f_hist.data() + f_hist.size(), GradientPairPrecise{})};
89 GradStats missing;
90 auto const &parent = snode_[nidx];
91 missing.SetSubstract(parent.stats, feature_sum);
92
93 for (bst_bin_t i = ibegin; i != iend; i += 1) {
94 auto split_pt = cut_val[i];
95
96 // missing on left (treat missing as other categories)
97 right_sum = GradStats{hist[i]};
98 left_sum.SetSubstract(parent.stats, right_sum);
99 if (IsValid(left_sum, right_sum)) {
100 auto missing_left_chg =
101 static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
102 GradStats{right_sum}) -
103 parent.root_gain);
104 best.Update(missing_left_chg, fidx, split_pt, true, true, left_sum, right_sum);
105 }
106
107 // missing on right (treat missing as chosen category)
108 right_sum.Add(missing);
109 left_sum.SetSubstract(parent.stats, right_sum);
110 if (IsValid(left_sum, right_sum)) {
111 auto missing_right_chg =
112 static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
113 GradStats{right_sum}) -
114 parent.root_gain);
115 best.Update(missing_right_chg, fidx, split_pt, false, true, left_sum, right_sum);
116 }
117 }
118
119 if (best.is_cat) {
120 auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
121 best.cat_bits.resize(n, 0);
122 common::CatBitField cat_bits{best.cat_bits};
123 cat_bits.Set(best.split_value);
124 }
125
126 p_best->Update(best);
127 }
128
145 template <int d_step>
146 void EnumeratePart(common::HistogramCuts const &cut, common::Span<size_t const> sorted_idx,
149 SplitEntry *p_best) {
150 static_assert(d_step == +1 || d_step == -1, "Invalid step.");
151
152 auto const &cut_ptr = cut.Ptrs();
153 auto const &cut_val = cut.Values();
154 auto const &parent = snode_[nidx];
155
156 bst_bin_t f_begin = cut_ptr[fidx];
157 bst_bin_t f_end = cut_ptr[fidx + 1];
158 bst_bin_t n_bins_feature{f_end - f_begin};
159 auto n_bins = std::min(param_->max_cat_threshold, n_bins_feature);
160
161 // statistics on both sides of split
162 GradStats left_sum;
163 GradStats right_sum;
164 // best split so far
165 SplitEntry best;
166
167 auto f_hist = hist.subspan(f_begin, n_bins_feature);
168 bst_bin_t it_begin, it_end;
169 if (d_step > 0) {
170 it_begin = f_begin;
171 it_end = it_begin + n_bins - 1;
172 } else {
173 it_begin = f_end - 1;
174 it_end = it_begin - n_bins + 1;
175 }
176
177 bst_bin_t best_thresh{-1};
178 for (bst_bin_t i = it_begin; i != it_end; i += d_step) {
179 auto j = i - f_begin; // index local to current feature
180 if (d_step == 1) {
181 right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
182 left_sum.SetSubstract(parent.stats, right_sum); // missing on left
183 } else {
184 left_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
185 right_sum.SetSubstract(parent.stats, left_sum); // missing on right
186 }
187 if (IsValid(left_sum, right_sum)) {
188 auto loss_chg = evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
189 GradStats{right_sum}) -
190 parent.root_gain;
191 // We don't have a numeric split point, nan here is a dummy split.
192 if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
193 left_sum, right_sum)) {
194 best_thresh = i;
195 }
196 }
197 }
198
199 if (best_thresh != -1) {
200 auto n = common::CatBitField::ComputeStorageSize(n_bins_feature);
201 best.cat_bits = decltype(best.cat_bits)(n, 0);
202 common::CatBitField cat_bits{best.cat_bits};
203 bst_bin_t partition = d_step == 1 ? (best_thresh - it_begin + 1) : (best_thresh - f_begin);
204 CHECK_GT(partition, 0);
205 std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, [&](size_t c) {
206 auto cat = cut_val[c + f_begin];
207 cat_bits.Set(cat);
208 });
209 }
210
211 p_best->Update(best);
212 }
213
214 // Enumerate/Scan the split values of specific feature
215 // Returns the sum of gradients corresponding to the data points that contains
216 // a non-missing value for the particular feature fid.
217 template <int d_step>
218 GradStats EnumerateSplit(common::HistogramCuts const &cut, common::ConstGHistRow hist,
219 bst_feature_t fidx, bst_node_t nidx,
221 SplitEntry *p_best) const {
222 static_assert(d_step == +1 || d_step == -1, "Invalid step.");
223
224 // aliases
225 const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
226 const std::vector<bst_float> &cut_val = cut.Values();
227 auto const &parent = snode_[nidx];
228
229 // statistics on both sides of split
230 GradStats left_sum;
231 GradStats right_sum;
232 // best split so far
233 SplitEntry best;
234
235 // bin boundaries
236 CHECK_LE(cut_ptr[fidx], static_cast<uint32_t>(std::numeric_limits<bst_bin_t>::max()));
237 CHECK_LE(cut_ptr[fidx + 1], static_cast<uint32_t>(std::numeric_limits<bst_bin_t>::max()));
238 // imin: index (offset) of the minimum value for feature fid need this for backward
239 // enumeration
240 const auto imin = static_cast<bst_bin_t>(cut_ptr[fidx]);
241 // ibegin, iend: smallest/largest cut points for feature fid use int to allow for
242 // value -1
243 bst_bin_t ibegin, iend;
244 if (d_step > 0) {
245 ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
246 iend = static_cast<bst_bin_t>(cut_ptr.at(fidx + 1));
247 } else {
248 ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
249 iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
250 }
251
252 for (bst_bin_t i = ibegin; i != iend; i += d_step) {
253 // start working
254 // try to find a split
255 left_sum.Add(hist[i].GetGrad(), hist[i].GetHess());
256 right_sum.SetSubstract(parent.stats, left_sum);
257 if (IsValid(left_sum, right_sum)) {
258 bst_float loss_chg;
259 bst_float split_pt;
260 if (d_step > 0) {
261 // forward enumeration: split at right bound of each bin
262 loss_chg =
263 static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
264 GradStats{right_sum}) -
265 parent.root_gain);
266 split_pt = cut_val[i]; // not used for partition based
267 best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
268 } else {
269 // backward enumeration: split at left bound of each bin
270 loss_chg =
271 static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum},
272 GradStats{left_sum}) -
273 parent.root_gain);
274 if (i == imin) {
275 split_pt = cut.MinValues()[fidx];
276 } else {
277 split_pt = cut_val[i - 1];
278 }
279 best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
280 }
281 }
282 }
283
284 p_best->Update(best);
285 return left_sum;
286 }
287
293 std::vector<CPUExpandEntry> Allgather(std::vector<CPUExpandEntry> const &entries) {
294 auto const world = collective::GetWorldSize();
295 auto const rank = collective::GetRank();
296 auto const num_entries = entries.size();
297
298 // First, gather all the primitive fields.
299 std::vector<CPUExpandEntry> all_entries(num_entries * world);
300 std::vector<uint32_t> cat_bits;
301 std::vector<std::size_t> cat_bits_sizes;
302 for (std::size_t i = 0; i < num_entries; i++) {
303 all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes);
304 }
305 collective::Allgather(all_entries.data(), all_entries.size() * sizeof(CPUExpandEntry));
306
307 // Gather all the cat_bits.
308 auto gathered = collective::AllgatherV(cat_bits, cat_bits_sizes);
309
310 common::ParallelFor(num_entries * world, ctx_->Threads(), [&] (auto i) {
311 // Copy the cat_bits back into all expand entries.
312 all_entries[i].split.cat_bits.resize(gathered.sizes[i]);
313 std::copy_n(gathered.result.cbegin() + gathered.offsets[i], gathered.sizes[i],
314 all_entries[i].split.cat_bits.begin());
315 });
316
317 return all_entries;
318 }
319
320 public:
321 void EvaluateSplits(const BoundedHistCollection &hist, common::HistogramCuts const &cut,
322 common::Span<FeatureType const> feature_types, const RegTree &tree,
323 std::vector<CPUExpandEntry> *p_entries) {
324 auto n_threads = ctx_->Threads();
325 auto& entries = *p_entries;
326 // All nodes are on the same level, so we can store the shared ptr.
327 std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(
328 entries.size());
329 for (size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
330 auto nidx = entries[nidx_in_set].nid;
331 features[nidx_in_set] =
332 column_sampler_->GetFeatureSet(tree.GetDepth(nidx));
333 }
334 CHECK(!features.empty());
335 const size_t grain_size =
336 std::max<size_t>(1, features.front()->Size() / n_threads);
337 common::BlockedSpace2d space(entries.size(), [&](size_t nidx_in_set) {
338 return features[nidx_in_set]->Size();
339 }, grain_size);
340
341 std::vector<CPUExpandEntry> tloc_candidates(n_threads * entries.size());
342 for (size_t i = 0; i < entries.size(); ++i) {
343 for (decltype(n_threads) j = 0; j < n_threads; ++j) {
344 tloc_candidates[i * n_threads + j] = entries[i];
345 }
346 }
347 auto evaluator = tree_evaluator_.GetEvaluator();
348 auto const& cut_ptrs = cut.Ptrs();
349
350 common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
351 auto tidx = omp_get_thread_num();
352 auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
353 auto best = &entry->split;
354 auto nidx = entry->nid;
355 auto histogram = hist[nidx];
356 auto features_set = features[nidx_in_set]->ConstHostSpan();
357 for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
358 auto fidx = features_set[fidx_in_set];
359 bool is_cat = common::IsCat(feature_types, fidx);
360 if (!interaction_constraints_.Query(nidx, fidx)) {
361 continue;
362 }
363 if (is_cat) {
364 auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
365 if (common::UseOneHot(n_bins, param_->max_cat_to_onehot)) {
366 EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
367 } else {
368 std::vector<size_t> sorted_idx(n_bins);
369 std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
370 auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
371 // Sort the histogram to get contiguous partitions.
372 std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
373 auto ret = evaluator.CalcWeightCat(*param_, feat_hist[l]) <
374 evaluator.CalcWeightCat(*param_, feat_hist[r]);
375 return ret;
376 });
377 EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
378 EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
379 }
380 } else {
381 auto grad_stats = EnumerateSplit<+1>(cut, histogram, fidx, nidx, evaluator, best);
382 if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
383 EnumerateSplit<-1>(cut, histogram, fidx, nidx, evaluator, best);
384 }
385 }
386 }
387 });
388
389 for (unsigned nidx_in_set = 0; nidx_in_set < entries.size();
390 ++nidx_in_set) {
391 for (auto tidx = 0; tidx < n_threads; ++tidx) {
392 entries[nidx_in_set].split.Update(
393 tloc_candidates[n_threads * nidx_in_set + tidx].split);
394 }
395 }
396
397 if (is_col_split_) {
398 // With column-wise data split, we gather the best splits from all the workers and update the
399 // expand entries accordingly.
400 auto all_entries = Allgather(entries);
401 for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
402 for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
403 entries[nidx_in_set].split.Update(
404 all_entries[worker * entries.size() + nidx_in_set].split);
405 }
406 }
407 }
408 }
409
410 // Add splits to tree, handles all statistic
411 void ApplyTreeSplit(CPUExpandEntry const& candidate, RegTree *p_tree) {
412 auto evaluator = tree_evaluator_.GetEvaluator();
413 RegTree &tree = *p_tree;
414
415 GradStats parent_sum = candidate.split.left_sum;
416 parent_sum.Add(candidate.split.right_sum);
417 auto base_weight = evaluator.CalcWeight(candidate.nid, *param_, GradStats{parent_sum});
418 auto left_weight =
419 evaluator.CalcWeight(candidate.nid, *param_, GradStats{candidate.split.left_sum});
420 auto right_weight =
421 evaluator.CalcWeight(candidate.nid, *param_, GradStats{candidate.split.right_sum});
422
423 if (candidate.split.is_cat) {
425 candidate.nid, candidate.split.SplitIndex(), candidate.split.cat_bits,
426 candidate.split.DefaultLeft(), base_weight, left_weight * param_->learning_rate,
427 right_weight * param_->learning_rate, candidate.split.loss_chg, parent_sum.GetHess(),
428 candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
429 } else {
430 tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
431 candidate.split.DefaultLeft(), base_weight,
432 left_weight * param_->learning_rate, right_weight * param_->learning_rate,
433 candidate.split.loss_chg, parent_sum.GetHess(),
434 candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
435 }
436
437 // Set up child constraints
438 auto left_child = tree[candidate.nid].LeftChild();
439 auto right_child = tree[candidate.nid].RightChild();
440 tree_evaluator_.AddSplit(candidate.nid, left_child, right_child,
441 tree[candidate.nid].SplitIndex(), left_weight,
442 right_weight);
443 evaluator = tree_evaluator_.GetEvaluator();
444
445 snode_.resize(tree.GetNodes().size());
446 snode_.at(left_child).stats = candidate.split.left_sum;
447 snode_.at(left_child).root_gain =
448 evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.left_sum});
449 snode_.at(right_child).stats = candidate.split.right_sum;
450 snode_.at(right_child).root_gain =
451 evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.right_sum});
452
453 interaction_constraints_.Split(candidate.nid,
454 tree[candidate.nid].SplitIndex(), left_child,
455 right_child);
456 }
457
458 [[nodiscard]] auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
459 [[nodiscard]] auto const &Stats() const { return snode_; }
460
461 float InitRoot(GradStats const &root_sum) {
462 snode_.resize(1);
463 auto root_evaluator = tree_evaluator_.GetEvaluator();
464
465 snode_[0].stats = GradStats{root_sum.GetGrad(), root_sum.GetHess()};
466 snode_[0].root_gain =
467 root_evaluator.CalcGain(RegTree::kRoot, *param_, GradStats{snode_[0].stats});
468 auto weight = root_evaluator.CalcWeight(RegTree::kRoot, *param_, GradStats{snode_[0].stats});
469 return weight;
470 }
471
472 public:
473 // The column sampler must be constructed by caller since we need to preserve the rng
474 // for the entire training session.
475 explicit HistEvaluator(Context const *ctx, TrainParam const *param, MetaInfo const &info,
476 std::shared_ptr<common::ColumnSampler> sampler)
477 : ctx_{ctx},
478 param_{param},
479 column_sampler_{std::move(sampler)},
480 tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
481 is_col_split_{info.IsColumnSplit()} {
482 interaction_constraints_.Configure(*param, info.num_col_);
483 column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
484 param_->colsample_bynode, param_->colsample_bylevel,
485 param_->colsample_bytree);
486 }
487};
488
490 std::vector<double> gain_;
492 TrainParam const *param_;
493 FeatureInteractionConstraintHost interaction_constraints_;
494 std::shared_ptr<common::ColumnSampler> column_sampler_;
495 Context const *ctx_;
496 bool is_col_split_{false};
497
498 private:
499 static double MultiCalcSplitGain(TrainParam const &param,
502 linalg::VectorView<float> left_weight,
503 linalg::VectorView<float> right_weight) {
504 CalcWeight(param, left_sum, left_weight);
505 CalcWeight(param, right_sum, right_weight);
506
507 auto left_gain = CalcGainGivenWeight(param, left_sum, left_weight);
508 auto right_gain = CalcGainGivenWeight(param, right_sum, right_weight);
509 return left_gain + right_gain;
510 }
511
512 template <bst_bin_t d_step>
513 bool EnumerateSplit(common::HistogramCuts const &cut, bst_feature_t fidx,
515 linalg::VectorView<GradientPairPrecise const> parent_sum, double parent_gain,
516 SplitEntryContainer<std::vector<GradientPairPrecise>> *p_best) const {
517 auto const &cut_ptr = cut.Ptrs();
518 auto const &cut_val = cut.Values();
519 auto const &min_val = cut.MinValues();
520
521 auto sum = linalg::Empty<GradientPairPrecise>(ctx_, 2, hist.size());
522 auto left_sum = sum.Slice(0, linalg::All());
523 auto right_sum = sum.Slice(1, linalg::All());
524
525 bst_bin_t ibegin, iend;
526 if (d_step > 0) {
527 ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
528 iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
529 } else {
530 ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
531 iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
532 }
533 const auto imin = static_cast<bst_bin_t>(cut_ptr[fidx]);
534
535 auto n_targets = hist.size();
536 auto weight = linalg::Empty<float>(ctx_, 2, n_targets);
537 auto left_weight = weight.Slice(0, linalg::All());
538 auto right_weight = weight.Slice(1, linalg::All());
539
540 for (bst_bin_t i = ibegin; i != iend; i += d_step) {
541 for (bst_target_t t = 0; t < n_targets; ++t) {
542 auto t_hist = hist[t];
543 auto t_p = parent_sum(t);
544 left_sum(t) += t_hist[i];
545 right_sum(t) = t_p - left_sum(t);
546 }
547
548 if (d_step > 0) {
549 auto split_pt = cut_val[i];
550 auto loss_chg =
551 MultiCalcSplitGain(*param_, right_sum, left_sum, right_weight, left_weight) -
552 parent_gain;
553 p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
554 } else {
555 float split_pt;
556 if (i == imin) {
557 split_pt = min_val[fidx];
558 } else {
559 split_pt = cut_val[i - 1];
560 }
561 auto loss_chg =
562 MultiCalcSplitGain(*param_, right_sum, left_sum, left_weight, right_weight) -
563 parent_gain;
564 p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
565 }
566 }
567 // return true if there's missing. Doesn't handle floating-point error well.
568 if (d_step == +1) {
569 return !std::equal(linalg::cbegin(left_sum), linalg::cend(left_sum),
570 linalg::cbegin(parent_sum));
571 }
572 return false;
573 }
574
580 std::vector<MultiExpandEntry> Allgather(std::vector<MultiExpandEntry> const &entries) {
581 auto const world = collective::GetWorldSize();
582 auto const rank = collective::GetRank();
583 auto const num_entries = entries.size();
584
585 // First, gather all the primitive fields.
586 std::vector<MultiExpandEntry> all_entries(num_entries * world);
587 std::vector<uint32_t> cat_bits;
588 std::vector<std::size_t> cat_bits_sizes;
589 std::vector<GradientPairPrecise> gradients;
590 for (std::size_t i = 0; i < num_entries; i++) {
591 all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes,
592 &gradients);
593 }
594 collective::Allgather(all_entries.data(), all_entries.size() * sizeof(MultiExpandEntry));
595
596 // Gather all the cat_bits.
597 auto gathered_cat_bits = collective::AllgatherV(cat_bits, cat_bits_sizes);
598
599 // Gather all the gradients.
600 auto const num_gradients = gradients.size();
601 std::vector<GradientPairPrecise> all_gradients(num_gradients * world);
602 std::copy_n(gradients.cbegin(), num_gradients, all_gradients.begin() + num_gradients * rank);
603 collective::Allgather(all_gradients.data(), all_gradients.size() * sizeof(GradientPairPrecise));
604
605 auto const total_entries = num_entries * world;
606 auto const gradients_per_entry = num_gradients / num_entries;
607 auto const gradients_per_side = gradients_per_entry / 2;
608 common::ParallelFor(total_entries, ctx_->Threads(), [&] (auto i) {
609 // Copy the cat_bits back into all expand entries.
610 all_entries[i].split.cat_bits.resize(gathered_cat_bits.sizes[i]);
611 std::copy_n(gathered_cat_bits.result.cbegin() + gathered_cat_bits.offsets[i],
612 gathered_cat_bits.sizes[i], all_entries[i].split.cat_bits.begin());
613
614 // Copy the gradients back into all expand entries.
615 all_entries[i].split.left_sum.resize(gradients_per_side);
616 std::copy_n(all_gradients.cbegin() + i * gradients_per_entry, gradients_per_side,
617 all_entries[i].split.left_sum.begin());
618 all_entries[i].split.right_sum.resize(gradients_per_side);
619 std::copy_n(all_gradients.cbegin() + i * gradients_per_entry + gradients_per_side,
620 gradients_per_side, all_entries[i].split.right_sum.begin());
621 });
622
623 return all_entries;
624 }
625
626 public:
627 void EvaluateSplits(RegTree const &tree, common::Span<const BoundedHistCollection *> hist,
628 common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) {
629 auto &entries = *p_entries;
630 std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(entries.size());
631
632 for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
633 auto nidx = entries[nidx_in_set].nid;
634 features[nidx_in_set] = column_sampler_->GetFeatureSet(tree.GetDepth(nidx));
635 }
636 CHECK(!features.empty());
637
638 std::int32_t n_threads = ctx_->Threads();
639 std::size_t const grain_size = std::max<std::size_t>(1, features.front()->Size() / n_threads);
641 entries.size(), [&](std::size_t nidx_in_set) { return features[nidx_in_set]->Size(); },
642 grain_size);
643
644 std::vector<MultiExpandEntry> tloc_candidates(n_threads * entries.size());
645 for (std::size_t i = 0; i < entries.size(); ++i) {
646 for (std::int32_t j = 0; j < n_threads; ++j) {
647 tloc_candidates[i * n_threads + j] = entries[i];
648 }
649 }
650 common::ParallelFor2d(space, n_threads, [&](std::size_t nidx_in_set, common::Range1d r) {
651 auto tidx = omp_get_thread_num();
652 auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
653 auto best = &entry->split;
654 auto parent_sum = stats_.Slice(entry->nid, linalg::All());
655 std::vector<common::ConstGHistRow> node_hist;
656 for (auto t_hist : hist) {
657 node_hist.emplace_back((*t_hist)[entry->nid]);
658 }
659 auto features_set = features[nidx_in_set]->ConstHostSpan();
660
661 for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
662 auto fidx = features_set[fidx_in_set];
663 if (!interaction_constraints_.Query(entry->nid, fidx)) {
664 continue;
665 }
666 auto parent_gain = gain_[entry->nid];
667 bool missing =
668 this->EnumerateSplit<+1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
669 if (missing) {
670 this->EnumerateSplit<-1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
671 }
672 }
673 });
674
675 for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
676 for (auto tidx = 0; tidx < n_threads; ++tidx) {
677 entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
678 }
679 }
680
681 if (is_col_split_) {
682 // With column-wise data split, we gather the best splits from all the workers and update the
683 // expand entries accordingly.
684 auto all_entries = Allgather(entries);
685 for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
686 for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
687 entries[nidx_in_set].split.Update(
688 all_entries[worker * entries.size() + nidx_in_set].split);
689 }
690 }
691 }
692 }
693
695 auto n_targets = root_sum.Size();
696 stats_ = linalg::Constant(ctx_, GradientPairPrecise{}, 1, n_targets);
697 gain_.resize(1);
698
699 linalg::Vector<float> weight({n_targets}, ctx_->gpu_id);
700 CalcWeight(*param_, root_sum, weight.HostView());
701 auto root_gain = CalcGainGivenWeight(*param_, root_sum, weight.HostView());
702 gain_.front() = root_gain;
703
704 auto h_stats = stats_.HostView();
705 std::copy(linalg::cbegin(root_sum), linalg::cend(root_sum), linalg::begin(h_stats));
706
707 return weight;
708 }
709
710 void ApplyTreeSplit(MultiExpandEntry const &candidate, RegTree *p_tree) {
711 auto n_targets = p_tree->NumTargets();
712 auto parent_sum = stats_.Slice(candidate.nid, linalg::All());
713
714 auto weight = linalg::Empty<float>(ctx_, 3, n_targets);
715 auto base_weight = weight.Slice(0, linalg::All());
716 CalcWeight(*param_, parent_sum, base_weight);
717
718 auto left_weight = weight.Slice(1, linalg::All());
719 auto left_sum =
720 linalg::MakeVec(candidate.split.left_sum.data(), candidate.split.left_sum.size());
721 CalcWeight(*param_, left_sum, param_->learning_rate, left_weight);
722
723 auto right_weight = weight.Slice(2, linalg::All());
724 auto right_sum =
725 linalg::MakeVec(candidate.split.right_sum.data(), candidate.split.right_sum.size());
726 CalcWeight(*param_, right_sum, param_->learning_rate, right_weight);
727
728 p_tree->ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
729 candidate.split.DefaultLeft(), base_weight, left_weight, right_weight);
730 CHECK(p_tree->IsMultiTarget());
731 auto left_child = p_tree->LeftChild(candidate.nid);
732 CHECK_GT(left_child, candidate.nid);
733 auto right_child = p_tree->RightChild(candidate.nid);
734 CHECK_GT(right_child, candidate.nid);
735
736 std::size_t n_nodes = p_tree->Size();
737 gain_.resize(n_nodes);
738 gain_[left_child] = CalcGainGivenWeight(*param_, left_sum, left_weight);
739 gain_[right_child] = CalcGainGivenWeight(*param_, right_sum, right_weight);
740
741 if (n_nodes >= stats_.Shape(0)) {
742 stats_.Reshape(n_nodes * 2, stats_.Shape(1));
743 }
744 CHECK_EQ(stats_.Shape(1), n_targets);
745 auto left_sum_stat = stats_.Slice(left_child, linalg::All());
746 std::copy(candidate.split.left_sum.cbegin(), candidate.split.left_sum.cend(),
747 linalg::begin(left_sum_stat));
748 auto right_sum_stat = stats_.Slice(right_child, linalg::All());
749 std::copy(candidate.split.right_sum.cbegin(), candidate.split.right_sum.cend(),
750 linalg::begin(right_sum_stat));
751 }
752
753 explicit HistMultiEvaluator(Context const *ctx, MetaInfo const &info, TrainParam const *param,
754 std::shared_ptr<common::ColumnSampler> sampler)
755 : param_{param},
756 column_sampler_{std::move(sampler)},
757 ctx_{ctx},
758 is_col_split_{info.IsColumnSplit()} {
759 interaction_constraints_.Configure(*param, info.num_col_);
760 column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
761 param_->colsample_bynode, param_->colsample_bylevel,
762 param_->colsample_bytree);
763 }
764};
765
772template <typename Partitioner>
773void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
774 std::vector<Partitioner> const &partitioner,
775 linalg::VectorView<float> out_preds) {
776 auto const &tree = *p_last_tree;
777 CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId);
778 size_t n_nodes = p_last_tree->GetNodes().size();
779 for (auto &part : partitioner) {
780 CHECK_EQ(part.Size(), n_nodes);
782 part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
783 common::ParallelFor2d(space, ctx->Threads(), [&](bst_node_t nidx, common::Range1d r) {
784 if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) {
785 auto const &rowset = part[nidx];
786 auto leaf_value = tree[nidx].LeafValue();
787 for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
788 out_preds(*it) += leaf_value;
789 }
790 }
791 });
792 }
793}
794
795template <typename Partitioner>
796void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
797 std::vector<Partitioner> const &partitioner,
798 linalg::MatrixView<float> out_preds) {
799 CHECK_GT(out_preds.Size(), 0U);
800 CHECK(p_last_tree);
801
802 auto const &tree = *p_last_tree;
803 if (!tree.IsMultiTarget()) {
804 UpdatePredictionCacheImpl(ctx, p_last_tree, partitioner, out_preds.Slice(linalg::All(), 0));
805 return;
806 }
807
808 auto const *mttree = tree.GetMultiTargetTree();
809 auto n_nodes = mttree->Size();
810 auto n_targets = tree.NumTargets();
811 CHECK_EQ(out_preds.Shape(1), n_targets);
812 CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId);
813
814 for (auto &part : partitioner) {
815 CHECK_EQ(part.Size(), n_nodes);
816 common::BlockedSpace2d space(
817 part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
818 common::ParallelFor2d(space, ctx->Threads(), [&](bst_node_t nidx, common::Range1d r) {
819 if (tree.IsLeaf(nidx)) {
820 auto const &rowset = part[nidx];
821 auto leaf_value = mttree->LeafValue(nidx);
822 for (std::size_t const *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
823 for (std::size_t i = 0; i < n_targets; ++i) {
824 out_preds(*it, i) += leaf_value(i);
825 }
826 }
827 }
828 });
829 }
830}
831} // namespace xgboost::tree
832#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
Feature interaction constraint implementation for CPU tree updaters.
Definition constraints.h:20
Meta information about dataset, always sit in memory.
Definition data.h:48
uint64_t num_col_
number of columns in the data
Definition data.h:56
bool IsColumnSplit() const
Whether the data is split column-wise.
Definition data.h:189
define regression tree to be the most common tree model.
Definition tree_model.h:158
bst_target_t NumTargets() const
The size of leaf weight.
Definition tree_model.h:481
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Definition tree_model.cc:791
bool IsMultiTarget() const
Whether this is a multi-target tree.
Definition tree_model.h:477
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
Definition tree_model.cc:837
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition tree_model.h:354
std::int32_t GetDepth(bst_node_t nid) const
get current depth
Definition tree_model.h:517
Definition threading_utils.h:74
Definition hist_util.h:37
Definition threading_utils.h:39
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition span.h:596
A tensor view with static type and dimension.
Definition linalg.h:293
LINALG_HD auto DeviceIdx() const
Obtain the CUDA device ordinal.
Definition linalg.h:567
LINALG_HD std::size_t Size() const
Number of items in the tensor.
Definition linalg.h:533
LINALG_HD auto Slice(S &&...slices) const
Slice the tensor.
Definition linalg.h:506
A tensor storage.
Definition linalg.h:742
auto Slice(S &&...slices) const
Get a host view on the slice.
Definition linalg.h:919
void Reshape(S &&...s)
Reshape the tensor.
Definition linalg.h:887
A persistent cache for CPU histogram.
Definition hist_cache.h:30
Definition evaluate_splits.h:29
Definition evaluate_splits.h:489
Definition split_evaluator.h:28
Copyright 2014-2023, XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2021-2023 by XGBoost Contributors.
void Allgather(void *send_receive_buffer, std::size_t size)
Gathers data from all processes and distributes it to all processes.
Definition communicator-inl.h:153
int GetWorldSize()
Get total number of processes.
Definition communicator-inl.h:83
int GetRank()
Get rank of current process.
Definition communicator-inl.h:76
AllgatherVResult< T > AllgatherV(std::vector< T > const &inputs, std::vector< std::size_t > const &sizes)
Gathers variable-length data from all processes and distributes it to all processes.
Definition communicator-inl.h:244
XGBOOST_DEVICE bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot)
Whether should we use onehot encoding for categorical data.
Definition categorical.h:89
auto MakeVec(T *ptr, size_t s, int32_t device=-1)
Create a vector view from contigious memory.
Definition linalg.h:649
auto Constant(Context const *ctx, T v, Index &&...index)
Create an array with value v.
Definition linalg.h:958
constexpr detail::AllTag All()
Specify all elements in the axis for slicing.
Definition linalg.h:265
Copyright 2021-2023 by XGBoost Contributors.
Definition tree_updater.h:25
void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree, std::vector< Partitioner > const &partitioner, linalg::VectorView< float > out_preds)
CPU implementation of update prediction cache, which calculates the leaf value for the last tree and ...
Definition evaluate_splits.h:773
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
std::int32_t bst_node_t
Type for tree node index.
Definition base.h:112
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition base.h:118
int32_t bst_bin_t
Type for histogram bin index.
Definition base.h:103
float bst_float
float type, used for storing statistics
Definition base.h:97
Runtime context for XGBoost.
Definition context.h:84
std::int32_t Threads() const
Returns the automatically chosen number of threads based on the nthread parameter and the system sett...
Definition context.cc:203
Definition expand_entry.h:34
Definition expand_entry.h:85
bst_feature_t SplitIndex() const
Definition param.h:477
bool Update(const SplitEntryContainer &e)
update the split entry, replace it if e is better
Definition param.h:506
bool DefaultLeft() const
Definition param.h:479
bst_float loss_chg
loss change after split this node
Definition param.h:400
training parameters for regression tree
Definition param.h:28
Definition split_evaluator.h:71