Medial Code Documentation
Loading...
Searching...
No Matches
goss.hpp
1#ifndef LIGHTGBM_BOOSTING_GOSS_H_
2#define LIGHTGBM_BOOSTING_GOSS_H_
3
4#include <LightGBM/utils/array_args.h>
5#include <LightGBM/utils/log.h>
6#include <LightGBM/utils/openmp_wrapper.h>
7#include <LightGBM/boosting.h>
8
9#include "score_updater.hpp"
10#include "gbdt.h"
11
12#include <cstdio>
13#include <vector>
14#include <string>
15#include <fstream>
16#include <chrono>
17#include <algorithm>
18
19namespace LightGBM {
20
21#ifdef TIMETAG
22std::chrono::duration<double, std::milli> subset_time;
23std::chrono::duration<double, std::milli> re_init_tree_time;
24#endif
25
26class GOSS: public GBDT {
27public:
31 GOSS() : GBDT() {
32 }
33
34 ~GOSS() {
35 #ifdef TIMETAG
36 Log::Info("GOSS::subset costs %f", subset_time * 1e-3);
37 Log::Info("GOSS::re_init_tree costs %f", re_init_tree_time * 1e-3);
38 #endif
39 }
40
41 void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
42 const std::vector<const Metric*>& training_metrics) override {
43 GBDT::Init(config, train_data, objective_function, training_metrics);
44 ResetGoss();
45 }
46
47 void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
48 const std::vector<const Metric*>& training_metrics) override {
49 GBDT::ResetTrainingData(train_data, objective_function, training_metrics);
50 ResetGoss();
51 }
52
53 void ResetConfig(const Config* config) override {
54 GBDT::ResetConfig(config);
55 ResetGoss();
56 }
57
58 void ResetGoss() {
59 CHECK(config_->top_rate + config_->other_rate <= 1.0f);
60 CHECK(config_->top_rate > 0.0f && config_->other_rate > 0.0f);
61 if (config_->bagging_freq > 0 && config_->bagging_fraction != 1.0f) {
62 Log::Fatal("Cannot use bagging in GOSS");
63 }
64 Log::Info("Using GOSS");
65
67 tmp_indices_.resize(num_data_);
68 tmp_indice_right_.resize(num_data_);
74
75 is_use_subset_ = false;
76 if (config_->top_rate + config_->other_rate <= 0.5) {
77 auto bag_data_cnt = static_cast<data_size_t>((config_->top_rate + config_->other_rate) * num_data_);
78 bag_data_cnt = std::max(1, bag_data_cnt);
79 tmp_subset_.reset(new Dataset(bag_data_cnt));
80 tmp_subset_->CopyFeatureMapperFrom(train_data_);
81 is_use_subset_ = true;
82 }
83 // flag to not bagging first
85 }
86
87 data_size_t BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer, data_size_t* buffer_right) {
88 if (cnt <= 0) {
89 return 0;
90 }
91 std::vector<score_t> tmp_gradients(cnt, 0.0f);
92 for (data_size_t i = 0; i < cnt; ++i) {
93 for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
94 size_t idx = static_cast<size_t>(cur_tree_id) * num_data_ + start + i;
95 tmp_gradients[i] += std::fabs(gradients_[idx] * hessians_[idx]);
96 }
97 }
98 data_size_t top_k = static_cast<data_size_t>(cnt * config_->top_rate);
99 data_size_t other_k = static_cast<data_size_t>(cnt * config_->other_rate);
100 top_k = std::max(1, top_k);
101 ArrayArgs<score_t>::ArgMaxAtK(&tmp_gradients, 0, static_cast<int>(tmp_gradients.size()), top_k - 1);
102 score_t threshold = tmp_gradients[top_k - 1];
103
104 score_t multiply = static_cast<score_t>(cnt - top_k) / other_k;
105 data_size_t cur_left_cnt = 0;
106 data_size_t cur_right_cnt = 0;
107 data_size_t big_weight_cnt = 0;
108 for (data_size_t i = 0; i < cnt; ++i) {
109 score_t grad = 0.0f;
110 for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
111 size_t idx = static_cast<size_t>(cur_tree_id) * num_data_ + start + i;
112 grad += std::fabs(gradients_[idx] * hessians_[idx]);
113 }
114 if (grad >= threshold) {
115 buffer[cur_left_cnt++] = start + i;
116 ++big_weight_cnt;
117 } else {
118 data_size_t sampled = cur_left_cnt - big_weight_cnt;
119 data_size_t rest_need = other_k - sampled;
120 data_size_t rest_all = (cnt - i) - (top_k - big_weight_cnt);
121 double prob = (rest_need) / static_cast<double>(rest_all);
122 if (cur_rand.NextFloat() < prob) {
123 buffer[cur_left_cnt++] = start + i;
124 for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
125 size_t idx = static_cast<size_t>(cur_tree_id) * num_data_ + start + i;
126 gradients_[idx] *= multiply;
127 hessians_[idx] *= multiply;
128 }
129 } else {
130 buffer_right[cur_right_cnt++] = start + i;
131 }
132 }
133 }
134 return cur_left_cnt;
135 }
136
137 void Bagging(int iter) override {
139 // not subsample for first iterations
140 if (iter < static_cast<int>(1.0f / config_->learning_rate)) { return; }
141
142 const data_size_t min_inner_size = 100;
143 data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
144 if (inner_size < min_inner_size) { inner_size = min_inner_size; }
145 OMP_INIT_EX();
146 #pragma omp parallel for schedule(static, 1)
147 for (int i = 0; i < num_threads_; ++i) {
148 OMP_LOOP_EX_BEGIN();
149 left_cnts_buf_[i] = 0;
150 right_cnts_buf_[i] = 0;
151 data_size_t cur_start = i * inner_size;
152 if (cur_start > num_data_) { continue; }
153 data_size_t cur_cnt = inner_size;
154 if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
155 Random cur_rand(config_->bagging_seed + iter * num_threads_ + i);
156 data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt,
157 tmp_indices_.data() + cur_start, tmp_indice_right_.data() + cur_start);
158 offsets_buf_[i] = cur_start;
159 left_cnts_buf_[i] = cur_left_count;
160 right_cnts_buf_[i] = cur_cnt - cur_left_count;
161 OMP_LOOP_EX_END();
162 }
163 OMP_THROW_EX();
164 data_size_t left_cnt = 0;
165 left_write_pos_buf_[0] = 0;
167 for (int i = 1; i < num_threads_; ++i) {
170 }
172
173 #pragma omp parallel for schedule(static, 1)
174 for (int i = 0; i < num_threads_; ++i) {
175 OMP_LOOP_EX_BEGIN();
176 if (left_cnts_buf_[i] > 0) {
177 std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i],
178 tmp_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t));
179 }
180 if (right_cnts_buf_[i] > 0) {
181 std::memcpy(bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i],
182 tmp_indice_right_.data() + offsets_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
183 }
184 OMP_LOOP_EX_END();
185 }
186 OMP_THROW_EX();
187 bag_data_cnt_ = left_cnt;
188 // set bagging data to tree learner
189 if (!is_use_subset_) {
190 tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
191 } else {
192 // get subset
193 #ifdef TIMETAG
194 auto start_time = std::chrono::steady_clock::now();
195 #endif
196 tmp_subset_->ReSize(bag_data_cnt_);
197 tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false);
198 #ifdef TIMETAG
199 subset_time += std::chrono::steady_clock::now() - start_time;
200 #endif
201 #ifdef TIMETAG
202 start_time = std::chrono::steady_clock::now();
203 #endif
204 tree_learner_->ResetTrainingData(tmp_subset_.get());
205 #ifdef TIMETAG
206 re_init_tree_time += std::chrono::steady_clock::now() - start_time;
207 #endif
208 }
209 }
210
211private:
212 std::vector<data_size_t> tmp_indice_right_;
213};
214
215} // namespace LightGBM
216#endif // LIGHTGBM_BOOSTING_GOSS_H_
The main class of data set, which are used to traning or validation.
Definition dataset.h:278
GBDT algorithm implementation. including Training, prediction, bagging.
Definition gbdt.h:26
std::vector< score_t > hessians_
Secend order derivative of training data.
Definition gbdt.h:445
data_size_t num_data_
Number of training data.
Definition gbdt.h:453
std::unique_ptr< TreeLearner > tree_learner_
Tree learner, will use this class to learn trees.
Definition gbdt.h:419
void Init(const Config *gbdt_config, const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Initialization logic.
Definition gbdt.cpp:45
int num_threads_
number of threads
Definition gbdt.h:470
std::vector< data_size_t > tmp_indices_
Store the indices of in-bag data.
Definition gbdt.h:451
std::vector< data_size_t > right_write_pos_buf_
Buffer for multi-threading bagging.
Definition gbdt.h:480
std::vector< data_size_t > left_write_pos_buf_
Buffer for multi-threading bagging.
Definition gbdt.h:478
void ResetTrainingData(const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Reset the training data.
Definition gbdt.cpp:622
void ResetConfig(const Config *gbdt_config) override
Reset Boosting Config.
Definition gbdt.cpp:676
std::vector< data_size_t > bag_data_indices_
Store the indices of in-bag data.
Definition gbdt.h:447
std::unique_ptr< Config > config_
Config of gbdt.
Definition gbdt.h:417
data_size_t bag_data_cnt_
Number of in-bag data.
Definition gbdt.h:449
int num_tree_per_iteration_
Number of trees per iterations.
Definition gbdt.h:455
const Dataset * train_data_
Pointer to training data.
Definition gbdt.h:415
std::vector< data_size_t > right_cnts_buf_
Buffer for multi-threading bagging.
Definition gbdt.h:476
std::vector< score_t > gradients_
First order derivative of training data.
Definition gbdt.h:443
std::vector< data_size_t > offsets_buf_
Buffer for multi-threading bagging.
Definition gbdt.h:472
std::vector< data_size_t > left_cnts_buf_
Buffer for multi-threading bagging.
Definition gbdt.h:474
Definition goss.hpp:26
void Bagging(int iter) override
Implement bagging logic.
Definition goss.hpp:137
void Init(const Config *config, const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Initialization logic.
Definition goss.hpp:41
void ResetTrainingData(const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics) override
Reset the training data.
Definition goss.hpp:47
void ResetConfig(const Config *config) override
Reset Boosting Config.
Definition goss.hpp:53
GOSS()
Constructor.
Definition goss.hpp:31
The interface of Objective Function.
Definition objective_function.h:13
A wrapper for random generator.
Definition random.h:15
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float score_t
Type of score, and gradients.
Definition meta.h:26
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14
Definition config.h:27