Medial Code Documentation
Loading...
Searching...
No Matches
rf.hpp
1#ifndef LIGHTGBM_BOOSTING_RF_H_
2#define LIGHTGBM_BOOSTING_RF_H_
3
4#include <LightGBM/boosting.h>
5#include <LightGBM/metric.h>
6#include "score_updater.hpp"
7#include "gbdt.h"
8
9#include <cstdio>
10#include <vector>
11#include <string>
12#include <fstream>
13
14namespace LightGBM {
18class RF : public GBDT {
19public:
20 RF() : GBDT() {
21 average_output_ = true;
22 }
23
24 ~RF() {}
25
26 void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
27 const std::vector<const Metric*>& training_metrics) override {
28 CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f);
29 CHECK(config->feature_fraction <= 1.0f && config->feature_fraction > 0.0f);
30 GBDT::Init(config, train_data, objective_function, training_metrics);
31
32 if (num_init_iteration_ > 0) {
33 for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
34 MultiplyScore(cur_tree_id, 1.0f / num_init_iteration_);
35 }
36 } else {
37 CHECK(train_data->metadata().init_score() == nullptr);
38 }
40 // not shrinkage rate for the RF
41 shrinkage_rate_ = 1.0f;
42 // only boosting one time
43 Boosting();
44 if (is_use_subset_ && bag_data_cnt_ < num_data_) {
45 tmp_grad_.resize(num_data_);
46 tmp_hess_.resize(num_data_);
47 }
48 }
49
50 void ResetConfig(const Config* config) override {
51 CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f);
52 CHECK(config->feature_fraction <= 1.0f && config->feature_fraction > 0.0f);
53 GBDT::ResetConfig(config);
54 // not shrinkage rate for the RF
55 shrinkage_rate_ = 1.0f;
56 }
57
58 void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
59 const std::vector<const Metric*>& training_metrics) override {
60 GBDT::ResetTrainingData(train_data, objective_function, training_metrics);
61 if (iter_ + num_init_iteration_ > 0) {
62 for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
63 train_score_updater_->MultiplyScore(1.0f / (iter_ + num_init_iteration_), cur_tree_id);
64 }
65 }
67 // only boosting one time
68 Boosting();
69 if (is_use_subset_ && bag_data_cnt_ < num_data_) {
70 tmp_grad_.resize(num_data_);
71 tmp_hess_.resize(num_data_);
72 }
73 }
74
75 void Boosting() override {
76 if (objective_function_ == nullptr) {
77 Log::Fatal("No object function provided");
78 }
79 init_scores_.resize(num_tree_per_iteration_, 0.0);
80 for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
81 init_scores_[cur_tree_id] = BoostFromAverage(cur_tree_id, false);
82 }
83 size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
84 std::vector<double> tmp_scores(total_size, 0.0f);
85 #pragma omp parallel for schedule(static)
86 for (int j = 0; j < num_tree_per_iteration_; ++j) {
87 size_t bias = static_cast<size_t>(j)* num_data_;
88 for (data_size_t i = 0; i < num_data_; ++i) {
89 tmp_scores[bias + i] = init_scores_[j];
90 }
91 }
93 GetGradients(tmp_scores.data(), gradients_.data(), hessians_.data());
94 }
95
96 bool TrainOneIter(const score_t* gradients, const score_t* hessians) override {
97 // bagging logic
99 CHECK(gradients == nullptr);
100 CHECK(hessians == nullptr);
101
102 gradients = gradients_.data();
103 hessians = hessians_.data();
104 for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
105 std::unique_ptr<Tree> new_tree(new Tree(2));
106 size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
107 if (class_need_train_[cur_tree_id]) {
108 auto grad = gradients + bias;
109 auto hess = hessians + bias;
110
111 // need to copy gradients for bagging subset.
112 if (is_use_subset_ && bag_data_cnt_ < num_data_) {
113 for (int i = 0; i < bag_data_cnt_; ++i) {
114 tmp_grad_[i] = grad[bag_data_indices_[i]];
115 tmp_hess_[i] = hess[bag_data_indices_[i]];
116 }
117 grad = tmp_grad_.data();
118 hess = tmp_hess_.data();
119 }
120
121 new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_,
122 forced_splits_json_));
123 }
124
125 if (new_tree->num_leaves() > 1) {
126 tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, init_scores_[cur_tree_id],
128 if (std::fabs(init_scores_[cur_tree_id]) > kEpsilon) {
129 new_tree->AddBias(init_scores_[cur_tree_id]);
130 }
131 // update score
132 MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_));
133 UpdateScore(new_tree.get(), cur_tree_id);
134 MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1));
135 } else {
136 // only add default score one-time
137 if (models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
138 double output = 0.0;
139 if (!class_need_train_[cur_tree_id]) {
140 if (objective_function_ != nullptr) {
141 output = objective_function_->BoostFromScore(cur_tree_id);
142 } else {
143 output = init_scores_[cur_tree_id];
144 }
145 }
146 new_tree->AsConstantTree(output);
147 MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_));
148 UpdateScore(new_tree.get(), cur_tree_id);
149 MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1));
150 }
151 }
152 // add model
153 models_.push_back(std::move(new_tree));
154 }
155 ++iter_;
156 return false;
157 }
158
159 void RollbackOneIter() override {
160 if (iter_ <= 0) { return; }
161 int cur_iter = iter_ + num_init_iteration_ - 1;
162 // reset score
163 for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
164 auto curr_tree = cur_iter * num_tree_per_iteration_ + cur_tree_id;
165 models_[curr_tree]->Shrinkage(-1.0);
166 MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_));
167 train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
168 for (auto& score_updater : valid_score_updater_) {
169 score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
170 }
171 MultiplyScore(cur_tree_id, 1.0f / (iter_ + num_init_iteration_ - 1));
172 }
173 // remove model
174 for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
175 models_.pop_back();
176 }
177 --iter_;
178 }
179
180 void MultiplyScore(const int cur_tree_id, double val) {
181 train_score_updater_->MultiplyScore(val, cur_tree_id);
182 for (auto& score_updater : valid_score_updater_) {
183 score_updater->MultiplyScore(val, cur_tree_id);
184 }
185 }
186
187 void AddValidDataset(const Dataset* valid_data,
188 const std::vector<const Metric*>& valid_metrics) override {
189 GBDT::AddValidDataset(valid_data, valid_metrics);
190 if (iter_ + num_init_iteration_ > 0) {
191 for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
192 valid_score_updater_.back()->MultiplyScore(1.0f / (iter_ + num_init_iteration_), cur_tree_id);
193 }
194 }
195 }
196
197 bool NeedAccuratePrediction() const override {
198 // No early stopping for prediction
199 return true;
200 };
201
202private:
203 std::vector<score_t> tmp_grad_;
204 std::vector<score_t> tmp_hess_;
205 std::vector<double> init_scores_;
206};
207
208} // namespace LightGBM
209#endif // LIGHTGBM_BOOSTING_RF_H_
The main class of data set, which are used to traning or validation.
Definition dataset.h:278
GBDT algorithm implementation. including Training, prediction, bagging.
Definition gbdt.h:26
std::vector< score_t > hessians_
Secend order derivative of training data.
Definition gbdt.h:445
virtual void Bagging(int iter)
Implement bagging logic.
Definition gbdt.cpp:180
void AddValidDataset(const Dataset *valid_data, const std::vector< const Metric * > &valid_metrics) override
Adding a validation dataset.
Definition gbdt.cpp:117
std::vector< std::unique_ptr< Tree > > models_
Trained models(trees)
Definition gbdt.h:439
data_size_t num_data_
Number of training data.
Definition gbdt.h:453
std::unique_ptr< TreeLearner > tree_learner_
Tree learner, will use this class to learn trees.
Definition gbdt.h:419
virtual void UpdateScore(const Tree *tree, const int cur_tree_id)
updating score after tree was trained
Definition gbdt.cpp:451
void Init(const Config *gbdt_config, const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Initialization logic.
Definition gbdt.cpp:45
void ResetTrainingData(const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Reset the training data.
Definition gbdt.cpp:622
const ObjectiveFunction * objective_function_
Objective function.
Definition gbdt.h:421
int num_class_
Number of class.
Definition gbdt.h:457
void ResetConfig(const Config *gbdt_config) override
Reset Boosting Config.
Definition gbdt.cpp:676
std::vector< data_size_t > bag_data_indices_
Store the indices of in-bag data.
Definition gbdt.h:447
std::vector< std::unique_ptr< ScoreUpdater > > valid_score_updater_
Store and update validation data's scores.
Definition gbdt.h:427
data_size_t bag_data_cnt_
Number of in-bag data.
Definition gbdt.h:449
int num_tree_per_iteration_
Number of trees per iterations.
Definition gbdt.h:455
int iter_
current iteration
Definition gbdt.h:413
GBDT()
Constructor.
Definition gbdt.cpp:22
std::unique_ptr< ScoreUpdater > train_score_updater_
Store and update training data's score.
Definition gbdt.h:423
double shrinkage_rate_
Shrinkage rate for one iteration.
Definition gbdt.h:463
std::vector< score_t > gradients_
First order derivative of training data.
Definition gbdt.h:443
int num_init_iteration_
Number of loaded initial models.
Definition gbdt.h:465
The interface of Objective Function.
Definition objective_function.h:13
Rondom Forest implementation.
Definition rf.hpp:18
void ResetTrainingData(const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Reset the training data.
Definition rf.hpp:58
void ResetConfig(const Config *config) override
Reset Boosting Config.
Definition rf.hpp:50
void AddValidDataset(const Dataset *valid_data, const std::vector< const Metric * > &valid_metrics) override
Adding a validation dataset.
Definition rf.hpp:187
void Boosting() override
calculate the object function
Definition rf.hpp:75
bool NeedAccuratePrediction() const override
Can use early stopping for prediction or not.
Definition rf.hpp:197
void RollbackOneIter() override
Rollback one iteration.
Definition rf.hpp:159
void Init(const Config *config, const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Initialization logic.
Definition rf.hpp:26
bool TrainOneIter(const score_t *gradients, const score_t *hessians) override
Training logic.
Definition rf.hpp:96
Tree model.
Definition tree.h:20
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float score_t
Type of score, and gradients.
Definition meta.h:26
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14
Definition config.h:27