1#ifndef LIGHTGBM_BOOSTING_RF_H_
2#define LIGHTGBM_BOOSTING_RF_H_
4#include <LightGBM/boosting.h>
5#include <LightGBM/metric.h>
6#include "score_updater.hpp"
21 average_output_ =
true;
27 const std::vector<const Metric*>& training_metrics)
override {
28 CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f);
29 CHECK(config->feature_fraction <= 1.0f && config->feature_fraction > 0.0f);
30 GBDT::Init(config, train_data, objective_function, training_metrics);
37 CHECK(train_data->metadata().init_score() ==
nullptr);
51 CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f);
52 CHECK(config->feature_fraction <= 1.0f && config->feature_fraction > 0.0f);
59 const std::vector<const Metric*>& training_metrics)
override {
77 Log::Fatal(
"No object function provided");
81 init_scores_[cur_tree_id] = BoostFromAverage(cur_tree_id,
false);
84 std::vector<double> tmp_scores(total_size, 0.0f);
85 #pragma omp parallel for schedule(static)
87 size_t bias =
static_cast<size_t>(j)*
num_data_;
89 tmp_scores[bias + i] = init_scores_[j];
99 CHECK(gradients ==
nullptr);
100 CHECK(hessians ==
nullptr);
105 std::unique_ptr<Tree> new_tree(
new Tree(2));
106 size_t bias =
static_cast<size_t>(cur_tree_id)*
num_data_;
107 if (class_need_train_[cur_tree_id]) {
108 auto grad = gradients + bias;
109 auto hess = hessians + bias;
117 grad = tmp_grad_.data();
118 hess = tmp_hess_.data();
121 new_tree.reset(
tree_learner_->Train(grad, hess, is_constant_hessian_,
122 forced_splits_json_));
125 if (new_tree->num_leaves() > 1) {
128 if (std::fabs(init_scores_[cur_tree_id]) > kEpsilon) {
129 new_tree->AddBias(init_scores_[cur_tree_id]);
139 if (!class_need_train_[cur_tree_id]) {
143 output = init_scores_[cur_tree_id];
146 new_tree->AsConstantTree(output);
153 models_.push_back(std::move(new_tree));
160 if (
iter_ <= 0) {
return; }
165 models_[curr_tree]->Shrinkage(-1.0);
169 score_updater->AddScore(
models_[curr_tree].get(), cur_tree_id);
180 void MultiplyScore(
const int cur_tree_id,
double val) {
183 score_updater->MultiplyScore(val, cur_tree_id);
188 const std::vector<const Metric*>& valid_metrics)
override {
203 std::vector<score_t> tmp_grad_;
204 std::vector<score_t> tmp_hess_;
205 std::vector<double> init_scores_;
The main class of data set, which are used to traning or validation.
Definition dataset.h:278
GBDT algorithm implementation. including Training, prediction, bagging.
Definition gbdt.h:26
std::vector< score_t > hessians_
Secend order derivative of training data.
Definition gbdt.h:445
virtual void Bagging(int iter)
Implement bagging logic.
Definition gbdt.cpp:180
void AddValidDataset(const Dataset *valid_data, const std::vector< const Metric * > &valid_metrics) override
Adding a validation dataset.
Definition gbdt.cpp:117
std::vector< std::unique_ptr< Tree > > models_
Trained models(trees)
Definition gbdt.h:439
data_size_t num_data_
Number of training data.
Definition gbdt.h:453
std::unique_ptr< TreeLearner > tree_learner_
Tree learner, will use this class to learn trees.
Definition gbdt.h:419
virtual void UpdateScore(const Tree *tree, const int cur_tree_id)
updating score after tree was trained
Definition gbdt.cpp:451
void Init(const Config *gbdt_config, const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Initialization logic.
Definition gbdt.cpp:45
void ResetTrainingData(const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Reset the training data.
Definition gbdt.cpp:622
const ObjectiveFunction * objective_function_
Objective function.
Definition gbdt.h:421
int num_class_
Number of class.
Definition gbdt.h:457
void ResetConfig(const Config *gbdt_config) override
Reset Boosting Config.
Definition gbdt.cpp:676
std::vector< data_size_t > bag_data_indices_
Store the indices of in-bag data.
Definition gbdt.h:447
std::vector< std::unique_ptr< ScoreUpdater > > valid_score_updater_
Store and update validation data's scores.
Definition gbdt.h:427
data_size_t bag_data_cnt_
Number of in-bag data.
Definition gbdt.h:449
int num_tree_per_iteration_
Number of trees per iterations.
Definition gbdt.h:455
int iter_
current iteration
Definition gbdt.h:413
GBDT()
Constructor.
Definition gbdt.cpp:22
std::unique_ptr< ScoreUpdater > train_score_updater_
Store and update training data's score.
Definition gbdt.h:423
double shrinkage_rate_
Shrinkage rate for one iteration.
Definition gbdt.h:463
std::vector< score_t > gradients_
First order derivative of training data.
Definition gbdt.h:443
int num_init_iteration_
Number of loaded initial models.
Definition gbdt.h:465
The interface of Objective Function.
Definition objective_function.h:13
Rondom Forest implementation.
Definition rf.hpp:18
void ResetTrainingData(const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Reset the training data.
Definition rf.hpp:58
void ResetConfig(const Config *config) override
Reset Boosting Config.
Definition rf.hpp:50
void AddValidDataset(const Dataset *valid_data, const std::vector< const Metric * > &valid_metrics) override
Adding a validation dataset.
Definition rf.hpp:187
void Boosting() override
calculate the object function
Definition rf.hpp:75
bool NeedAccuratePrediction() const override
Can use early stopping for prediction or not.
Definition rf.hpp:197
void RollbackOneIter() override
Rollback one iteration.
Definition rf.hpp:159
void Init(const Config *config, const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Initialization logic.
Definition rf.hpp:26
bool TrainOneIter(const score_t *gradients, const score_t *hessians) override
Training logic.
Definition rf.hpp:96
Tree model.
Definition tree.h:20
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float score_t
Type of score, and gradients.
Definition meta.h:26
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14