Medial Code Documentation
Loading...
Searching...
No Matches
common_row_partitioner.h
Go to the documentation of this file.
1
6#ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
7#define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
8
9#include <algorithm> // std::all_of
10#include <cinttypes> // std::uint32_t
11#include <limits> // std::numeric_limits
12#include <vector>
13
14#include "../collective/communicator-inl.h"
15#include "../common/linalg_op.h" // cbegin
16#include "../common/numeric.h" // Iota
17#include "../common/partition_builder.h"
18#include "hist/expand_entry.h" // CPUExpandEntry
19#include "xgboost/base.h"
20#include "xgboost/context.h" // Context
21#include "xgboost/linalg.h" // TensorView
22
23namespace xgboost::tree {
24
25static constexpr size_t kPartitionBlockSize = 2048;
26
28 public:
29 ColumnSplitHelper() = default;
30
33 common::RowSetCollection* row_set_collection)
34 : partition_builder_{partition_builder}, row_set_collection_{row_set_collection} {
35 decision_storage_.resize(num_row);
36 decision_bits_ = BitVector(common::Span<BitVector::value_type>(decision_storage_));
37 missing_storage_.resize(num_row);
38 missing_bits_ = BitVector(common::Span<BitVector::value_type>(missing_storage_));
39 }
40
41 template <typename BinIdxType, bool any_missing, bool any_cat, typename ExpandEntry>
42 void Partition(common::BlockedSpace2d const& space, std::int32_t n_threads,
43 GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix,
44 std::vector<ExpandEntry> const& nodes,
45 std::vector<int32_t> const& split_conditions, RegTree const* p_tree) {
46 // When data is split by column, we don't have all the feature values in the local worker, so
47 // we first collect all the decisions and whether the feature is missing into bit vectors.
48 std::fill(decision_storage_.begin(), decision_storage_.end(), 0);
49 std::fill(missing_storage_.begin(), missing_storage_.end(), 0);
50 common::ParallelFor2d(space, n_threads, [&](size_t node_in_set, common::Range1d r) {
51 const int32_t nid = nodes[node_in_set].nid;
52 bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0;
53 partition_builder_->MaskRows<BinIdxType, any_missing, any_cat>(
54 node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree,
55 (*row_set_collection_)[nid].begin, &decision_bits_, &missing_bits_);
56 });
57
58 // Then aggregate the bit vectors across all the workers.
59 collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
60 decision_storage_.size());
61 collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
62 missing_storage_.size());
63
64 // Finally use the bit vectors to partition the rows.
65 common::ParallelFor2d(space, n_threads, [&](size_t node_in_set, common::Range1d r) {
66 size_t begin = r.begin();
67 const int32_t nid = nodes[node_in_set].nid;
68 const size_t task_id = partition_builder_->GetTaskIdx(node_in_set, begin);
69 partition_builder_->AllocateForTask(task_id);
70 partition_builder_->PartitionByMask(node_in_set, nodes, r, gmat, *p_tree,
71 (*row_set_collection_)[nid].begin, decision_bits_,
72 missing_bits_);
73 });
74 }
75
76 private:
77 using BitVector = RBitField8;
78 std::vector<BitVector::value_type> decision_storage_{};
79 BitVector decision_bits_{};
80 std::vector<BitVector::value_type> missing_storage_{};
81 BitVector missing_bits_{};
83 common::RowSetCollection* row_set_collection_;
84};
85
87 public:
88 bst_row_t base_rowid = 0;
89
90 CommonRowPartitioner() = default;
91 CommonRowPartitioner(Context const* ctx, bst_row_t num_row, bst_row_t _base_rowid,
92 bool is_col_split)
93 : base_rowid{_base_rowid}, is_col_split_{is_col_split} {
94 row_set_collection_.Clear();
95 std::vector<size_t>& row_indices = *row_set_collection_.Data();
96 row_indices.resize(num_row);
97
98 std::size_t* p_row_indices = row_indices.data();
99 common::Iota(ctx, p_row_indices, p_row_indices + row_indices.size(), base_rowid);
100 row_set_collection_.Init();
101
102 if (is_col_split_) {
103 column_split_helper_ = ColumnSplitHelper{num_row, &partition_builder_, &row_set_collection_};
104 }
105 }
106
107 template <typename ExpandEntry>
108 void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
109 const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) {
110 auto const& ptrs = gmat.cut.Ptrs();
111 auto const& vals = gmat.cut.Values();
112
113 for (std::size_t i = 0; i < nodes.size(); ++i) {
114 bst_node_t const nidx = nodes[i].nid;
115 bst_feature_t const fidx = tree.SplitIndex(nidx);
116 float const split_pt = tree.SplitCond(nidx);
117 std::uint32_t const lower_bound = ptrs[fidx];
118 std::uint32_t const upper_bound = ptrs[fidx + 1];
119 bst_bin_t split_cond = -1;
120 // convert floating-point split_pt into corresponding bin_id
121 // split_cond = -1 indicates that split_pt is less than all known cut points
122 CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
123 for (auto bound = lower_bound; bound < upper_bound; ++bound) {
124 if (split_pt == vals[bound]) {
125 split_cond = static_cast<bst_bin_t>(bound);
126 }
127 }
128 (*split_conditions)[i] = split_cond;
129 }
130 }
131
132 template <typename ExpandEntry>
133 void AddSplitsToRowSet(const std::vector<ExpandEntry>& nodes, RegTree const* p_tree) {
134 const size_t n_nodes = nodes.size();
135 for (unsigned int i = 0; i < n_nodes; ++i) {
136 const int32_t nidx = nodes[i].nid;
137 const size_t n_left = partition_builder_.GetNLeftElems(i);
138 const size_t n_right = partition_builder_.GetNRightElems(i);
139 CHECK_EQ(p_tree->LeftChild(nidx) + 1, p_tree->RightChild(nidx));
140 row_set_collection_.AddSplit(nidx, p_tree->LeftChild(nidx), p_tree->RightChild(nidx), n_left,
141 n_right);
142 }
143 }
144
145 template <typename ExpandEntry>
146 void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
147 std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
148 auto const& column_matrix = gmat.Transpose();
149 if (column_matrix.IsInitialized()) {
150 if (gmat.cut.HasCategorical()) {
151 this->template UpdatePosition<true>(ctx, gmat, column_matrix, nodes, p_tree);
152 } else {
153 this->template UpdatePosition<false>(ctx, gmat, column_matrix, nodes, p_tree);
154 }
155 } else {
156 /* ColumnMatrix is not initilized.
157 * It means that we use 'approx' method.
158 * any_missing and any_cat don't metter in this case.
159 * Jump directly to the main method.
160 */
161 this->template UpdatePosition<uint8_t, true, true>(ctx, gmat, column_matrix, nodes, p_tree);
162 }
163 }
164
165 template <bool any_cat, typename ExpandEntry>
166 void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
167 const common::ColumnMatrix& column_matrix,
168 std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
169 if (column_matrix.AnyMissing()) {
170 this->template UpdatePosition<true, any_cat>(ctx, gmat, column_matrix, nodes, p_tree);
171 } else {
172 this->template UpdatePosition<false, any_cat>(ctx, gmat, column_matrix, nodes, p_tree);
173 }
174 }
175
176 template <bool any_missing, bool any_cat, typename ExpandEntry>
177 void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
178 const common::ColumnMatrix& column_matrix,
179 std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
180 common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto t) {
181 using T = decltype(t);
182 this->template UpdatePosition<T, any_missing, any_cat>(ctx, gmat, column_matrix, nodes,
183 p_tree);
184 });
185 }
186
187 template <typename BinIdxType, bool any_missing, bool any_cat, typename ExpandEntry>
188 void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
189 const common::ColumnMatrix& column_matrix,
190 std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
191 // 1. Find split condition for each split
192 size_t n_nodes = nodes.size();
193
194 std::vector<int32_t> split_conditions;
195 if (column_matrix.IsInitialized()) {
196 split_conditions.resize(n_nodes);
197 FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
198 }
199
200 // 2.1 Create a blocked space of size SUM(samples in each node)
202 n_nodes,
203 [&](size_t node_in_set) {
204 int32_t nid = nodes[node_in_set].nid;
205 return row_set_collection_[nid].Size();
206 },
207 kPartitionBlockSize);
208
209 // 2.2 Initialize the partition builder
210 // allocate buffers for storage intermediate results by each thread
211 partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
212 const int32_t nid = nodes[node_in_set].nid;
213 const size_t size = row_set_collection_[nid].Size();
214 const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
215 return n_tasks;
216 });
217 CHECK_EQ(base_rowid, gmat.base_rowid);
218
219 // 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
220 // Store results in intermediate buffers from partition_builder_
221 if (is_col_split_) {
222 column_split_helper_.Partition<BinIdxType, any_missing, any_cat>(
223 space, ctx->Threads(), gmat, column_matrix, nodes, split_conditions, p_tree);
224 } else {
225 common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
226 size_t begin = r.begin();
227 const int32_t nid = nodes[node_in_set].nid;
228 const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
229 partition_builder_.AllocateForTask(task_id);
230 bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0;
231 partition_builder_.template Partition<BinIdxType, any_missing, any_cat>(
232 node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree,
233 row_set_collection_[nid].begin);
234 });
235 }
236
237 // 3. Compute offsets to copy blocks of row-indexes
238 // from partition_builder_ to row_set_collection_
239 partition_builder_.CalculateRowOffsets();
240
241 // 4. Copy elements from partition_builder_ to row_set_collection_ back
242 // with updated row-indexes for each tree-node
243 common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
244 const int32_t nid = nodes[node_in_set].nid;
245 partition_builder_.MergeToArray(node_in_set, r.begin(),
246 const_cast<size_t*>(row_set_collection_[nid].begin));
247 });
248
249 // 5. Add info about splits into row_set_collection_
250 AddSplitsToRowSet(nodes, p_tree);
251 }
252
253 [[nodiscard]] auto const& Partitions() const { return row_set_collection_; }
254
255 [[nodiscard]] std::size_t Size() const {
256 return std::distance(row_set_collection_.begin(), row_set_collection_.end());
257 }
258
259 auto& operator[](bst_node_t nidx) { return row_set_collection_[nidx]; }
260 auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
261
262 void LeafPartition(Context const* ctx, RegTree const& tree, common::Span<float const> hess,
263 std::vector<bst_node_t>* p_out_position) const {
264 partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position,
265 [&](size_t idx) -> bool { return hess[idx] - .0f == .0f; });
266 }
267
268 void LeafPartition(Context const* ctx, RegTree const& tree,
270 std::vector<bst_node_t>* p_out_position) const {
271 if (gpair.Shape(1) > 1) {
272 partition_builder_.LeafPartition(
273 ctx, tree, this->Partitions(), p_out_position, [&](std::size_t idx) -> bool {
274 auto sample = gpair.Slice(idx, linalg::All());
275 return std::all_of(linalg::cbegin(sample), linalg::cend(sample),
276 [](GradientPair const& g) { return g.GetHess() - .0f == .0f; });
277 });
278 } else {
279 auto s = gpair.Slice(linalg::All(), 0);
280 partition_builder_.LeafPartition(
281 ctx, tree, this->Partitions(), p_out_position,
282 [&](std::size_t idx) -> bool { return s(idx).GetHess() - .0f == .0f; });
283 }
284 }
285 void LeafPartition(Context const* ctx, RegTree const& tree,
287 std::vector<bst_node_t>* p_out_position) const {
288 partition_builder_.LeafPartition(
289 ctx, tree, this->Partitions(), p_out_position,
290 [&](std::size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
291 }
292
293 private:
295 common::RowSetCollection row_set_collection_;
296 bool is_col_split_;
297 ColumnSplitHelper column_split_helper_;
298};
299
300} // namespace xgboost::tree
301#endif // XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
preprocessed global index matrix, in CSR format.
Definition gradient_index.h:38
bst_row_t base_rowid
base row index for current page (used by external memory)
Definition gradient_index.h:152
common::HistogramCuts cut
The corresponding cuts.
Definition gradient_index.h:148
define regression tree to be the most common tree model.
Definition tree_model.h:158
Definition threading_utils.h:74
Column major matrix for gradient index on CPU.
Definition column_matrix.h:148
Definition partition_builder.h:32
void PartitionByMask(const size_t node_in_set, std::vector< ExpandEntry > const &nodes, const common::Range1d range, GHistIndexMatrix const &gmat, const RegTree &tree, const size_t *rid, BitVector const &decision_bits, BitVector const &missing_bits)
Once we've aggregated the decision and missing bits from all the workers, we can then use them to par...
Definition partition_builder.h:264
void MaskRows(const size_t node_in_set, std::vector< ExpandEntry > const &nodes, const common::Range1d range, bst_bin_t split_cond, GHistIndexMatrix const &gmat, const common::ColumnMatrix &column_matrix, const RegTree &tree, const size_t *rid, BitVector *decision_bits, BitVector *missing_bits)
When data is split by column, we don't have all the features locally on the current worker,...
Definition partition_builder.h:206
Definition threading_utils.h:39
collection of rowset
Definition row_set.h:19
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
Implementation of gradient statistics pair. Template specialisation may be used to overload different...
Definition base.h:137
A tensor view with static type and dimension.
Definition linalg.h:293
LINALG_HD auto Slice(S &&...slices) const
Slice the tensor.
Definition linalg.h:506
Definition common_row_partitioner.h:27
Definition common_row_partitioner.h:86
Copyright 2014-2023, XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2021-2023 by XGBoost Contributors.
auto DispatchBinType(BinTypeSize type, Fn &&fn)
Dispatch for bin type, fn is a function that accepts a scalar of the bin type.
Definition hist_util.h:187
constexpr detail::AllTag All()
Specify all elements in the axis for slicing.
Definition linalg.h:265
Copyright 2021-2023 by XGBoost Contributors.
Definition tree_updater.h:25
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
std::int32_t bst_node_t
Type for tree node index.
Definition base.h:112
std::size_t bst_row_t
Type for data row index.
Definition base.h:110
int32_t bst_bin_t
Type for histogram bin index.
Definition base.h:103
Runtime context for XGBoost.
Definition context.h:84
std::int32_t Threads() const
Returns the automatically chosen number of threads based on the nthread parameter and the system sett...
Definition context.cc:203