Medial Code Documentation
Loading...
Searching...
No Matches
row_set.h
Go to the documentation of this file.
1
7#ifndef XGBOOST_COMMON_ROW_SET_H_
8#define XGBOOST_COMMON_ROW_SET_H_
9
10#include <xgboost/data.h>
11#include <algorithm>
12#include <vector>
13#include <utility>
14#include <memory>
15
16namespace xgboost {
17namespace common {
20 public:
21 RowSetCollection() = default;
22 RowSetCollection(RowSetCollection const&) = delete;
24 RowSetCollection& operator=(RowSetCollection const&) = delete;
25 RowSetCollection& operator=(RowSetCollection&&) = default;
26
30 struct Elem {
31 const size_t* begin{nullptr};
32 const size_t* end{nullptr};
33 bst_node_t node_id{-1};
34 // id of node associated with this instance set; -1 means uninitialized
35 Elem()
36 = default;
37 Elem(const size_t* begin,
38 const size_t* end,
39 bst_node_t node_id = -1)
40 : begin(begin), end(end), node_id(node_id) {}
41
42 inline size_t Size() const {
43 return end - begin;
44 }
45 };
46
47 std::vector<Elem>::const_iterator begin() const { // NOLINT
48 return elem_of_each_node_.begin();
49 }
50
51 std::vector<Elem>::const_iterator end() const { // NOLINT
52 return elem_of_each_node_.end();
53 }
54
55 size_t Size() const { return std::distance(begin(), end()); }
56
58 inline const Elem& operator[](unsigned node_id) const {
59 const Elem& e = elem_of_each_node_[node_id];
60 return e;
61 }
62
64 inline Elem& operator[](unsigned node_id) {
65 Elem& e = elem_of_each_node_[node_id];
66 return e;
67 }
68
69 // clear up things
70 inline void Clear() {
71 elem_of_each_node_.clear();
72 }
73 // initialize node id 0->everything
74 inline void Init() {
75 CHECK_EQ(elem_of_each_node_.size(), 0U);
76
77 if (row_indices_.empty()) { // edge case: empty instance set
78 constexpr size_t* kBegin = nullptr;
79 constexpr size_t* kEnd = nullptr;
80 static_assert(kEnd - kBegin == 0);
81 elem_of_each_node_.emplace_back(kBegin, kEnd, 0);
82 return;
83 }
84
85 const size_t* begin = dmlc::BeginPtr(row_indices_);
86 const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
87 elem_of_each_node_.emplace_back(begin, end, 0);
88 }
89
90 std::vector<size_t>* Data() { return &row_indices_; }
91 std::vector<size_t> const* Data() const { return &row_indices_; }
92
93 // split rowset into two
94 inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id,
95 size_t n_left, size_t n_right) {
96 const Elem e = elem_of_each_node_[node_id];
97
98 size_t* all_begin{nullptr};
99 size_t* begin{nullptr};
100 if (e.begin == nullptr) {
101 CHECK_EQ(n_left, 0);
102 CHECK_EQ(n_right, 0);
103 } else {
104 all_begin = dmlc::BeginPtr(row_indices_);
105 begin = all_begin + (e.begin - all_begin);
106 }
107
108 CHECK_EQ(n_left + n_right, e.Size());
109 CHECK_LE(begin + n_left, e.end);
110 CHECK_EQ(begin + n_left + n_right, e.end);
111
112 if (left_node_id >= elem_of_each_node_.size()) {
113 elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
114 }
115 if (right_node_id >= elem_of_each_node_.size()) {
116 elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
117 }
118
119 elem_of_each_node_[left_node_id] = Elem(begin, begin + n_left, left_node_id);
120 elem_of_each_node_[right_node_id] = Elem(begin + n_left, e.end, right_node_id);
121 elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
122 }
123
124 private:
125 // stores the row indexes in the set
126 std::vector<size_t> row_indices_;
127 // vector: node_id -> elements
128 std::vector<Elem> elem_of_each_node_;
129};
130} // namespace common
131} // namespace xgboost
132
133#endif // XGBOOST_COMMON_ROW_SET_H_
collection of rowset
Definition row_set.h:19
Elem & operator[](unsigned node_id)
return corresponding element set given the node_id
Definition row_set.h:64
const Elem & operator[](unsigned node_id) const
return corresponding element set given the node_id
Definition row_set.h:58
Copyright 2015-2023 by XGBoost Contributors.
T * BeginPtr(std::vector< T > &vec)
safely get the beginning address of a vector
Definition base.h:284
namespace of xgboost
Definition base.h:90
std::int32_t bst_node_t
Type for tree node index.
Definition base.h:112
data structure to store an instance set, a subset of rows (instances) associated with a particular no...
Definition row_set.h:30