Medial Code Documentation
Loading...
Searching...
No Matches
predictor.hpp
1#ifndef LIGHTGBM_PREDICTOR_HPP_
2#define LIGHTGBM_PREDICTOR_HPP_
3
4#include <LightGBM/meta.h>
5#include <LightGBM/boosting.h>
6#include <LightGBM/utils/text_reader.h>
7#include <LightGBM/dataset.h>
8
9#include <LightGBM/utils/openmp_wrapper.h>
10
11#include <map>
12#include <cstring>
13#include <cstdio>
14#include <vector>
15#include <utility>
16#include <functional>
17#include <string>
18#include <memory>
19
20namespace LightGBM {
21
25class Predictor {
26public:
35 Predictor(Boosting* boosting, int num_iteration,
36 bool is_raw_score, bool predict_leaf_index, bool predict_contrib,
37 bool early_stop, int early_stop_freq, double early_stop_margin) {
39 if (early_stop && !boosting->NeedAccuratePrediction()) {
40 PredictionEarlyStopConfig pred_early_stop_config;
41 CHECK(early_stop_freq > 0);
42 CHECK(early_stop_margin >= 0);
43 pred_early_stop_config.margin_threshold = early_stop_margin;
44 pred_early_stop_config.round_period = early_stop_freq;
45 if (boosting->NumberOfClasses() == 1) {
46 early_stop_ = CreatePredictionEarlyStopInstance("binary", pred_early_stop_config);
47 } else {
48 early_stop_ = CreatePredictionEarlyStopInstance("multiclass", pred_early_stop_config);
49 }
50 }
51
52 #pragma omp parallel
53 #pragma omp master
54 {
55 num_threads_ = omp_get_num_threads();
56 }
57 boosting->InitPredict(num_iteration, predict_contrib);
58 boosting_ = boosting;
59 num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, predict_leaf_index, predict_contrib);
60 num_feature_ = boosting_->MaxFeatureIdx() + 1;
61 predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f));
62 const int kFeatureThreshold = 100000;
63 const size_t KSparseThreshold = static_cast<size_t>(0.01 * num_feature_);
64 if (predict_leaf_index) {
65 predict_fun_ = [=](const std::vector<std::pair<int, double>>& features, double* output) {
66 int tid = omp_get_thread_num();
67 if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) {
68 auto buf = CopyToPredictMap(features);
69 boosting_->PredictLeafIndexByMap(buf, output);
70 } else {
71 CopyToPredictBuffer(predict_buf_[tid].data(), features);
72 // get result for leaf index
73 boosting_->PredictLeafIndex(predict_buf_[tid].data(), output);
74 ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
75 }
76 };
77 } else if (predict_contrib) {
78 predict_fun_ = [=](const std::vector<std::pair<int, double>>& features, double* output) {
79 int tid = omp_get_thread_num();
80 CopyToPredictBuffer(predict_buf_[tid].data(), features);
81 // get result for leaf index
82 boosting_->PredictContrib(predict_buf_[tid].data(), output, &early_stop_);
83 ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
84 };
85 } else {
86 if (is_raw_score) {
87 predict_fun_ = [=](const std::vector<std::pair<int, double>>& features, double* output) {
88 int tid = omp_get_thread_num();
89 if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) {
90 auto buf = CopyToPredictMap(features);
91 boosting_->PredictRawByMap(buf, output, &early_stop_);
92 } else {
93 CopyToPredictBuffer(predict_buf_[tid].data(), features);
94 boosting_->PredictRaw(predict_buf_[tid].data(), output, &early_stop_);
95 ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
96 }
97 };
98 } else {
99 predict_fun_ = [=](const std::vector<std::pair<int, double>>& features, double* output) {
100 int tid = omp_get_thread_num();
101 if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) {
102 auto buf = CopyToPredictMap(features);
103 boosting_->PredictByMap(buf, output, &early_stop_);
104 } else {
105 CopyToPredictBuffer(predict_buf_[tid].data(), features);
106 boosting_->Predict(predict_buf_[tid].data(), output, &early_stop_);
107 ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
108 }
109 };
110 }
111 }
112 }
113
118 }
119
120 inline const PredictFunction& GetPredictFunction() const {
121 return predict_fun_;
122 }
123
129 void Predict(const char* data_filename, const char* result_filename, bool header) {
130 auto writer = VirtualFileWriter::Make(result_filename);
131 if (!writer->Init()) {
132 Log::Fatal("Prediction results file %s cannot be found", result_filename);
133 }
134 auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx()));
135
136 if (parser == nullptr) {
137 Log::Fatal("Could not recognize the data format of data file %s", data_filename);
138 }
139
140 TextReader<data_size_t> predict_data_reader(data_filename, header);
141 std::unordered_map<int, int> feature_names_map_;
142 bool need_adjust = false;
143 if (header) {
144 std::string first_line = predict_data_reader.first_line();
145 std::vector<std::string> header_words = Common::Split(first_line.c_str(), "\t,");
146 header_words.erase(header_words.begin() + boosting_->LabelIdx());
147 for (int i = 0; i < static_cast<int>(header_words.size()); ++i) {
148 for (int j = 0; j < static_cast<int>(boosting_->FeatureNames().size()); ++j) {
149 if (header_words[i] == boosting_->FeatureNames()[j]) {
150 feature_names_map_[i] = j;
151 break;
152 }
153 }
154 }
155 for (auto s : feature_names_map_) {
156 if (s.first != s.second) {
157 need_adjust = true;
158 break;
159 }
160 }
161 }
162 // function for parse data
163 std::function<void(const char*, std::vector<std::pair<int, double>>*)> parser_fun;
164 double tmp_label;
165 parser_fun = [&]
166 (const char* buffer, std::vector<std::pair<int, double>>* feature) {
167 parser->ParseOneLine(buffer, feature, &tmp_label);
168 if (need_adjust) {
169 int i = 0, j = static_cast<int>(feature->size());
170 while (i < j) {
171 if (feature_names_map_.find((*feature)[i].first) != feature_names_map_.end()) {
172 (*feature)[i].first = feature_names_map_[(*feature)[i].first];
173 ++i;
174 } else {
175 // move the non-used features to the end of the feature vector
176 std::swap((*feature)[i], (*feature)[--j]);
177 }
178 }
179 feature->resize(i);
180 }
181 };
182
183 std::function<void(data_size_t, const std::vector<std::string>&)> process_fun = [&]
184 (data_size_t, const std::vector<std::string>& lines) {
185 std::vector<std::pair<int, double>> oneline_features;
186 std::vector<std::string> result_to_write(lines.size());
187 OMP_INIT_EX();
188 #pragma omp parallel for schedule(static) firstprivate(oneline_features)
189 for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
190 OMP_LOOP_EX_BEGIN();
191 oneline_features.clear();
192 // parser
193 parser_fun(lines[i].c_str(), &oneline_features);
194 // predict
195 std::vector<double> result(num_pred_one_row_);
196 predict_fun_(oneline_features, result.data());
197 auto str_result = Common::Join<double>(result, "\t");
198 result_to_write[i] = str_result;
199 OMP_LOOP_EX_END();
200 }
201 OMP_THROW_EX();
202 for (data_size_t i = 0; i < static_cast<data_size_t>(result_to_write.size()); ++i) {
203 writer->Write(result_to_write[i].c_str(), result_to_write[i].size());
204 writer->Write("\n", 1);
205 }
206 };
207 predict_data_reader.ReadAllAndProcessParallel(process_fun);
208 }
209
210private:
211 void CopyToPredictBuffer(double* pred_buf, const std::vector<std::pair<int, double>>& features) {
212 int loop_size = static_cast<int>(features.size());
213 for (int i = 0; i < loop_size; ++i) {
214 if (features[i].first < num_feature_) {
215 pred_buf[features[i].first] = features[i].second;
216 }
217 }
218 }
219
220 void ClearPredictBuffer(double* pred_buf, size_t buf_size, const std::vector<std::pair<int, double>>& features) {
221 if (features.size() > static_cast<size_t>(buf_size / 2)) {
222 std::memset(pred_buf, 0, sizeof(double)*(buf_size));
223 } else {
224 int loop_size = static_cast<int>(features.size());
225 for (int i = 0; i < loop_size; ++i) {
226 if (features[i].first < num_feature_) {
227 pred_buf[features[i].first] = 0.0f;
228 }
229 }
230 }
231 }
232
233 std::unordered_map<int, double> CopyToPredictMap(const std::vector<std::pair<int, double>>& features) {
234 std::unordered_map<int, double> buf;
235 int loop_size = static_cast<int>(features.size());
236 for (int i = 0; i < loop_size; ++i) {
237 if (features[i].first < num_feature_) {
238 buf[features[i].first] = features[i].second;
239 }
240 }
241 return buf;
242 }
243
245 const Boosting* boosting_;
247 PredictFunction predict_fun_;
248 PredictionEarlyStopInstance early_stop_;
249 int num_feature_;
250 int num_pred_one_row_;
251 int num_threads_;
252 std::vector<std::vector<double>> predict_buf_;
253};
254
255} // namespace LightGBM
256
257#endif // LightGBM_PREDICTOR_HPP_
The interface for Boosting.
Definition boosting.h:22
virtual void Predict(const double *features, double *output, const PredictionEarlyStopInstance *early_stop) const =0
Prediction for one record, sigmoid transformation will be used if needed.
virtual std::vector< std::string > FeatureNames() const =0
Get feature names of this model.
virtual int NumberOfClasses() const =0
Get number of classes.
virtual void InitPredict(int num_iteration, bool is_pred_contrib)=0
Initial work for the prediction.
virtual int MaxFeatureIdx() const =0
Get max feature index of this model.
virtual void PredictLeafIndex(const double *features, double *output) const =0
Prediction for one record with leaf index.
virtual void PredictContrib(const double *features, double *output, const PredictionEarlyStopInstance *early_stop) const =0
Feature contributions for the model's prediction of one record.
virtual bool NeedAccuratePrediction() const =0
The prediction should be accurate or not. True will disable early stopping for prediction.
virtual void PredictRaw(const double *features, double *output, const PredictionEarlyStopInstance *early_stop) const =0
Prediction for one record, not sigmoid transform.
virtual int LabelIdx() const =0
Get index of label column.
static Parser * CreateParser(const char *filename, bool header, int num_features, int label_idx)
Create a object of parser, will auto choose the format depend on file.
Definition parser.cpp:87
Used to predict data with input model.
Definition predictor.hpp:25
void Predict(const char *data_filename, const char *result_filename, bool header)
predicting on data, then saving result to disk
Definition predictor.hpp:129
~Predictor()
Destructor.
Definition predictor.hpp:117
Predictor(Boosting *boosting, int num_iteration, bool is_raw_score, bool predict_leaf_index, bool predict_contrib, bool early_stop, int early_stop_freq, double early_stop_margin)
Constructor.
Definition predictor.hpp:35
Read text data from file.
Definition text_reader.h:21
std::string first_line()
return first line of data
Definition text_reader.h:74
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
LIGHTGBM_EXPORT PredictionEarlyStopInstance CreatePredictionEarlyStopInstance(const std::string &type, const PredictionEarlyStopConfig &config)
Create an early stopping algorithm of type type, with given round_period and margin threshold.
Definition prediction_early_stop.cpp:76
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14
NLOHMANN_BASIC_JSON_TPL_DECLARATION void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL &j1, nlohmann::NLOHMANN_BASIC_JSON_TPL &j2) noexcept(//NOLINT(readability-inconsistent-declaration-parameter-name, cert-dcl58-cpp) is_nothrow_move_constructible< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value &&//NOLINT(misc-redundant-expression) is_nothrow_move_assignable< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value)
exchanges the values of two JSON objects
Definition json.hpp:24418
Definition prediction_early_stop.h:21
static std::unique_ptr< VirtualFileWriter > Make(const std::string &filename)
Create appropriate writer for filename.
Definition file_io.cpp:161