3#ifndef __MED_LIGHT_GBM__
4#define __MED_LIGHT_GBM__
15#include <LightGBM/application.h>
16#include <LightGBM/dataset.h>
17#include <LightGBM/boosting.h>
18#include <LightGBM/objective_function.h>
19#include <LightGBM/metric.h>
20#include <LightGBM/config.h>
21#include <LightGBM/../../src/boosting/gbdt.h>
22#include <LightGBM/../../src/boosting/dart.hpp>
23#include <LightGBM/../../src/boosting/goss.hpp>
34 class ObjectiveFunction;
42 MemApp(
int argc,
char **argv) : Application::Application(argc, argv) { is_silent =
false; only_fatal =
false; };
43 MemApp() : Application::Application(0, NULL) { is_silent =
false; }
46 int init(map<string, string>& initialization_map) {
return set_params(initialization_map); };
47 int set_params(map<string, string>& initialization_map);
50 int InitTrain(
float *xdata,
float *ydata,
const float *weight,
int nrows,
int ncols);
54 void Predict(
float *x,
int nrows,
int ncols,
float *&preds)
const;
55 void PredictShap(
float *x,
int nrows,
int ncols,
float *&shap_vals)
const;
58 int InitTrainData(
float *xdata,
float *ydata,
const float *weight,
int nrows,
int ncols);
61 int serialize_to_string(
string &str)
const { str = boosting_->SaveModelToString(-1);
return 0; }
62 int deserialize_from_string(
string &str) {
63 std::unique_ptr<Boosting> ret;
64 string type = config_.boosting;
65 if (type == std::string(
"gbdt")) ret.reset(
new GBDT());
66 else if (type == std::string(
"dart")) ret.reset(
new DART());
67 else if (type == std::string(
"goss")) ret.reset(
new GOSS());
68 else { fprintf(stderr,
"deserialize MedLightGBM ERROR: unknown boosting type %s\n", type.c_str());
return -1; }
69 if (!ret.get()->LoadModelFromString(str))
return -1;
70 boosting_.reset(ret.release());
74 std::string get_boosting_type() {
return config_.boosting; };
79 int n_preds_per_sample()
const {
81 int num_preb_in_one_row = config_.num_class;
83 int is_pred_leaf = config_.predict_leaf_index ? 1 : 0;
84 int num_iteration = config_.num_iterations;
86 int max_iteration = num_iteration;
87 if (num_iteration > 0) {
88 num_preb_in_one_row *=
static_cast<int>(std::min(max_iteration, num_iteration));
91 num_preb_in_one_row *= max_iteration;
94 return num_preb_in_one_row;
97 void calc_feature_importance(vector<float> &features_importance_scores,
98 const string &general_params,
int max_feature_idx_);
109 vector<float> FeatureImportanceTrick(
const string &method =
"gain") {
125 vector<float> final_res;
127 final_res.push_back((
float)d);
145 defaults +=
"boosting_type=gbdt;";
146 defaults +=
"objective=binary;";
147 defaults +=
"metric=binary_logloss,auc;";
148 defaults +=
"metric_freq=1;";
149 defaults +=
"is_training_metric=true;";
150 defaults +=
"max_bin=255;";
151 defaults +=
"num_trees=200;";
152 defaults +=
"learning_rate=0.05;";
153 defaults +=
"tree_learner=serial;";
154 defaults +=
"num_threads=12;";
155 defaults +=
"min_data_in_leaf=50;";
156 defaults +=
"min_sum_hessian_in_leaf=5.0;";
157 defaults +=
"is_enable_sparse=false;";
158 defaults +=
"num_machines=1;";
159 defaults +=
"feature_fraction=0.8;";
160 defaults +=
"bagging_fraction=0.25;";
161 defaults +=
"bagging_freq=4;";
162 defaults +=
"is_unbalance=true;";
163 defaults +=
"num_leaves=80";
167 string defaults =
"";
168 string user_params =
"";
181 int init(map<string, string>& initialization_map) {
return mem_app.init(initialization_map); }
182 int set_params(map<string, string>& initialization_map) {
return mem_app.set_params(initialization_map); }
184 int init_from_string(
string init_str) {
185 params.user_params += init_str;
186 string init = params.defaults +
";" + params.user_params;
188 map<string, string> init_map;
189 MedSerialize::initialization_text_to_map(
init, init_map);
190 mem_app.is_silent = init_str.empty();
202 init_from_string(
"");
203 _mark_learn_done =
false;
204 prepared_single =
false;
209 int Learn(
float *x,
float *y,
const float *w,
int nsamples,
int nftrs) {
210 if (!mem_app.is_silent)
211 global_logger.log(LOG_MEDALGO, LOG_DEF_LEVEL,
"Starting a LightGBM train session...\n");
212 mem_app.InitTrain(x, y, w, nsamples, nftrs);
214 _mark_learn_done =
true;
217 int Learn(
float *x,
float *y,
int nsamples,
int nftrs) {
return Learn(x, y, NULL, nsamples, nftrs); }
218 int Predict(
float *x,
float *&preds,
int nsamples,
int nftrs)
const {
220 mem_app.Predict(x, nsamples, nftrs, preds);
224 void calc_feature_importance(vector<float> &features_importance_scores,
225 const string &general_params,
const MedFeatures *features);
232 void prepare_predict_single();
233 void predict_single(
const vector<float> &x, vector<float> &preds)
const;
234 void predict_single(
const vector<double> &x, vector<double> &preds)
const;
236 void export_predictor(
const string & output_fname);
239 void pre_serialization() {
240 model_as_string =
"";
241 if (_mark_learn_done) {
242 if (mem_app.serialize_to_string(model_as_string) < 0)
243 global_logger.log(LOG_MEDALGO, MAX_LOG_LEVEL,
"MedLightGBM::serialize() failed moving model to string\n");
247 void post_deserialization() {
248 init_from_string(
"");
249 if (_mark_learn_done) {
250 if (mem_app.deserialize_from_string(model_as_string) < 0)
251 global_logger.log(LOG_MEDALGO, MAX_LOG_LEVEL,
"MedLightGBM::deserialize() failed moving model to string\n");
252 model_as_string =
"";
256 void print(FILE *fp,
const string& prefix,
int level = 0)
const;
262 bool _mark_learn_done =
false;
263 string model_as_string;
264 bool prepared_single;
MedAlgo - APIs to different algorithms: Linear Models, RF, GBM, KNN, and more.
@ MODEL_LIGHTGBM
to_use:"lightgbm" the celebrated LightGBM algorithm - creates MedLightGBM
Definition MedAlgo.h:57
#define ADD_SERIALIZATION_FUNCS(...)
Definition SerializableObject.h:122
#define MEDSERIALIZE_SUPPORT(Type)
Definition SerializableObject.h:108
The main entrance of LightGBM. this application has two tasks: Train and Predict. Train task will tra...
Definition application.h:25
The interface for Boosting.
Definition boosting.h:22
std::string SaveModelToString(int num_iterations)
Save model to string.
Definition boosting.h:208
DART algorithm implementation. including Training, prediction, bagging.
Definition dart.hpp:17
Definition MedLightGBM.h:102
GBDT algorithm implementation. including Training, prediction, bagging.
Definition gbdt.h:26
std::vector< double > FeatureImportance(int num_iteration, int importance_type) const override
Calculate feature importances.
Definition gbdt_model_text.cpp:513
bool LoadModelFromString(std::string str)
Restore from a serialized buffer.
Definition gbdt.h:282
Definition MedLightGBM.h:38
A class for holding features data as a virtual matrix
Definition MedFeatures.h:47
Definition MedLightGBM.h:174
int n_preds_per_sample() const
Number of predictions per sample. typically 1 - but some models return several per sample (for exampl...
Definition MedLightGBM.h:230
int Learn(float *x, float *y, const float *w, int nsamples, int nftrs)
Learn should be implemented for each model.
Definition MedLightGBM.h:209
int Predict(float *x, float *&preds, int nsamples, int nftrs) const
Predict should be implemented for each model.
Definition MedLightGBM.h:218
void calc_feature_contribs(MedMat< float > &x, MedMat< float > &contribs)
Feature contributions explains the prediction on each sample (aka BUT_WHY)
Definition MedLightGBM.cpp:464
int init(map< string, string > &initialization_map)
please reffer to KeyAliasTransform in LightGBM::ParameterAlias
Definition MedLightGBM.h:181
Base Interface for predictor.
Definition MedAlgo.h:78
bool normalize_for_learn
True if need to normalize before learn.
Definition MedAlgo.h:87
bool transpose_for_predict
True if need to transpose before predict.
Definition MedAlgo.h:90
bool normalize_for_predict
True if need to normalize before predict.
Definition MedAlgo.h:91
int features_count
The model features count used in Learn, to validate when caling predict.
Definition MedAlgo.h:96
bool normalize_y_for_learn
True if need to normalize labels before learn.
Definition MedAlgo.h:88
MedPredictorTypes classifier_type
The Predicotr enum type.
Definition MedAlgo.h:80
vector< string > model_features
The model features used in Learn, to validate when caling predict.
Definition MedAlgo.h:93
bool transpose_for_learn
True if need to transpose before learn.
Definition MedAlgo.h:86
Definition SerializableObject.h:32
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
Definition prediction_early_stop.h:11
Definition MedLightGBM.h:140