Medial Code Documentation
Loading...
Searching...
No Matches
dart.hpp
1#ifndef LIGHTGBM_BOOSTING_DART_H_
2#define LIGHTGBM_BOOSTING_DART_H_
3
4#include <LightGBM/boosting.h>
5#include "score_updater.hpp"
6#include "gbdt.h"
7
8#include <cstdio>
9#include <vector>
10#include <string>
11#include <fstream>
12
13namespace LightGBM {
17class DART: public GBDT {
18public:
22 DART() : GBDT() { }
26 ~DART() { }
35 void Init(const Config* config, const Dataset* train_data,
36 const ObjectiveFunction* objective_function,
37 const std::vector<const Metric*>& training_metrics) override {
38 GBDT::Init(config, train_data, objective_function, training_metrics);
39 random_for_drop_ = Random(config_->drop_seed);
40 sum_weight_ = 0.0f;
41 }
42
43 void ResetConfig(const Config* config) override {
44 GBDT::ResetConfig(config);
45 random_for_drop_ = Random(config_->drop_seed);
46 sum_weight_ = 0.0f;
47 }
48
52 bool TrainOneIter(const score_t* gradient, const score_t* hessian) override {
53 is_update_score_cur_iter_ = false;
54 bool ret = GBDT::TrainOneIter(gradient, hessian);
55 if (ret) {
56 return ret;
57 }
58 // normalize
59 Normalize();
60 if (!config_->uniform_drop) {
61 tree_weight_.push_back(shrinkage_rate_);
62 sum_weight_ += shrinkage_rate_;
63 }
64 return false;
65 }
66
72 const double* GetTrainingScore(int64_t* out_len) override {
73 if (!is_update_score_cur_iter_) {
74 // only drop one time in one iteration
75 DroppingTrees();
76 is_update_score_cur_iter_ = true;
77 }
78 *out_len = static_cast<int64_t>(train_score_updater_->num_data()) * num_class_;
79 return train_score_updater_->score();
80 }
81
82 bool EvalAndCheckEarlyStopping() override {
84 return false;
85 }
86
87private:
91 void DroppingTrees() {
92 drop_index_.clear();
93 bool is_skip = random_for_drop_.NextFloat() < config_->skip_drop;
94 // select dropping tree indices based on drop_rate and tree weights
95 if (!is_skip) {
96 double drop_rate = config_->drop_rate;
97 if (!config_->uniform_drop) {
98 double inv_average_weight = static_cast<double>(tree_weight_.size()) / sum_weight_;
99 if (config_->max_drop > 0) {
100 drop_rate = std::min(drop_rate, config_->max_drop * inv_average_weight / sum_weight_);
101 }
102 for (int i = 0; i < iter_; ++i) {
103 if (random_for_drop_.NextFloat() < drop_rate * tree_weight_[i] * inv_average_weight) {
104 drop_index_.push_back(num_init_iteration_ + i);
105 if (drop_index_.size() >= static_cast<size_t>(config_->max_drop)) {
106 break;
107 }
108 }
109 }
110 } else {
111 if (config_->max_drop > 0) {
112 drop_rate = std::min(drop_rate, config_->max_drop / static_cast<double>(iter_));
113 }
114 for (int i = 0; i < iter_; ++i) {
115 if (random_for_drop_.NextFloat() < drop_rate) {
116 drop_index_.push_back(num_init_iteration_ + i);
117 if (drop_index_.size() >= static_cast<size_t>(config_->max_drop)) {
118 break;
119 }
120 }
121 }
122 }
123 }
124 // drop trees
125 for (auto i : drop_index_) {
126 for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
127 auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id;
128 models_[curr_tree]->Shrinkage(-1.0);
129 train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
130 }
131 }
132 if (!config_->xgboost_dart_mode) {
133 shrinkage_rate_ = config_->learning_rate / (1.0f + static_cast<double>(drop_index_.size()));
134 } else {
135 if (drop_index_.empty()) {
136 shrinkage_rate_ = config_->learning_rate;
137 } else {
138 shrinkage_rate_ = config_->learning_rate / (config_->learning_rate + static_cast<double>(drop_index_.size()));
139 }
140 }
141 }
152 void Normalize() {
153 double k = static_cast<double>(drop_index_.size());
154 if (!config_->xgboost_dart_mode) {
155 for (auto i : drop_index_) {
156 for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
157 auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id;
158 // update validation score
159 models_[curr_tree]->Shrinkage(1.0f / (k + 1.0f));
160 for (auto& score_updater : valid_score_updater_) {
161 score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
162 }
163 // update training score
164 models_[curr_tree]->Shrinkage(-k);
165 train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
166 }
167 if (!config_->uniform_drop) {
168 sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f));
169 tree_weight_[i] *= (k / (k + 1.0f));
170 }
171 }
172 } else {
173 for (auto i : drop_index_) {
174 for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
175 auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id;
176 // update validation score
177 models_[curr_tree]->Shrinkage(shrinkage_rate_);
178 for (auto& score_updater : valid_score_updater_) {
179 score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
180 }
181 // update training score
182 models_[curr_tree]->Shrinkage(-k / config_->learning_rate);
183 train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
184 }
185 if (!config_->uniform_drop) {
186 sum_weight_ -= tree_weight_[i] * (1.0f / (k + config_->learning_rate));;
187 tree_weight_[i] *= (k / (k + config_->learning_rate));
188 }
189 }
190 }
191 }
193 std::vector<double> tree_weight_;
195 double sum_weight_;
197 std::vector<int> drop_index_;
199 Random random_for_drop_;
201 bool is_update_score_cur_iter_;
202};
203
204} // namespace LightGBM
205#endif // LightGBM_BOOSTING_DART_H_
DART algorithm implementation. including Training, prediction, bagging.
Definition dart.hpp:17
void Init(const Config *config, const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Initialization logic.
Definition dart.hpp:35
DART()
Constructor.
Definition dart.hpp:22
bool EvalAndCheckEarlyStopping() override
Print eval result and check early stopping.
Definition dart.hpp:82
~DART()
Destructor.
Definition dart.hpp:26
bool TrainOneIter(const score_t *gradient, const score_t *hessian) override
one training iteration
Definition dart.hpp:52
const double * GetTrainingScore(int64_t *out_len) override
Get current training score.
Definition dart.hpp:72
The main class of data set, which are used to traning or validation.
Definition dataset.h:278
GBDT algorithm implementation. including Training, prediction, bagging.
Definition gbdt.h:26
std::vector< std::unique_ptr< Tree > > models_
Trained models(trees)
Definition gbdt.h:439
void Init(const Config *gbdt_config, const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Initialization logic.
Definition gbdt.cpp:45
int num_class_
Number of class.
Definition gbdt.h:457
void ResetConfig(const Config *gbdt_config) override
Reset Boosting Config.
Definition gbdt.cpp:676
virtual bool TrainOneIter(const score_t *gradients, const score_t *hessians) override
Training logic.
Definition gbdt.cpp:333
std::vector< std::unique_ptr< ScoreUpdater > > valid_score_updater_
Store and update validation data's scores.
Definition gbdt.h:427
std::unique_ptr< Config > config_
Config of gbdt.
Definition gbdt.h:417
int num_tree_per_iteration_
Number of trees per iterations.
Definition gbdt.h:455
std::string OutputMetric(int iter)
Print metric result of current iteration.
Definition gbdt.cpp:476
int iter_
current iteration
Definition gbdt.h:413
std::unique_ptr< ScoreUpdater > train_score_updater_
Store and update training data's score.
Definition gbdt.h:423
double shrinkage_rate_
Shrinkage rate for one iteration.
Definition gbdt.h:463
int num_init_iteration_
Number of loaded initial models.
Definition gbdt.h:465
The interface of Objective Function.
Definition objective_function.h:13
A wrapper for random generator.
Definition random.h:15
float NextFloat()
Generate random float data.
Definition random.h:56
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float score_t
Type of score, and gradients.
Definition meta.h:26
Definition config.h:27