7#ifndef XGBOOST_TREE_MODEL_H_
8#define XGBOOST_TREE_MODEL_H_
18#include <xgboost/multi_target_tree_model.h>
35struct TreeParam :
public dmlc::Parameter<TreeParam> {
56 static_assert(
sizeof(
TreeParam) == (31 + 6) *
sizeof(int),
"TreeParam: 64 bit align");
62 [[nodiscard]]
TreeParam ByteSwap()
const {
64 dmlc::ByteSwap(&x.deprecated_num_roots,
sizeof(x.deprecated_num_roots), 1);
67 dmlc::ByteSwap(&x.deprecated_max_depth,
sizeof(x.deprecated_max_depth), 1);
69 dmlc::ByteSwap(&x.size_leaf_vector,
sizeof(x.size_leaf_vector), 1);
70 dmlc::ByteSwap(x.reserved,
sizeof(x.reserved[0]),
sizeof(x.reserved) /
sizeof(x.reserved[0]));
78 DMLC_DECLARE_FIELD(
num_nodes).set_lower_bound(1).set_default(1);
81 .describe(
"Number of features used in tree construction.");
86 .describe(
"Size of leaf vector, reserved for vector tree");
89 bool operator==(
const TreeParam& b)
const {
109 bool operator==(
const RTreeNodeStat& b)
const {
115 [[nodiscard]] RTreeNodeStat ByteSwap()
const {
116 RTreeNodeStat x = *
this;
130 std::unique_ptr<T> ptr_{
nullptr};
137 ptr_ = std::make_unique<T>(*that);
140 T* get()
const noexcept {
return ptr_.get(); }
142 T& operator*() {
return *ptr_; }
143 T* operator->()
noexcept {
return this->get(); }
145 T
const& operator*()
const {
return *ptr_; }
146 T
const* operator->()
const noexcept {
return this->get(); }
148 explicit operator bool()
const {
return static_cast<bool>(ptr_); }
149 bool operator!()
const {
return !ptr_; }
150 void reset(T* ptr) { ptr_.reset(ptr); }
161 static constexpr bst_node_t kInvalidNodeId{MultiTargetTree::InvalidNodeId()};
162 static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
170 static_assert(
sizeof(
Node) == 4 *
sizeof(
int) +
sizeof(Info),
171 "Node: 64 bit align");
173 Node(int32_t cleft, int32_t cright, int32_t parent,
174 uint32_t split_ind,
float split_cond,
bool default_left) :
175 parent_{parent}, cleft_{cleft}, cright_{cright} {
176 this->SetParent(parent_);
177 this->
SetSplit(split_ind, split_cond, default_left);
190 return sindex_ & ((1U << 31) - 1U);
229 bool default_left =
false) {
230 if (default_left) split_index |= (1U << 31);
231 this->sindex_ = split_index;
232 (this->info_).split_cond = split_cond;
241 (this->info_).leaf_value = value;
242 this->cleft_ = kInvalidNodeId;
243 this->cright_ = right;
247 this->sindex_ = kDeletedNodeMarker;
254 XGBOOST_DEVICE void SetParent(
int pidx,
bool is_left_child =
true) {
255 if (is_left_child) pidx |= (1U << 31);
256 this->parent_ = pidx;
258 bool operator==(
const Node& b)
const {
259 return parent_ == b.parent_ && cleft_ == b.cleft_ &&
260 cright_ == b.cright_ && sindex_ == b.sindex_ &&
261 info_.leaf_value == b.info_.leaf_value;
264 [[nodiscard]]
Node ByteSwap()
const {
281 SplitCondT split_cond;
285 int32_t parent_{kInvalidNodeId};
287 int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
300 CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
301 CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
302 this->DeleteNode(nodes_[rid].LeftChild());
303 this->DeleteNode(nodes_[rid].RightChild());
304 nodes_[rid].SetLeaf(value);
312 if (nodes_[rid].IsLeaf())
return;
313 if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
316 if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
326 split_types_.resize(param_.
num_nodes, FeatureType::kNumerical);
327 split_categories_segments_.resize(param_.
num_nodes);
328 for (
int i = 0; i < param_.
num_nodes; i++) {
329 nodes_[i].SetLeaf(0.0f);
330 nodes_[i].SetParent(kInvalidNodeId);
354 [[nodiscard]]
const std::vector<Node>&
GetNodes()
const {
return nodes_; }
357 [[nodiscard]]
const std::vector<RTreeNodeStat>&
GetStats()
const {
return stats_; }
382 bool operator==(
const RegTree& b)
const {
383 return nodes_ == b.nodes_ && stats_ == b.stats_ &&
384 deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
391 template <
typename Func>
void WalkTree(Func func)
const {
392 std::stack<bst_node_t> nodes;
395 while (!nodes.empty()) {
396 auto nidx = nodes.top();
401 auto left = self[nidx].LeftChild();
402 auto right = self[nidx].RightChild();
403 if (left != RegTree::kInvalidNodeId) {
406 if (right != RegTree::kInvalidNodeId) {
417 [[nodiscard]]
bool Equal(
const RegTree& b)
const;
437 bool default_left,
bst_float base_weight,
439 bst_float loss_change,
float sum_hess,
float left_sum,
441 bst_node_t leaf_right_child = kInvalidNodeId);
446 linalg::VectorView<float const> base_weight,
447 linalg::VectorView<float const> left_weight,
448 linalg::VectorView<float const> right_weight);
466 common::Span<const uint32_t> split_cat,
bool default_left,
469 float left_sum,
float right_sum);
477 [[nodiscard]]
bool IsMultiTarget()
const {
return static_cast<bool>(p_mt_tree_); }
487 return p_mt_tree_.get();
510 [[nodiscard]]
bst_node_t GetNumLeaves()
const;
511 [[nodiscard]]
bst_node_t GetNumSplitNodes()
const;
519 return this->p_mt_tree_->Depth(nid);
522 while (!nodes_[nid].IsRoot()) {
524 nid = nodes_[nid].Parent();
533 return this->p_mt_tree_->SetLeaf(nidx, weight);
541 if (nodes_[nid].IsLeaf())
return 0;
542 return std::max(
MaxDepth(nodes_[nid].LeftChild()) + 1,
MaxDepth(nodes_[nid].RightChild()) + 1);
559 void Init(
size_t size);
575 [[nodiscard]]
size_t Size()
const;
587 [[nodiscard]]
bool IsMissing(
size_t i)
const;
588 [[nodiscard]]
bool HasMissing()
const;
600 std::vector<Entry> data_;
610 std::vector<float>* mean_values,
620 std::string format)
const;
634 return split_categories_;
640 auto node_ptr = GetCategoriesMatrix().node_ptr;
641 auto categories = GetCategoriesMatrix().categories;
642 auto segment = node_ptr[nidx];
643 auto node_cats = categories.subspan(segment.beg, segment.size);
646 [[nodiscard]]
auto const& GetSplitCategoriesPtr()
const {
return split_categories_segments_; }
668 view.categories = this->GetSplitCategories();
675 return this->p_mt_tree_->SplitIndex(nidx);
677 return (*
this)[nidx].SplitIndex();
679 [[nodiscard]]
float SplitCond(
bst_node_t nidx)
const {
681 return this->p_mt_tree_->SplitCond(nidx);
683 return (*
this)[nidx].SplitCond();
685 [[nodiscard]]
bool DefaultLeft(
bst_node_t nidx)
const {
687 return this->p_mt_tree_->DefaultLeft(nidx);
689 return (*
this)[nidx].DefaultLeft();
691 [[nodiscard]]
bool IsRoot(
bst_node_t nidx)
const {
693 return nidx == kRoot;
695 return (*
this)[nidx].IsRoot();
697 [[nodiscard]]
bool IsLeaf(
bst_node_t nidx)
const {
699 return this->p_mt_tree_->IsLeaf(nidx);
701 return (*
this)[nidx].IsLeaf();
705 return this->p_mt_tree_->Parent(nidx);
707 return (*
this)[nidx].Parent();
711 return this->p_mt_tree_->LeftChild(nidx);
713 return (*
this)[nidx].LeftChild();
717 return this->p_mt_tree_->RightChild(nidx);
719 return (*
this)[nidx].RightChild();
721 [[nodiscard]]
bool IsLeftChild(
bst_node_t nidx)
const {
723 CHECK_NE(nidx, kRoot);
724 auto p = this->p_mt_tree_->Parent(nidx);
725 return nidx == this->p_mt_tree_->LeftChild(p);
727 return (*
this)[nidx].IsLeftChild();
731 return this->p_mt_tree_->Size();
733 return this->nodes_.size();
737 template <
bool typed>
738 void LoadCategoricalSplit(Json
const& in);
739 void SaveCategoricalSplit(Json* p_out)
const;
743 std::vector<Node> nodes_;
745 std::vector<int> deleted_nodes_;
747 std::vector<RTreeNodeStat> stats_;
748 std::vector<FeatureType> split_types_;
751 std::vector<uint32_t> split_categories_;
753 std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
755 CopyUniquePtr<MultiTargetTree> p_mt_tree_;
759 if (param_.num_deleted != 0) {
760 int nid = deleted_nodes_.back();
761 deleted_nodes_.pop_back();
763 --param_.num_deleted;
766 int nd = param_.num_nodes++;
767 CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
768 <<
"number of nodes in the tree exceed 2^31";
769 nodes_.resize(param_.num_nodes);
770 stats_.resize(param_.num_nodes);
771 split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
772 split_categories_segments_.resize(param_.num_nodes);
776 void DeleteNode(
int nid) {
778 auto pid = (*this)[nid].Parent();
779 if (nid == (*
this)[pid].LeftChild()) {
780 (*this)[pid].SetLeftChild(kInvalidNodeId);
782 (*this)[pid].SetRightChild(kInvalidNodeId);
785 deleted_nodes_.push_back(nid);
786 nodes_[nid].MarkDelete();
787 ++param_.num_deleted;
792 Entry e; e.flag = -1;
794 std::fill(data_.begin(), data_.end(), e);
799 size_t feature_count = 0;
800 for (
auto const& entry : inst) {
801 if (entry.index >= data_.size()) {
804 data_[entry.index].fvalue = entry.fvalue;
807 has_missing_ = data_.size() != feature_count;
813 std::fill_n(data_.data(), data_.size(), e);
822 return data_[i].fvalue;
826 return data_[i].flag == -1;
829inline bool RegTree::FVec::HasMissing()
const {
835 return " support for multi-target tree is not yet implemented.";
interface of stream I/O for serialization
Definition io.h:30
Helper for defining copyable data structure that contains unique pointers.
Definition tree_model.h:129
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition feature_map.h:22
Data structure representing JSON format.
Definition json.h:357
Tree structure for multi-target model.
Definition multi_target_tree_model.h:23
tree node
Definition tree_model.h:166
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition tree_model.h:201
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition tree_model.h:246
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition tree_model.h:207
XGBOOST_DEVICE int RightChild() const
index of right child
Definition tree_model.h:183
XGBOOST_DEVICE float LeafValue() const
Definition tree_model.h:197
XGBOOST_DEVICE unsigned SplitIndex() const
feature index of split condition
Definition tree_model.h:189
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition tree_model.h:240
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition tree_model.h:203
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition tree_model.h:228
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition tree_model.h:212
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition tree_model.h:205
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition tree_model.h:195
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition tree_model.h:250
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition tree_model.h:219
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition tree_model.h:193
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition tree_model.h:181
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition tree_model.h:185
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition tree_model.h:199
define regression tree to be the most common tree model.
Definition tree_model.h:158
int MaxDepth(int nid) const
get maximum depth
Definition tree_model.h:540
void SaveModel(Json *out) const override
saves the model config to a JSON object
Definition tree_model.cc:1142
bst_target_t NumTargets() const
The size of leaf weight.
Definition tree_model.h:481
void Save(dmlc::Stream *fo) const
save model to stream
Definition tree_model.cc:892
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition tree_model.h:360
bst_node_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition tree_model.h:496
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Definition tree_model.cc:791
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
Get the bit storage for categories.
Definition tree_model.h:639
bool IsMultiTarget() const
Whether this is a multi-target tree.
Definition tree_model.h:477
bst_node_t NumExtraNodes() const noexcept
number of extra nodes besides the root
Definition tree_model.h:506
const Node & operator[](int nid) const
get node given nid
Definition tree_model.h:349
auto GetMultiTargetTree() const
Get the underlying implementaiton of multi-target tree.
Definition tree_model.h:485
void Load(dmlc::Stream *fi)
load model from stream
Definition tree_model.cc:857
const std::vector< RTreeNodeStat > & GetStats() const
get const reference to stats
Definition tree_model.h:357
RegTree(bst_target_t n_targets, bst_feature_t n_features)
Constructor that initializes the tree model with shape.
Definition tree_model.h:336
Node & operator[](int nid)
get node given nid
Definition tree_model.h:345
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
Definition tree_model.cc:837
bool Equal(const RegTree &b) const
Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted ...
Definition tree_model.cc:748
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition tree_model.h:311
bst_node_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition tree_model.h:500
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition tree_model.h:364
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition tree_model.h:354
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition tree_model.h:299
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
Set the leaf weight for a multi-target tree.
Definition tree_model.h:531
void CalculateContributionsApprox(const RegTree::FVec &feat, std::vector< float > *mean_values, bst_float *out_contribs) const
calculate the approximate feature contributions for the given root
Definition tree_model.cc:1221
void LoadModel(Json const &in) override
load the model from a JSON object
Definition tree_model.cc:1085
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format as a text string
Definition tree_model.cc:739
FeatureType NodeSplitType(bst_node_t nidx) const
Get split type for a node.
Definition tree_model.h:626
bst_feature_t NumFeatures() const noexcept
Get the number of features.
Definition tree_model.h:492
bool HasCategoricalSplit() const
Whether this tree has categorical split.
Definition tree_model.h:473
std::vector< FeatureType > const & GetSplitTypes() const
Get split types for all nodes.
Definition tree_model.h:630
std::int32_t GetDepth(bst_node_t nid) const
get current depth
Definition tree_model.h:517
int MaxDepth()
get maximum depth
Definition tree_model.h:548
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
A tensor view with static type and dimension.
Definition linalg.h:293
defines serializable interface of dmlc
Provide lightweight util to do parameter setup and checking.
Feature map data structure to help visualization and model dump.
Copyright 2015-2023 by XGBoost Contributors.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition base.h:64
Copyright 2015-2023 by XGBoost Contributors.
defines console logging options for xgboost. Use to enforce unified print behavior.
Copyright 2021-2023 by XGBoost Contributors.
Defines the abstract interface for different components in XGBoost.
void ByteSwap(void *data, size_t elem_bytes, size_t num_elems)
A generic inplace byte swapping function.
Definition endian.h:51
namespace of xgboost
Definition base.h:90
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
std::int32_t bst_node_t
Type for tree node index.
Definition base.h:112
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition base.h:118
float bst_float
float type, used for storing statistics
Definition base.h:97
Definition tree_shap.h:119
node statistics used in regression tree
Definition tree_model.h:96
bst_float loss_chg
loss change caused by current split
Definition tree_model.h:98
int leaf_child_cnt
number of child that is leaf node known up to now
Definition tree_model.h:104
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition tree_model.h:100
bst_float base_weight
weight of current node
Definition tree_model.h:102
Definition tree_model.h:656
CSR-like matrix for categorical splits.
Definition tree_model.h:655
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector.
Definition tree_model.h:554
void Drop()
drop the trace after fill, must be called after fill.
Definition tree_model.h:810
void Fill(const SparsePage::Inst &inst)
fill the vector with sparse vector
Definition tree_model.h:798
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition tree_model.h:825
size_t Size() const
returns the size of the feature vector
Definition tree_model.h:817
bst_float GetFvalue(size_t i) const
get ith value
Definition tree_model.h:821
void Init(size_t size)
initialize the vector with size vector
Definition tree_model.h:791
Definition string_view.h:15
meta parameters of the tree
Definition tree_model.h:35
bst_feature_t num_feature
number of features used for tree construction
Definition tree_model.h:45
int num_nodes
total number of nodes
Definition tree_model.h:39
int num_deleted
number of deleted nodes
Definition tree_model.h:41
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition tree_model.h:52
TreeParam()
constructor
Definition tree_model.h:54
bst_target_t size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree
Definition tree_model.h:50
int deprecated_num_roots
(Deprecated) number of start root
Definition tree_model.h:37
int deprecated_max_depth
maximum depth, this is a statistics of the tree
Definition tree_model.h:43