Medial Code Documentation
Loading...
Searching...
No Matches
multi_target_tree_model.h
1
6#ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
7#define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
8#include <xgboost/base.h> // for bst_node_t, bst_target_t, bst_feature_t
9#include <xgboost/context.h> // for Context
10#include <xgboost/linalg.h> // for VectorView
11#include <xgboost/model.h> // for Model
12#include <xgboost/span.h> // for Span
13
14#include <cinttypes> // for uint8_t
15#include <cstddef> // for size_t
16#include <vector> // for vector
17
18namespace xgboost {
19struct TreeParam;
23class MultiTargetTree : public Model {
24 public:
25 static bst_node_t constexpr InvalidNodeId() { return -1; }
26
27 private:
28 TreeParam const* param_;
29 std::vector<bst_node_t> left_;
30 std::vector<bst_node_t> right_;
31 std::vector<bst_node_t> parent_;
32 std::vector<bst_feature_t> split_index_;
33 std::vector<std::uint8_t> default_left_;
34 std::vector<float> split_conds_;
35 std::vector<float> weights_;
36
37 [[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
38 auto beg = nidx * this->NumTarget();
39 auto v = common::Span<float const>{weights_}.subspan(beg, this->NumTarget());
40 return linalg::MakeTensorView(Context::kCpuId, v, v.size());
41 }
42 [[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
43 auto beg = nidx * this->NumTarget();
44 auto v = common::Span<float>{weights_}.subspan(beg, this->NumTarget());
45 return linalg::MakeTensorView(Context::kCpuId, v, v.size());
46 }
47
48 public:
49 explicit MultiTargetTree(TreeParam const* param);
57 void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
61
62 [[nodiscard]] bool IsLeaf(bst_node_t nidx) const { return left_[nidx] == InvalidNodeId(); }
63 [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const { return parent_.at(nidx); }
64 [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const { return left_.at(nidx); }
65 [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const { return right_.at(nidx); }
66
67 [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const { return split_index_[nidx]; }
68 [[nodiscard]] float SplitCond(bst_node_t nidx) const { return split_conds_[nidx]; }
69 [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const { return default_left_[nidx]; }
70 [[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
71 return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
72 }
73
74 [[nodiscard]] bst_target_t NumTarget() const;
75
76 [[nodiscard]] std::size_t Size() const;
77
78 [[nodiscard]] bst_node_t Depth(bst_node_t nidx) const {
79 bst_node_t depth{0};
80 while (Parent(nidx) != InvalidNodeId()) {
81 ++depth;
82 nidx = Parent(nidx);
83 }
84 return depth;
85 }
86
87 [[nodiscard]] linalg::VectorView<float const> LeafValue(bst_node_t nidx) const {
88 CHECK(IsLeaf(nidx));
89 return this->NodeWeight(nidx);
90 }
91
92 void LoadModel(Json const& in) override;
93 void SaveModel(Json* out) const override;
94};
95} // namespace xgboost
96#endif // XGBOOST_MULTI_TARGET_TREE_MODEL_H_
Data structure representing JSON format.
Definition json.h:357
Tree structure for multi-target model.
Definition multi_target_tree_model.h:23
void SaveModel(Json *out) const override
saves the model config to a JSON object
Definition multi_target_tree_model.cc:98
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
Set the weight for a leaf.
Definition multi_target_tree_model.cc:156
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
Expand a leaf into split node.
Definition multi_target_tree_model.cc:167
void LoadModel(Json const &in) override
load the model from a JSON object
Definition multi_target_tree_model.cc:78
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition span.h:596
A tensor view with static type and dimension.
Definition linalg.h:293
Copyright 2014-2023, XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2021-2023 by XGBoost Contributors.
Defines the abstract interface for different components in XGBoost.
auto MakeTensorView(Context const *ctx, Container &data, S &&...shape)
Constructor for automatic type deduction.
Definition linalg.h:576
namespace of xgboost
Definition base.h:90
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
std::int32_t bst_node_t
Type for tree node index.
Definition base.h:112
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition base.h:118
Definition model.h:17
meta parameters of the tree
Definition tree_model.h:35