Medial Code Documentation
Loading...
Searching...
No Matches
gbtree_model.h
Go to the documentation of this file.
1
5#ifndef XGBOOST_GBM_GBTREE_MODEL_H_
6#define XGBOOST_GBM_GBTREE_MODEL_H_
7
8#include <dmlc/io.h>
9#include <dmlc/parameter.h>
10#include <xgboost/context.h>
11#include <xgboost/learner.h>
12#include <xgboost/model.h>
13#include <xgboost/parameter.h>
14#include <xgboost/tree_model.h>
15
16#include <memory>
17#include <string>
18#include <utility>
19#include <vector>
20
21#include "../common/threading_utils.h"
22
23namespace xgboost {
24
25class Json;
26
27namespace gbm {
31using TreesOneGroup = std::vector<std::unique_ptr<RegTree>>;
35using TreesOneIter = std::vector<TreesOneGroup>;
36
38struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
39 public:
43 std::int32_t num_trees;
47 std::int32_t num_parallel_tree;
49 int32_t reserved[38];
50
53 std::memset(this, 0, sizeof(GBTreeModelParam)); // FIXME(trivialfis): Why?
54 static_assert(sizeof(GBTreeModelParam) == (4 + 2 + 2 + 32) * sizeof(int32_t),
55 "64/32 bit compatibility issue");
57 }
58
59 // declare parameters, only declare those that need to be set.
60 DMLC_DECLARE_PARAMETER(GBTreeModelParam) {
61 DMLC_DECLARE_FIELD(num_trees)
62 .set_lower_bound(0)
63 .set_default(0)
64 .describe("Number of features used for training and prediction.");
65 DMLC_DECLARE_FIELD(num_parallel_tree)
66 .set_default(1)
67 .set_lower_bound(1)
68 .describe(
69 "Number of parallel trees constructed during each iteration."
70 " This option is used to support boosted random forest.");
71 }
72
73 // Swap byte order for all fields. Useful for transporting models between machines with different
74 // endianness (big endian vs little endian)
75 GBTreeModelParam ByteSwap() const {
76 GBTreeModelParam x = *this;
77 dmlc::ByteSwap(&x.num_trees, sizeof(x.num_trees), 1);
78 dmlc::ByteSwap(&x.num_parallel_tree, sizeof(x.num_parallel_tree), 1);
79 dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
80 return x;
81 }
82};
83
84struct GBTreeModel : public Model {
85 public:
86 explicit GBTreeModel(LearnerModelParam const* learner_model, Context const* ctx)
87 : learner_model_param{learner_model}, ctx_{ctx} {}
88 void Configure(const Args& cfg) {
89 // initialize model parameters if not yet been initialized.
90 if (trees.size() == 0) {
91 param.UpdateAllowUnknown(cfg);
92 }
93 }
94
95 void InitTreesToUpdate() {
96 if (trees_to_update.size() == 0u) {
97 for (auto & tree : trees) {
98 trees_to_update.push_back(std::move(tree));
99 }
100 trees.clear();
101 param.num_trees = 0;
102 tree_info.clear();
103
104 iteration_indptr.clear();
105 iteration_indptr.push_back(0);
106 }
107 }
108
109 void Load(dmlc::Stream* fi);
110 void Save(dmlc::Stream* fo) const;
111
112 void SaveModel(Json* p_out) const override;
113 void LoadModel(Json const& p_out) override;
114
115 [[nodiscard]] std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
116 int32_t n_threads, std::string format) const {
117 std::vector<std::string> dump(trees.size());
118 common::ParallelFor(trees.size(), n_threads,
119 [&](size_t i) { dump[i] = trees[i]->DumpModel(fmap, with_stats, format); });
120 return dump;
121 }
128
129 void CommitModelGroup(std::vector<std::unique_ptr<RegTree>>&& new_trees, bst_target_t group_idx) {
130 for (auto& new_tree : new_trees) {
131 trees.push_back(std::move(new_tree));
132 tree_info.push_back(group_idx);
133 }
134 param.num_trees += static_cast<int>(new_trees.size());
135 }
136
137 [[nodiscard]] std::int32_t BoostedRounds() const {
138 if (trees.empty()) {
139 CHECK_EQ(iteration_indptr.size(), 1);
140 }
141 return static_cast<std::int32_t>(iteration_indptr.size() - 1);
142 }
143
144 // base margin
145 LearnerModelParam const* learner_model_param;
146 // model parameter
147 GBTreeModelParam param;
149 std::vector<std::unique_ptr<RegTree> > trees;
151 std::vector<std::unique_ptr<RegTree> > trees_to_update;
155 std::vector<int> tree_info;
159 std::vector<bst_tree_t> iteration_indptr{0};
160
161 private:
165 Context const* ctx_;
166};
167} // namespace gbm
168} // namespace xgboost
169
170#endif // XGBOOST_GBM_GBTREE_MODEL_H_
interface of stream I/O for serialization
Definition io.h:30
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition feature_map.h:22
Data structure representing JSON format.
Definition json.h:357
Copyright 2014-2023, XGBoost Contributors.
defines serializable interface of dmlc
Provide lightweight util to do parameter setup and checking.
macro for using C++11 enum class as DMLC parameter
Copyright 2015-2023 by XGBoost Contributors.
Defines the abstract interface for different components in XGBoost.
void ByteSwap(void *data, size_t elem_bytes, size_t num_elems)
A generic inplace byte swapping function.
Definition endian.h:51
std::vector< TreesOneGroup > TreesOneIter
Container for all trees built (not update) for one iteration.
Definition gbtree_model.h:35
std::vector< std::unique_ptr< RegTree > > TreesOneGroup
Container for all trees built (not update) for one group.
Definition gbtree_model.h:31
namespace of xgboost
Definition base.h:90
std::int32_t bst_tree_t
Type for indexing trees.
Definition base.h:126
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition base.h:118
Runtime context for XGBoost.
Definition context.h:84
Basic model parameters, used to describe the booster.
Definition learner.h:291
Definition model.h:17
model parameters
Definition gbtree_model.h:38
std::int32_t num_parallel_tree
Number of trees for a forest.
Definition gbtree_model.h:47
std::int32_t num_trees
number of trees
Definition gbtree_model.h:43
GBTreeModelParam()
constructor
Definition gbtree_model.h:52
int32_t reserved[38]
reserved parameters
Definition gbtree_model.h:49
Definition gbtree_model.h:84
bst_tree_t CommitModel(TreesOneIter &&new_trees)
Add trees to the model.
Definition gbtree_model.cc:180
std::vector< int > tree_info
Group index for trees.
Definition gbtree_model.h:155
void LoadModel(Json const &p_out) override
load the model from a JSON object
Definition gbtree_model.cc:138
std::vector< bst_tree_t > iteration_indptr
Number of trees accumulated for each iteration.
Definition gbtree_model.h:159
std::vector< std::unique_ptr< RegTree > > trees_to_update
for the update process, a place to keep the initial trees
Definition gbtree_model.h:151
void SaveModel(Json *p_out) const override
saves the model config to a JSON object
Definition gbtree_model.cc:109
std::vector< std::unique_ptr< RegTree > > trees
vector of trees stored in the model
Definition gbtree_model.h:149
Copyright 2014-2023 by Contributors.