Medial Code Documentation
Loading...
Searching...
No Matches
tree_model.h
Go to the documentation of this file.
1
7#ifndef XGBOOST_TREE_MODEL_H_
8#define XGBOOST_TREE_MODEL_H_
9
10#include <dmlc/io.h>
11#include <dmlc/parameter.h>
12#include <xgboost/base.h>
13#include <xgboost/data.h>
14#include <xgboost/feature_map.h>
15#include <xgboost/linalg.h> // for VectorView
16#include <xgboost/logging.h>
17#include <xgboost/model.h>
18#include <xgboost/multi_target_tree_model.h> // for MultiTargetTree
19
20#include <algorithm>
21#include <cstring>
22#include <limits>
23#include <memory> // for make_unique
24#include <stack>
25#include <string>
26#include <tuple>
27#include <vector>
28
29namespace xgboost {
30class Json;
31
32// FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should
33// not be configured by users.
35struct TreeParam : public dmlc::Parameter<TreeParam> {
39 int num_nodes{1};
52 int reserved[31];
55 // assert compact alignment
56 static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int), "TreeParam: 64 bit align");
57 std::memset(reserved, 0, sizeof(reserved));
58 }
59
60 // Swap byte order for all fields. Useful for transporting models between machines with different
61 // endianness (big endian vs little endian)
62 [[nodiscard]] TreeParam ByteSwap() const {
63 TreeParam x = *this;
64 dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1);
65 dmlc::ByteSwap(&x.num_nodes, sizeof(x.num_nodes), 1);
66 dmlc::ByteSwap(&x.num_deleted, sizeof(x.num_deleted), 1);
67 dmlc::ByteSwap(&x.deprecated_max_depth, sizeof(x.deprecated_max_depth), 1);
68 dmlc::ByteSwap(&x.num_feature, sizeof(x.num_feature), 1);
69 dmlc::ByteSwap(&x.size_leaf_vector, sizeof(x.size_leaf_vector), 1);
70 dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
71 return x;
72 }
73
74 // declare the parameters
75 DMLC_DECLARE_PARAMETER(TreeParam) {
76 // only declare the parameters that can be set by the user.
77 // other arguments are set by the algorithm.
78 DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
79 DMLC_DECLARE_FIELD(num_feature)
80 .set_default(0)
81 .describe("Number of features used in tree construction.");
82 DMLC_DECLARE_FIELD(num_deleted).set_default(0);
83 DMLC_DECLARE_FIELD(size_leaf_vector)
84 .set_lower_bound(0)
85 .set_default(1)
86 .describe("Size of leaf vector, reserved for vector tree");
87 }
88
89 bool operator==(const TreeParam& b) const {
90 return num_nodes == b.num_nodes && num_deleted == b.num_deleted &&
91 num_feature == b.num_feature && size_leaf_vector == b.size_leaf_vector;
92 }
93};
94
105
106 RTreeNodeStat() = default;
107 RTreeNodeStat(float loss_chg, float sum_hess, float weight) :
109 bool operator==(const RTreeNodeStat& b) const {
110 return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
111 base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
112 }
113 // Swap byte order for all fields. Useful for transporting models between machines with different
114 // endianness (big endian vs little endian)
115 [[nodiscard]] RTreeNodeStat ByteSwap() const {
116 RTreeNodeStat x = *this;
117 dmlc::ByteSwap(&x.loss_chg, sizeof(x.loss_chg), 1);
118 dmlc::ByteSwap(&x.sum_hess, sizeof(x.sum_hess), 1);
119 dmlc::ByteSwap(&x.base_weight, sizeof(x.base_weight), 1);
120 dmlc::ByteSwap(&x.leaf_child_cnt, sizeof(x.leaf_child_cnt), 1);
121 return x;
122 }
123};
124
128template <typename T>
130 std::unique_ptr<T> ptr_{nullptr};
131
132 public:
133 CopyUniquePtr() = default;
134 CopyUniquePtr(CopyUniquePtr const& that) {
135 ptr_.reset(nullptr);
136 if (that.ptr_) {
137 ptr_ = std::make_unique<T>(*that);
138 }
139 }
140 T* get() const noexcept { return ptr_.get(); } // NOLINT
141
142 T& operator*() { return *ptr_; }
143 T* operator->() noexcept { return this->get(); }
144
145 T const& operator*() const { return *ptr_; }
146 T const* operator->() const noexcept { return this->get(); }
147
148 explicit operator bool() const { return static_cast<bool>(ptr_); }
149 bool operator!() const { return !ptr_; }
150 void reset(T* ptr) { ptr_.reset(ptr); } // NOLINT
151};
152
158class RegTree : public Model {
159 public:
160 using SplitCondT = bst_float;
161 static constexpr bst_node_t kInvalidNodeId{MultiTargetTree::InvalidNodeId()};
162 static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
163 static constexpr bst_node_t kRoot{0};
164
166 class Node {
167 public:
169 // assert compact alignment
170 static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
171 "Node: 64 bit align");
172 }
173 Node(int32_t cleft, int32_t cright, int32_t parent,
174 uint32_t split_ind, float split_cond, bool default_left) :
175 parent_{parent}, cleft_{cleft}, cright_{cright} {
176 this->SetParent(parent_);
177 this->SetSplit(split_ind, split_cond, default_left);
178 }
179
181 [[nodiscard]] XGBOOST_DEVICE int LeftChild() const { return this->cleft_; }
183 [[nodiscard]] XGBOOST_DEVICE int RightChild() const { return this->cright_; }
185 [[nodiscard]] XGBOOST_DEVICE int DefaultChild() const {
186 return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
187 }
189 [[nodiscard]] XGBOOST_DEVICE unsigned SplitIndex() const {
190 return sindex_ & ((1U << 31) - 1U);
191 }
193 [[nodiscard]] XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; }
195 [[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == kInvalidNodeId; }
197 [[nodiscard]] XGBOOST_DEVICE float LeafValue() const { return (this->info_).leaf_value; }
199 [[nodiscard]] XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; }
201 [[nodiscard]] XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); }
203 [[nodiscard]] XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; }
205 [[nodiscard]] XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; }
207 [[nodiscard]] XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
213 this->cleft_ = nid;
214 }
220 this->cright_ = nid;
221 }
228 XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
229 bool default_left = false) {
230 if (default_left) split_index |= (1U << 31);
231 this->sindex_ = split_index;
232 (this->info_).split_cond = split_cond;
233 }
240 XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
241 (this->info_).leaf_value = value;
242 this->cleft_ = kInvalidNodeId;
243 this->cright_ = right;
244 }
247 this->sindex_ = kDeletedNodeMarker;
248 }
251 this->sindex_ = 0;
252 }
253 // set parent
254 XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
255 if (is_left_child) pidx |= (1U << 31);
256 this->parent_ = pidx;
257 }
258 bool operator==(const Node& b) const {
259 return parent_ == b.parent_ && cleft_ == b.cleft_ &&
260 cright_ == b.cright_ && sindex_ == b.sindex_ &&
261 info_.leaf_value == b.info_.leaf_value;
262 }
263
264 [[nodiscard]] Node ByteSwap() const {
265 Node x = *this;
266 dmlc::ByteSwap(&x.parent_, sizeof(x.parent_), 1);
267 dmlc::ByteSwap(&x.cleft_, sizeof(x.cleft_), 1);
268 dmlc::ByteSwap(&x.cright_, sizeof(x.cright_), 1);
269 dmlc::ByteSwap(&x.sindex_, sizeof(x.sindex_), 1);
270 dmlc::ByteSwap(&x.info_, sizeof(x.info_), 1);
271 return x;
272 }
273
274 private:
279 union Info{
280 bst_float leaf_value;
281 SplitCondT split_cond;
282 };
283 // pointer to parent, highest bit is used to
284 // indicate whether it's a left child or not
285 int32_t parent_{kInvalidNodeId};
286 // pointer to left, right
287 int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
288 // split feature index, left split or right split depends on the highest bit
289 uint32_t sindex_{0};
290 // extra info
291 Info info_;
292 };
293
299 void ChangeToLeaf(int rid, bst_float value) {
300 CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
301 CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
302 this->DeleteNode(nodes_[rid].LeftChild());
303 this->DeleteNode(nodes_[rid].RightChild());
304 nodes_[rid].SetLeaf(value);
305 }
311 void CollapseToLeaf(int rid, bst_float value) {
312 if (nodes_[rid].IsLeaf()) return;
313 if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
314 CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
315 }
316 if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
317 CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
318 }
319 this->ChangeToLeaf(rid, value);
320 }
321
322 RegTree() {
323 param_.Init(Args{});
324 nodes_.resize(param_.num_nodes);
325 stats_.resize(param_.num_nodes);
326 split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
327 split_categories_segments_.resize(param_.num_nodes);
328 for (int i = 0; i < param_.num_nodes; i++) {
329 nodes_[i].SetLeaf(0.0f);
330 nodes_[i].SetParent(kInvalidNodeId);
331 }
332 }
336 explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
337 param_.num_feature = n_features;
338 param_.size_leaf_vector = n_targets;
339 if (n_targets > 1) {
340 this->p_mt_tree_.reset(new MultiTargetTree{&param_});
341 }
342 }
343
345 Node& operator[](int nid) {
346 return nodes_[nid];
347 }
349 const Node& operator[](int nid) const {
350 return nodes_[nid];
351 }
352
354 [[nodiscard]] const std::vector<Node>& GetNodes() const { return nodes_; }
355
357 [[nodiscard]] const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
358
360 RTreeNodeStat& Stat(int nid) {
361 return stats_[nid];
362 }
364 [[nodiscard]] const RTreeNodeStat& Stat(int nid) const {
365 return stats_[nid];
366 }
367
372 void Load(dmlc::Stream* fi);
377 void Save(dmlc::Stream* fo) const;
378
379 void LoadModel(Json const& in) override;
380 void SaveModel(Json* out) const override;
381
382 bool operator==(const RegTree& b) const {
383 return nodes_ == b.nodes_ && stats_ == b.stats_ &&
384 deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
385 }
386 /* \brief Iterate through all nodes in this tree.
387 *
388 * \param Function that accepts a node index, and returns false when iteration should
389 * stop, otherwise returns true.
390 */
391 template <typename Func> void WalkTree(Func func) const {
392 std::stack<bst_node_t> nodes;
393 nodes.push(kRoot);
394 auto &self = *this;
395 while (!nodes.empty()) {
396 auto nidx = nodes.top();
397 nodes.pop();
398 if (!func(nidx)) {
399 return;
400 }
401 auto left = self[nidx].LeftChild();
402 auto right = self[nidx].RightChild();
403 if (left != RegTree::kInvalidNodeId) {
404 nodes.push(left);
405 }
406 if (right != RegTree::kInvalidNodeId) {
407 nodes.push(right);
408 }
409 }
410 }
417 [[nodiscard]] bool Equal(const RegTree& b) const;
418
436 void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
437 bool default_left, bst_float base_weight,
438 bst_float left_leaf_weight, bst_float right_leaf_weight,
439 bst_float loss_change, float sum_hess, float left_sum,
440 float right_sum,
441 bst_node_t leaf_right_child = kInvalidNodeId);
445 void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
446 linalg::VectorView<float const> base_weight,
447 linalg::VectorView<float const> left_weight,
448 linalg::VectorView<float const> right_weight);
449
465 void ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
466 common::Span<const uint32_t> split_cat, bool default_left,
467 bst_float base_weight, bst_float left_leaf_weight,
468 bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
469 float left_sum, float right_sum);
473 [[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); }
477 [[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
481 [[nodiscard]] bst_target_t NumTargets() const { return param_.size_leaf_vector; }
485 [[nodiscard]] auto GetMultiTargetTree() const {
486 CHECK(IsMultiTarget());
487 return p_mt_tree_.get();
488 }
492 [[nodiscard]] bst_feature_t NumFeatures() const noexcept { return param_.num_feature; }
496 [[nodiscard]] bst_node_t NumNodes() const noexcept { return param_.num_nodes; }
500 [[nodiscard]] bst_node_t NumValidNodes() const noexcept {
501 return param_.num_nodes - param_.num_deleted;
502 }
506 [[nodiscard]] bst_node_t NumExtraNodes() const noexcept {
507 return param_.num_nodes - 1 - param_.num_deleted;
508 }
509 /* \brief Count number of leaves in tree. */
510 [[nodiscard]] bst_node_t GetNumLeaves() const;
511 [[nodiscard]] bst_node_t GetNumSplitNodes() const;
512
517 [[nodiscard]] std::int32_t GetDepth(bst_node_t nid) const {
518 if (IsMultiTarget()) {
519 return this->p_mt_tree_->Depth(nid);
520 }
521 int depth = 0;
522 while (!nodes_[nid].IsRoot()) {
523 ++depth;
524 nid = nodes_[nid].Parent();
525 }
526 return depth;
527 }
532 CHECK(IsMultiTarget());
533 return this->p_mt_tree_->SetLeaf(nidx, weight);
534 }
535
540 [[nodiscard]] int MaxDepth(int nid) const {
541 if (nodes_[nid].IsLeaf()) return 0;
542 return std::max(MaxDepth(nodes_[nid].LeftChild()) + 1, MaxDepth(nodes_[nid].RightChild()) + 1);
543 }
544
548 int MaxDepth() { return MaxDepth(0); }
549
554 struct FVec {
559 void Init(size_t size);
564 void Fill(const SparsePage::Inst& inst);
565
570 void Drop();
575 [[nodiscard]] size_t Size() const;
581 [[nodiscard]] bst_float GetFvalue(size_t i) const;
587 [[nodiscard]] bool IsMissing(size_t i) const;
588 [[nodiscard]] bool HasMissing() const;
589
590
591 private:
596 union Entry {
597 bst_float fvalue;
598 int flag;
599 };
600 std::vector<Entry> data_;
601 bool has_missing_;
602 };
603
610 std::vector<float>* mean_values,
611 bst_float* out_contribs) const;
619 [[nodiscard]] std::string DumpModel(const FeatureMap& fmap, bool with_stats,
620 std::string format) const;
626 [[nodiscard]] FeatureType NodeSplitType(bst_node_t nidx) const { return split_types_.at(nidx); }
630 [[nodiscard]] std::vector<FeatureType> const& GetSplitTypes() const {
631 return split_types_;
632 }
633 [[nodiscard]] common::Span<uint32_t const> GetSplitCategories() const {
634 return split_categories_;
635 }
640 auto node_ptr = GetCategoriesMatrix().node_ptr;
641 auto categories = GetCategoriesMatrix().categories;
642 auto segment = node_ptr[nidx];
643 auto node_cats = categories.subspan(segment.beg, segment.size);
644 return node_cats;
645 }
646 [[nodiscard]] auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
647
656 struct Segment {
657 std::size_t beg{0};
658 std::size_t size{0};
659 };
663 };
664
665 [[nodiscard]] CategoricalSplitMatrix GetCategoriesMatrix() const {
667 view.split_type = common::Span<FeatureType const>(this->GetSplitTypes());
668 view.categories = this->GetSplitCategories();
669 view.node_ptr = common::Span<CategoricalSplitMatrix::Segment const>(split_categories_segments_);
670 return view;
671 }
672
673 [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
674 if (IsMultiTarget()) {
675 return this->p_mt_tree_->SplitIndex(nidx);
676 }
677 return (*this)[nidx].SplitIndex();
678 }
679 [[nodiscard]] float SplitCond(bst_node_t nidx) const {
680 if (IsMultiTarget()) {
681 return this->p_mt_tree_->SplitCond(nidx);
682 }
683 return (*this)[nidx].SplitCond();
684 }
685 [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
686 if (IsMultiTarget()) {
687 return this->p_mt_tree_->DefaultLeft(nidx);
688 }
689 return (*this)[nidx].DefaultLeft();
690 }
691 [[nodiscard]] bool IsRoot(bst_node_t nidx) const {
692 if (IsMultiTarget()) {
693 return nidx == kRoot;
694 }
695 return (*this)[nidx].IsRoot();
696 }
697 [[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
698 if (IsMultiTarget()) {
699 return this->p_mt_tree_->IsLeaf(nidx);
700 }
701 return (*this)[nidx].IsLeaf();
702 }
703 [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
704 if (IsMultiTarget()) {
705 return this->p_mt_tree_->Parent(nidx);
706 }
707 return (*this)[nidx].Parent();
708 }
709 [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
710 if (IsMultiTarget()) {
711 return this->p_mt_tree_->LeftChild(nidx);
712 }
713 return (*this)[nidx].LeftChild();
714 }
715 [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
716 if (IsMultiTarget()) {
717 return this->p_mt_tree_->RightChild(nidx);
718 }
719 return (*this)[nidx].RightChild();
720 }
721 [[nodiscard]] bool IsLeftChild(bst_node_t nidx) const {
722 if (IsMultiTarget()) {
723 CHECK_NE(nidx, kRoot);
724 auto p = this->p_mt_tree_->Parent(nidx);
725 return nidx == this->p_mt_tree_->LeftChild(p);
726 }
727 return (*this)[nidx].IsLeftChild();
728 }
729 [[nodiscard]] bst_node_t Size() const {
730 if (IsMultiTarget()) {
731 return this->p_mt_tree_->Size();
732 }
733 return this->nodes_.size();
734 }
735
736 private:
737 template <bool typed>
738 void LoadCategoricalSplit(Json const& in);
739 void SaveCategoricalSplit(Json* p_out) const;
741 TreeParam param_;
742 // vector of nodes
743 std::vector<Node> nodes_;
744 // free node space, used during training process
745 std::vector<int> deleted_nodes_;
746 // stats of nodes
747 std::vector<RTreeNodeStat> stats_;
748 std::vector<FeatureType> split_types_;
749
750 // Categories for each internal node.
751 std::vector<uint32_t> split_categories_;
752 // Ptr to split categories of each node.
753 std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
754 // ptr to multi-target tree with vector leaf.
755 CopyUniquePtr<MultiTargetTree> p_mt_tree_;
756 // allocate a new node,
757 // !!!!!! NOTE: may cause BUG here, nodes.resize
758 bst_node_t AllocNode() {
759 if (param_.num_deleted != 0) {
760 int nid = deleted_nodes_.back();
761 deleted_nodes_.pop_back();
762 nodes_[nid].Reuse();
763 --param_.num_deleted;
764 return nid;
765 }
766 int nd = param_.num_nodes++;
767 CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
768 << "number of nodes in the tree exceed 2^31";
769 nodes_.resize(param_.num_nodes);
770 stats_.resize(param_.num_nodes);
771 split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
772 split_categories_segments_.resize(param_.num_nodes);
773 return nd;
774 }
775 // delete a tree node, keep the parent field to allow trace back
776 void DeleteNode(int nid) {
777 CHECK_GE(nid, 1);
778 auto pid = (*this)[nid].Parent();
779 if (nid == (*this)[pid].LeftChild()) {
780 (*this)[pid].SetLeftChild(kInvalidNodeId);
781 } else {
782 (*this)[pid].SetRightChild(kInvalidNodeId);
783 }
784
785 deleted_nodes_.push_back(nid);
786 nodes_[nid].MarkDelete();
787 ++param_.num_deleted;
788 }
789};
790
791inline void RegTree::FVec::Init(size_t size) {
792 Entry e; e.flag = -1;
793 data_.resize(size);
794 std::fill(data_.begin(), data_.end(), e);
795 has_missing_ = true;
796}
797
798inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
799 size_t feature_count = 0;
800 for (auto const& entry : inst) {
801 if (entry.index >= data_.size()) {
802 continue;
803 }
804 data_[entry.index].fvalue = entry.fvalue;
805 ++feature_count;
806 }
807 has_missing_ = data_.size() != feature_count;
808}
809
810inline void RegTree::FVec::Drop() {
811 Entry e{};
812 e.flag = -1;
813 std::fill_n(data_.data(), data_.size(), e);
814 has_missing_ = true;
815}
816
817inline size_t RegTree::FVec::Size() const {
818 return data_.size();
819}
820
821inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
822 return data_[i].fvalue;
823}
824
825inline bool RegTree::FVec::IsMissing(size_t i) const {
826 return data_[i].flag == -1;
827}
828
829inline bool RegTree::FVec::HasMissing() const {
830 return has_missing_;
831}
832
833// Multi-target tree not yet implemented error
834inline StringView MTNotImplemented() {
835 return " support for multi-target tree is not yet implemented.";
836}
837} // namespace xgboost
838#endif // XGBOOST_TREE_MODEL_H_
interface of stream I/O for serialization
Definition io.h:30
Helper for defining copyable data structure that contains unique pointers.
Definition tree_model.h:129
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition feature_map.h:22
Data structure representing JSON format.
Definition json.h:357
Tree structure for multi-target model.
Definition multi_target_tree_model.h:23
tree node
Definition tree_model.h:166
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition tree_model.h:201
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition tree_model.h:246
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition tree_model.h:207
XGBOOST_DEVICE int RightChild() const
index of right child
Definition tree_model.h:183
XGBOOST_DEVICE float LeafValue() const
Definition tree_model.h:197
XGBOOST_DEVICE unsigned SplitIndex() const
feature index of split condition
Definition tree_model.h:189
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition tree_model.h:240
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition tree_model.h:203
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition tree_model.h:228
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition tree_model.h:212
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition tree_model.h:205
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition tree_model.h:195
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition tree_model.h:250
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition tree_model.h:219
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition tree_model.h:193
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition tree_model.h:181
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition tree_model.h:185
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition tree_model.h:199
define regression tree to be the most common tree model.
Definition tree_model.h:158
int MaxDepth(int nid) const
get maximum depth
Definition tree_model.h:540
void SaveModel(Json *out) const override
saves the model config to a JSON object
Definition tree_model.cc:1142
bst_target_t NumTargets() const
The size of leaf weight.
Definition tree_model.h:481
void Save(dmlc::Stream *fo) const
save model to stream
Definition tree_model.cc:892
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition tree_model.h:360
bst_node_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition tree_model.h:496
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Definition tree_model.cc:791
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
Get the bit storage for categories.
Definition tree_model.h:639
bool IsMultiTarget() const
Whether this is a multi-target tree.
Definition tree_model.h:477
bst_node_t NumExtraNodes() const noexcept
number of extra nodes besides the root
Definition tree_model.h:506
const Node & operator[](int nid) const
get node given nid
Definition tree_model.h:349
auto GetMultiTargetTree() const
Get the underlying implementaiton of multi-target tree.
Definition tree_model.h:485
void Load(dmlc::Stream *fi)
load model from stream
Definition tree_model.cc:857
const std::vector< RTreeNodeStat > & GetStats() const
get const reference to stats
Definition tree_model.h:357
RegTree(bst_target_t n_targets, bst_feature_t n_features)
Constructor that initializes the tree model with shape.
Definition tree_model.h:336
Node & operator[](int nid)
get node given nid
Definition tree_model.h:345
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
Definition tree_model.cc:837
bool Equal(const RegTree &b) const
Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted ...
Definition tree_model.cc:748
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition tree_model.h:311
bst_node_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition tree_model.h:500
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition tree_model.h:364
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition tree_model.h:354
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition tree_model.h:299
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
Set the leaf weight for a multi-target tree.
Definition tree_model.h:531
void CalculateContributionsApprox(const RegTree::FVec &feat, std::vector< float > *mean_values, bst_float *out_contribs) const
calculate the approximate feature contributions for the given root
Definition tree_model.cc:1221
void LoadModel(Json const &in) override
load the model from a JSON object
Definition tree_model.cc:1085
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format as a text string
Definition tree_model.cc:739
FeatureType NodeSplitType(bst_node_t nidx) const
Get split type for a node.
Definition tree_model.h:626
bst_feature_t NumFeatures() const noexcept
Get the number of features.
Definition tree_model.h:492
bool HasCategoricalSplit() const
Whether this tree has categorical split.
Definition tree_model.h:473
std::vector< FeatureType > const & GetSplitTypes() const
Get split types for all nodes.
Definition tree_model.h:630
std::int32_t GetDepth(bst_node_t nid) const
get current depth
Definition tree_model.h:517
int MaxDepth()
get maximum depth
Definition tree_model.h:548
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
A tensor view with static type and dimension.
Definition linalg.h:293
defines serializable interface of dmlc
Provide lightweight util to do parameter setup and checking.
Feature map data structure to help visualization and model dump.
Copyright 2015-2023 by XGBoost Contributors.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition base.h:64
Copyright 2015-2023 by XGBoost Contributors.
defines console logging options for xgboost. Use to enforce unified print behavior.
Copyright 2021-2023 by XGBoost Contributors.
Defines the abstract interface for different components in XGBoost.
void ByteSwap(void *data, size_t elem_bytes, size_t num_elems)
A generic inplace byte swapping function.
Definition endian.h:51
namespace of xgboost
Definition base.h:90
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
std::int32_t bst_node_t
Type for tree node index.
Definition base.h:112
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition base.h:118
float bst_float
float type, used for storing statistics
Definition base.h:97
Definition tree_shap.h:119
Definition model.h:17
node statistics used in regression tree
Definition tree_model.h:96
bst_float loss_chg
loss change caused by current split
Definition tree_model.h:98
int leaf_child_cnt
number of child that is leaf node known up to now
Definition tree_model.h:104
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition tree_model.h:100
bst_float base_weight
weight of current node
Definition tree_model.h:102
CSR-like matrix for categorical splits.
Definition tree_model.h:655
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector.
Definition tree_model.h:554
void Drop()
drop the trace after fill, must be called after fill.
Definition tree_model.h:810
void Fill(const SparsePage::Inst &inst)
fill the vector with sparse vector
Definition tree_model.h:798
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition tree_model.h:825
size_t Size() const
returns the size of the feature vector
Definition tree_model.h:817
bst_float GetFvalue(size_t i) const
get ith value
Definition tree_model.h:821
void Init(size_t size)
initialize the vector with size vector
Definition tree_model.h:791
Definition string_view.h:15
meta parameters of the tree
Definition tree_model.h:35
bst_feature_t num_feature
number of features used for tree construction
Definition tree_model.h:45
int num_nodes
total number of nodes
Definition tree_model.h:39
int num_deleted
number of deleted nodes
Definition tree_model.h:41
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition tree_model.h:52
TreeParam()
constructor
Definition tree_model.h:54
bst_target_t size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree
Definition tree_model.h:50
int deprecated_num_roots
(Deprecated) number of start root
Definition tree_model.h:37
int deprecated_max_depth
maximum depth, this is a statistics of the tree
Definition tree_model.h:43