Medial Code Documentation
Loading...
Searching...
No Matches
gbdt.h
1#ifndef LIGHTGBM_BOOSTING_GBDT_H_
2#define LIGHTGBM_BOOSTING_GBDT_H_
3
4#include <LightGBM/boosting.h>
5#include <LightGBM/objective_function.h>
6#include <LightGBM/prediction_early_stop.h>
7#include <LightGBM/json11.hpp>
8
9#include "score_updater.hpp"
10
11#include <cstdio>
12#include <vector>
13#include <string>
14#include <fstream>
15#include <memory>
16#include <mutex>
17#include <map>
18
19using namespace json11;
20
21namespace LightGBM {
22
26class GBDT : public GBDTBase {
27public:
31 GBDT();
32
36 ~GBDT();
37
45 void Init(const Config* gbdt_config, const Dataset* train_data,
46 const ObjectiveFunction* objective_function,
47 const std::vector<const Metric*>& training_metrics) override;
48
53 void MergeFrom(const Boosting* other) override {
54 auto other_gbdt = reinterpret_cast<const GBDT*>(other);
55 // tmp move to other vector
56 auto original_models = std::move(models_);
57 models_ = std::vector<std::unique_ptr<Tree>>();
58 // push model from other first
59 for (const auto& tree : other_gbdt->models_) {
60 auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get())));
61 models_.push_back(std::move(new_tree));
62 }
63 num_init_iteration_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
64 // push model in current object
65 for (const auto& tree : original_models) {
66 auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get())));
67 models_.push_back(std::move(new_tree));
68 }
69 num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
70 }
71
72 void ShuffleModels(int start_iter, int end_iter) override {
73 int total_iter = static_cast<int>(models_.size()) / num_tree_per_iteration_;
74 start_iter = std::max(0, start_iter);
75 if (end_iter <= 0) {
76 end_iter = total_iter;
77 }
78 end_iter = std::min(total_iter, end_iter);
79 auto original_models = std::move(models_);
80 std::vector<int> indices(total_iter);
81 for (int i = 0; i < total_iter; ++i) {
82 indices[i] = i;
83 }
84 Random tmp_rand(17);
85 for (int i = start_iter; i < end_iter - 1; ++i) {
86 int j = tmp_rand.NextShort(i + 1, end_iter);
87 std::swap(indices[i], indices[j]);
88 }
89 models_ = std::vector<std::unique_ptr<Tree>>();
90 for (int i = 0; i < total_iter; ++i) {
91 for (int j = 0; j < num_tree_per_iteration_; ++j) {
92 int tree_idx = indices[i] * num_tree_per_iteration_ + j;
93 auto new_tree = std::unique_ptr<Tree>(new Tree(*(original_models[tree_idx].get())));
94 models_.push_back(std::move(new_tree));
95 }
96 }
97 }
98
105 void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
106 const std::vector<const Metric*>& training_metrics) override;
107
112 void ResetConfig(const Config* gbdt_config) override;
113
119 void AddValidDataset(const Dataset* valid_data,
120 const std::vector<const Metric*>& valid_metrics) override;
121
127 void Train(int snapshot_freq, const std::string& model_output_path) override;
128
129 void RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) override;
130
137 virtual bool TrainOneIter(const score_t* gradients, const score_t* hessians) override;
138
142 void RollbackOneIter() override;
143
147 int GetCurrentIteration() const override { return static_cast<int>(models_.size()) / num_tree_per_iteration_; }
148
153 bool NeedAccuratePrediction() const override {
154 if (objective_function_ == nullptr) {
155 return true;
156 } else {
158 }
159 }
160
166 std::vector<double> GetEvalAt(int data_idx) const override;
167
173 virtual const double* GetTrainingScore(int64_t* out_len) override;
174
180 virtual int64_t GetNumPredictAt(int data_idx) const override {
181 CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
182 data_size_t num_data = train_data_->num_data();
183 if (data_idx > 0) {
184 num_data = valid_score_updater_[data_idx - 1]->num_data();
185 }
186 return num_data * num_class_;
187 }
188
195 void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) override;
196
204 inline int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const override {
205 int num_preb_in_one_row = num_class_;
206 if (is_pred_leaf) {
207 int max_iteration = GetCurrentIteration();
208 if (num_iteration > 0) {
209 num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
210 } else {
211 num_preb_in_one_row *= max_iteration;
212 }
213 } else if (is_pred_contrib) {
214 num_preb_in_one_row = num_tree_per_iteration_ * (max_feature_idx_ + 2); // +1 for 0-based indexing, +1 for baseline
215 }
216 return num_preb_in_one_row;
217 }
218
219 void PredictRaw(const double* features, double* output,
220 const PredictionEarlyStopInstance* earlyStop) const override;
221
222 void PredictRawByMap(const std::unordered_map<int, double>& features, double* output,
223 const PredictionEarlyStopInstance* early_stop) const override;
224
225 void Predict(const double* features, double* output,
226 const PredictionEarlyStopInstance* earlyStop) const override;
227
228 void PredictByMap(const std::unordered_map<int, double>& features, double* output,
229 const PredictionEarlyStopInstance* early_stop) const override;
230
231 void PredictLeafIndex(const double* features, double* output) const override;
232
233 void PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const override;
234
235 void PredictContrib(const double* features, double* output,
236 const PredictionEarlyStopInstance* earlyStop) const override;
237
244 std::string DumpModel(int start_iteration, int num_iteration) const override;
245
251 std::string ModelToIfElse(int num_iteration) const override;
252
259 bool SaveModelToIfElse(int num_iteration, const char* filename) const override;
260
268 virtual bool SaveModelToFile(int start_iteration, int num_iterations, const char* filename) const override;
269
276 std::string SaveModelToString(int num_iterations) { return SaveModelToString(0, num_iterations); } // ADDED by Medial to allow backward compatability with MedLightGBM wrapper
277 virtual std::string SaveModelToString(int start_iteration, int num_iterations) const override;
278
282 bool LoadModelFromString(std::string str) { return LoadModelFromString(str.c_str(), str.length()); } // ADDED by Medial for backward compatability
283 bool LoadModelFromString(const char* buffer, size_t len) override;
284
291 std::vector<double> FeatureImportance(int num_iteration, int importance_type) const override;
292
297 inline int MaxFeatureIdx() const override { return max_feature_idx_; }
298
303 inline std::vector<std::string> FeatureNames() const override { return feature_names_; }
304
309 inline int LabelIdx() const override { return label_idx_; }
310
315 inline int NumberOfTotalModel() const override { return static_cast<int>(models_.size()); }
316
321 inline int NumModelPerIteration() const override { return num_tree_per_iteration_; }
322
327 inline int NumberOfClasses() const override { return num_class_; }
328
329 inline void InitPredict(int num_iteration, bool is_pred_contrib) override {
330 num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
331 if (num_iteration > 0) {
332 num_iteration_for_pred_ = std::min(num_iteration, num_iteration_for_pred_);
333 }
334 if (is_pred_contrib) {
335 #pragma omp parallel for schedule(static)
336 for (int i = 0; i < static_cast<int>(models_.size()); ++i) {
337 models_[i]->RecomputeMaxDepth();
338 }
339 }
340 }
341
342 inline double GetLeafValue(int tree_idx, int leaf_idx) const override {
343 CHECK(tree_idx >= 0 && static_cast<size_t>(tree_idx) < models_.size());
344 CHECK(leaf_idx >= 0 && leaf_idx < models_[tree_idx]->num_leaves());
345 return models_[tree_idx]->LeafOutput(leaf_idx);
346 }
347
348 inline void SetLeafValue(int tree_idx, int leaf_idx, double val) override {
349 CHECK(tree_idx >= 0 && static_cast<size_t>(tree_idx) < models_.size());
350 CHECK(leaf_idx >= 0 && leaf_idx < models_[tree_idx]->num_leaves());
351 models_[tree_idx]->SetLeafOutput(leaf_idx, val);
352 }
353
357 virtual const char* SubModelName() const override { return "tree"; }
358
359protected:
363 virtual bool EvalAndCheckEarlyStopping();
364
368 void ResetBaggingConfig(const Config* config, bool is_change_dataset);
369
374 virtual void Bagging(int iter);
375
383 data_size_t BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer);
384
388 virtual void Boosting();
389
395 virtual void UpdateScore(const Tree* tree, const int cur_tree_id);
396
401 virtual std::vector<double> EvalOneMetric(const Metric* metric, const double* score) const;
402
408 std::string OutputMetric(int iter);
409
410 double BoostFromAverage(int class_id, bool update_scorer);
411
413 int iter_;
417 std::unique_ptr<Config> config_;
419 std::unique_ptr<TreeLearner> tree_learner_;
423 std::unique_ptr<ScoreUpdater> train_score_updater_;
425 std::vector<const Metric*> training_metrics_;
427 std::vector<std::unique_ptr<ScoreUpdater>> valid_score_updater_;
429 std::vector<std::vector<const Metric*>> valid_metrics_;
433 std::vector<std::vector<int>> best_iter_;
435 std::vector<std::vector<double>> best_score_;
437 std::vector<std::vector<std::string>> best_msg_;
439 std::vector<std::unique_ptr<Tree>> models_;
443 std::vector<score_t> gradients_;
445 std::vector<score_t> hessians_;
447 std::vector<data_size_t> bag_data_indices_;
451 std::vector<data_size_t> tmp_indices_;
467 std::vector<std::string> feature_names_;
468 std::vector<std::string> feature_infos_;
472 std::vector<data_size_t> offsets_buf_;
474 std::vector<data_size_t> left_cnts_buf_;
476 std::vector<data_size_t> right_cnts_buf_;
478 std::vector<data_size_t> left_write_pos_buf_;
480 std::vector<data_size_t> right_write_pos_buf_;
481 std::unique_ptr<Dataset> tmp_subset_;
482 bool is_use_subset_;
483 std::vector<bool> class_need_train_;
484 bool is_constant_hessian_;
485 std::unique_ptr<ObjectiveFunction> loaded_objective_;
486 bool average_output_;
487 bool need_re_bagging_;
488 std::string loaded_parameter_;
489
490 Json forced_splits_json_;
491};
492
493} // namespace LightGBM
494#endif // LightGBM_BOOSTING_GBDT_H_
The interface for Boosting.
Definition boosting.h:22
The main class of data set, which are used to traning or validation.
Definition dataset.h:278
data_size_t num_data() const
Get Number of data.
Definition dataset.h:577
Definition boosting.h:298
GBDT algorithm implementation. including Training, prediction, bagging.
Definition gbdt.h:26
int NumModelPerIteration() const override
Get number of tree per iteration.
Definition gbdt.h:321
std::vector< score_t > hessians_
Secend order derivative of training data.
Definition gbdt.h:445
virtual void Bagging(int iter)
Implement bagging logic.
Definition gbdt.cpp:180
void PredictContrib(const double *features, double *output, const PredictionEarlyStopInstance *earlyStop) const override
Feature contributions for the model's prediction of one record.
Definition gbdt.cpp:564
void AddValidDataset(const Dataset *valid_data, const std::vector< const Metric * > &valid_metrics) override
Adding a validation dataset.
Definition gbdt.cpp:117
int num_iteration_for_pred_
number of used model
Definition gbdt.h:461
std::vector< const Metric * > training_metrics_
Metrics for training data.
Definition gbdt.h:425
int LabelIdx() const override
Get index of label column.
Definition gbdt.h:309
void InitPredict(int num_iteration, bool is_pred_contrib) override
Initial work for the prediction.
Definition gbdt.h:329
std::vector< std::unique_ptr< Tree > > models_
Trained models(trees)
Definition gbdt.h:439
data_size_t BaggingHelper(Random &cur_rand, data_size_t start, data_size_t cnt, data_size_t *buffer)
Helper function for bagging, used for multi-threading optimization.
Definition gbdt.cpp:159
void ResetBaggingConfig(const Config *config, bool is_change_dataset)
reset config for bagging
Definition gbdt.cpp:689
int GetCurrentIteration() const override
Get current iteration.
Definition gbdt.h:147
data_size_t num_data_
Number of training data.
Definition gbdt.h:453
std::string DumpModel(int start_iteration, int num_iteration) const override
Dump model to json format string.
Definition gbdt_model_text.cpp:15
~GBDT()
Destructor.
Definition gbdt.cpp:42
virtual int64_t GetNumPredictAt(int data_idx) const override
Get size of prediction at data_idx data.
Definition gbdt.h:180
std::unique_ptr< TreeLearner > tree_learner_
Tree learner, will use this class to learn trees.
Definition gbdt.h:419
virtual void UpdateScore(const Tree *tree, const int cur_tree_id)
updating score after tree was trained
Definition gbdt.cpp:451
bool NeedAccuratePrediction() const override
Can use early stopping for prediction or not.
Definition gbdt.h:153
void Init(const Config *gbdt_config, const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Initialization logic.
Definition gbdt.cpp:45
int num_threads_
number of threads
Definition gbdt.h:470
std::string ModelToIfElse(int num_iteration) const override
Translate model to if-else statement.
Definition gbdt_model_text.cpp:60
virtual std::vector< double > EvalOneMetric(const Metric *metric, const double *score) const
eval results for one metric
Definition gbdt.cpp:472
std::vector< data_size_t > tmp_indices_
Store the indices of in-bag data.
Definition gbdt.h:451
std::vector< data_size_t > right_write_pos_buf_
Buffer for multi-threading bagging.
Definition gbdt.h:480
virtual const char * SubModelName() const override
Get Type name of this boosting object.
Definition gbdt.h:357
std::vector< data_size_t > left_write_pos_buf_
Buffer for multi-threading bagging.
Definition gbdt.h:478
std::vector< std::vector< std::string > > best_msg_
output message of best iteration
Definition gbdt.h:437
std::vector< double > FeatureImportance(int num_iteration, int importance_type) const override
Calculate feature importances.
Definition gbdt_model_text.cpp:513
int early_stopping_round_
Number of rounds for early stopping.
Definition gbdt.h:431
void ResetTrainingData(const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Reset the training data.
Definition gbdt.cpp:622
void RollbackOneIter() override
Rollback one iteration.
Definition gbdt.cpp:414
std::vector< std::vector< double > > best_score_
Best score(s) for early stopping.
Definition gbdt.h:435
int NumberOfClasses() const override
Get number of classes.
Definition gbdt.h:327
const ObjectiveFunction * objective_function_
Objective function.
Definition gbdt.h:421
int num_class_
Number of class.
Definition gbdt.h:457
void ResetConfig(const Config *gbdt_config) override
Reset Boosting Config.
Definition gbdt.cpp:676
virtual void Boosting()
calculate the object function
Definition gbdt.cpp:149
virtual bool TrainOneIter(const score_t *gradients, const score_t *hessians) override
Training logic.
Definition gbdt.cpp:333
std::vector< data_size_t > bag_data_indices_
Store the indices of in-bag data.
Definition gbdt.h:447
std::vector< std::vector< const Metric * > > valid_metrics_
Metric for validation data.
Definition gbdt.h:429
std::vector< std::vector< int > > best_iter_
Best iteration(s) for early stopping.
Definition gbdt.h:433
std::vector< std::unique_ptr< ScoreUpdater > > valid_score_updater_
Store and update validation data's scores.
Definition gbdt.h:427
std::unique_ptr< Config > config_
Config of gbdt.
Definition gbdt.h:417
data_size_t bag_data_cnt_
Number of in-bag data.
Definition gbdt.h:449
int num_tree_per_iteration_
Number of trees per iterations.
Definition gbdt.h:455
std::string OutputMetric(int iter)
Print metric result of current iteration.
Definition gbdt.cpp:476
bool SaveModelToIfElse(int num_iteration, const char *filename) const override
Translate model to if-else statement.
Definition gbdt_model_text.cpp:219
virtual bool EvalAndCheckEarlyStopping()
Print eval result and check early stopping.
Definition gbdt.cpp:432
int iter_
current iteration
Definition gbdt.h:413
void PredictLeafIndex(const double *features, double *output) const override
Prediction for one record with leaf index.
Definition gbdt_prediction.cpp:73
int max_feature_idx_
Max feature index of training data.
Definition gbdt.h:441
void MergeFrom(const Boosting *other) override
Merge model from other boosting object. Will insert to the front of current boosting object.
Definition gbdt.h:53
bool LoadModelFromString(std::string str)
Restore from a serialized buffer.
Definition gbdt.h:282
GBDT()
Constructor.
Definition gbdt.cpp:22
void PredictRaw(const double *features, double *output, const PredictionEarlyStopInstance *earlyStop) const override
Prediction for one record, not sigmoid transform.
Definition gbdt_prediction.cpp:9
void Predict(const double *features, double *output, const PredictionEarlyStopInstance *earlyStop) const override
Prediction for one record, sigmoid transformation will be used if needed.
Definition gbdt_prediction.cpp:49
std::vector< std::string > FeatureNames() const override
Get feature names of this model.
Definition gbdt.h:303
const Dataset * train_data_
Pointer to training data.
Definition gbdt.h:415
std::vector< data_size_t > right_cnts_buf_
Buffer for multi-threading bagging.
Definition gbdt.h:476
std::unique_ptr< ScoreUpdater > train_score_updater_
Store and update training data's score.
Definition gbdt.h:423
int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const override
Get number of prediction for one data.
Definition gbdt.h:204
double shrinkage_rate_
Shrinkage rate for one iteration.
Definition gbdt.h:463
std::vector< score_t > gradients_
First order derivative of training data.
Definition gbdt.h:443
data_size_t label_idx_
Index of label column.
Definition gbdt.h:459
void RefitTree(const std::vector< std::vector< int > > &tree_leaf_prediction) override
Update the tree output by new training data.
Definition gbdt.cpp:263
void ShuffleModels(int start_iter, int end_iter) override
Shuffle Existing Models.
Definition gbdt.h:72
virtual const double * GetTrainingScore(int64_t *out_len) override
Get current training score.
Definition gbdt.cpp:559
std::vector< data_size_t > offsets_buf_
Buffer for multi-threading bagging.
Definition gbdt.h:472
void GetPredictAt(int data_idx, double *out_result, int64_t *out_len) override
Get prediction result at data_idx data.
Definition gbdt.cpp:585
int num_init_iteration_
Number of loaded initial models.
Definition gbdt.h:465
std::vector< std::string > feature_names_
Feature names.
Definition gbdt.h:467
int NumberOfTotalModel() const override
Get number of weak sub-models.
Definition gbdt.h:315
void Train(int snapshot_freq, const std::string &model_output_path) override
Perform a full training procedure.
Definition gbdt.cpp:243
std::vector< data_size_t > left_cnts_buf_
Buffer for multi-threading bagging.
Definition gbdt.h:474
virtual bool SaveModelToFile(int start_iteration, int num_iterations, const char *filename) const override
Save model to file.
Definition gbdt_model_text.cpp:332
std::vector< double > GetEvalAt(int data_idx) const override
Get evaluation result at data_idx data.
Definition gbdt.cpp:536
std::string SaveModelToString(int num_iterations)
Save model to string.
Definition gbdt.h:276
int MaxFeatureIdx() const override
Get max feature index of this model.
Definition gbdt.h:297
The interface of metric. Metric is used to calculate metric result.
Definition metric.h:20
The interface of Objective Function.
Definition objective_function.h:13
virtual bool NeedAccuratePrediction() const
The prediction should be accurate or not. True will disable early stopping for prediction.
Definition objective_function.h:63
A wrapper for random generator.
Definition random.h:15
int NextShort(int lower_bound, int upper_bound)
Generate random integer, int16 range. [0, 65536].
Definition random.h:38
Tree model.
Definition tree.h:20
Definition json11.hpp:79
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float score_t
Type of score, and gradients.
Definition meta.h:26
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14
NLOHMANN_BASIC_JSON_TPL_DECLARATION void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL &j1, nlohmann::NLOHMANN_BASIC_JSON_TPL &j2) noexcept(//NOLINT(readability-inconsistent-declaration-parameter-name, cert-dcl58-cpp) is_nothrow_move_constructible< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value &&//NOLINT(misc-redundant-expression) is_nothrow_move_assignable< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value)
exchanges the values of two JSON objects
Definition json.hpp:24418
Definition config.h:27
Definition prediction_early_stop.h:11