1#ifndef LIGHTGBM_BOOSTING_DART_H_
2#define LIGHTGBM_BOOSTING_DART_H_
4#include <LightGBM/boosting.h>
5#include "score_updater.hpp"
37 const std::vector<const Metric*>& training_metrics)
override {
38 GBDT::Init(config, train_data, objective_function, training_metrics);
43 void ResetConfig(
const Config* config)
override {
53 is_update_score_cur_iter_ =
false;
73 if (!is_update_score_cur_iter_) {
76 is_update_score_cur_iter_ =
true;
91 void DroppingTrees() {
96 double drop_rate =
config_->drop_rate;
98 double inv_average_weight =
static_cast<double>(tree_weight_.size()) / sum_weight_;
100 drop_rate = std::min(drop_rate,
config_->max_drop * inv_average_weight / sum_weight_);
102 for (
int i = 0; i <
iter_; ++i) {
103 if (random_for_drop_.
NextFloat() < drop_rate * tree_weight_[i] * inv_average_weight) {
105 if (drop_index_.size() >=
static_cast<size_t>(
config_->max_drop)) {
112 drop_rate = std::min(drop_rate,
config_->max_drop /
static_cast<double>(
iter_));
114 for (
int i = 0; i <
iter_; ++i) {
115 if (random_for_drop_.
NextFloat() < drop_rate) {
117 if (drop_index_.size() >=
static_cast<size_t>(
config_->max_drop)) {
125 for (
auto i : drop_index_) {
128 models_[curr_tree]->Shrinkage(-1.0);
132 if (!
config_->xgboost_dart_mode) {
135 if (drop_index_.empty()) {
153 double k =
static_cast<double>(drop_index_.size());
154 if (!
config_->xgboost_dart_mode) {
155 for (
auto i : drop_index_) {
159 models_[curr_tree]->Shrinkage(1.0f / (k + 1.0f));
161 score_updater->AddScore(
models_[curr_tree].get(), cur_tree_id);
164 models_[curr_tree]->Shrinkage(-k);
168 sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f));
169 tree_weight_[i] *= (k / (k + 1.0f));
173 for (
auto i : drop_index_) {
179 score_updater->AddScore(
models_[curr_tree].get(), cur_tree_id);
186 sum_weight_ -= tree_weight_[i] * (1.0f / (k +
config_->learning_rate));;
187 tree_weight_[i] *= (k / (k +
config_->learning_rate));
193 std::vector<double> tree_weight_;
197 std::vector<int> drop_index_;
199 Random random_for_drop_;
201 bool is_update_score_cur_iter_;
DART algorithm implementation. including Training, prediction, bagging.
Definition dart.hpp:17
void Init(const Config *config, const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Initialization logic.
Definition dart.hpp:35
DART()
Constructor.
Definition dart.hpp:22
bool EvalAndCheckEarlyStopping() override
Print eval result and check early stopping.
Definition dart.hpp:82
~DART()
Destructor.
Definition dart.hpp:26
bool TrainOneIter(const score_t *gradient, const score_t *hessian) override
one training iteration
Definition dart.hpp:52
const double * GetTrainingScore(int64_t *out_len) override
Get current training score.
Definition dart.hpp:72
The main class of data set, which are used to traning or validation.
Definition dataset.h:278
GBDT algorithm implementation. including Training, prediction, bagging.
Definition gbdt.h:26
std::vector< std::unique_ptr< Tree > > models_
Trained models(trees)
Definition gbdt.h:439
void Init(const Config *gbdt_config, const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Initialization logic.
Definition gbdt.cpp:45
int num_class_
Number of class.
Definition gbdt.h:457
void ResetConfig(const Config *gbdt_config) override
Reset Boosting Config.
Definition gbdt.cpp:676
virtual bool TrainOneIter(const score_t *gradients, const score_t *hessians) override
Training logic.
Definition gbdt.cpp:333
std::vector< std::unique_ptr< ScoreUpdater > > valid_score_updater_
Store and update validation data's scores.
Definition gbdt.h:427
std::unique_ptr< Config > config_
Config of gbdt.
Definition gbdt.h:417
int num_tree_per_iteration_
Number of trees per iterations.
Definition gbdt.h:455
std::string OutputMetric(int iter)
Print metric result of current iteration.
Definition gbdt.cpp:476
int iter_
current iteration
Definition gbdt.h:413
std::unique_ptr< ScoreUpdater > train_score_updater_
Store and update training data's score.
Definition gbdt.h:423
double shrinkage_rate_
Shrinkage rate for one iteration.
Definition gbdt.h:463
int num_init_iteration_
Number of loaded initial models.
Definition gbdt.h:465
The interface of Objective Function.
Definition objective_function.h:13
A wrapper for random generator.
Definition random.h:15
float NextFloat()
Generate random float data.
Definition random.h:56
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float score_t
Type of score, and gradients.
Definition meta.h:26