Medial Code Documentation
Loading...
Searching...
No Matches
MedLightGBM.h
1#pragma once
2
3#ifndef __MED_LIGHT_GBM__
4#define __MED_LIGHT_GBM__
5
6//========================================================================================================
7// MedLightGBM
8//
9// Wrapping the LightGBM package into a MedPredictor.
10// The package code is in Libs/External/LightGBM.
11//
12//========================================================================================================
13
15#include <LightGBM/application.h>
16#include <LightGBM/dataset.h>
17#include <LightGBM/boosting.h>
18#include <LightGBM/objective_function.h>
19#include <LightGBM/metric.h>
20#include <LightGBM/config.h>
21#include <LightGBM/../../src/boosting/gbdt.h>
22#include <LightGBM/../../src/boosting/dart.hpp>
23#include <LightGBM/../../src/boosting/goss.hpp>
24
25
26//==================================================================
27// Wrapper for LightGBM::Application to handle our special cases of
28// loading data from memory, etc...
29//==================================================================
30namespace LightGBM {
31 class DatasetLoader;
32 class Dataset;
33 class Boosting;
34 class ObjectiveFunction;
35 class Metric;
36 class Predictor;
37
38 class MemApp : public Application {
39 public:
40 bool is_silent;
41 bool only_fatal;
42 MemApp(int argc, char **argv) : Application::Application(argc, argv) { is_silent = false; only_fatal = false; };
43 MemApp() : Application::Application(0, NULL) { is_silent = false; }
44 //~MemApp() { Application::~Application(); };
45
46 int init(map<string, string>& initialization_map) { return set_params(initialization_map); };
47 int set_params(map<string, string>& initialization_map);
48
49 // train
50 int InitTrain(float *xdata, float *ydata, const float *weight, int nrows, int ncols);
51 void Train();
52
53 // predict
54 void Predict(float *x, int nrows, int ncols, float *&preds) const;
55 void PredictShap(float *x, int nrows, int ncols, float *&shap_vals) const;
56
57 // initializing the train_data_ object from a float c matrix
58 int InitTrainData(float *xdata, float *ydata, const float *weight, int nrows, int ncols);
59
60 // string serializations
61 int serialize_to_string(string &str) const { str = boosting_->SaveModelToString(-1); return 0; }
62 int deserialize_from_string(string &str) {
63 std::unique_ptr<Boosting> ret;
64 string type = config_.boosting; // use boosting_type for older lightgbm version
65 if (type == std::string("gbdt")) ret.reset(new GBDT());
66 else if (type == std::string("dart")) ret.reset(new DART());
67 else if (type == std::string("goss")) ret.reset(new GOSS());
68 else { fprintf(stderr, "deserialize MedLightGBM ERROR: unknown boosting type %s\n", type.c_str()); return -1; }
69 if (!ret.get()->LoadModelFromString(str)) return -1;
70 boosting_.reset(ret.release());
71 return 0;
72 }
73
74 std::string get_boosting_type() { return config_.boosting; };
75 void fetch_boosting(LightGBM::Boosting *&res);
76 void fetch_early_stop(LightGBM::PredictionEarlyStopInstance &early_stop_);
77
78 // n_preds
79 int n_preds_per_sample() const {
80
81 int num_preb_in_one_row = config_.num_class; // In older lightgbm: config_.boosting_config.num_class;
82 //int is_pred_leaf = config_.io_config.is_predict_leaf_index ? 1 : 0; // In older lightgbm
83 int is_pred_leaf = config_.predict_leaf_index ? 1 : 0; // new lightgbm
84 int num_iteration = config_.num_iterations; // Older lightgbm : config_.boosting_config.num_iterations;
85 if (is_pred_leaf) {
86 int max_iteration = num_iteration;
87 if (num_iteration > 0) {
88 num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
89 }
90 else {
91 num_preb_in_one_row *= max_iteration;
92 }
93 }
94 return num_preb_in_one_row;
95 }
96
97 void calc_feature_importance(vector<float> &features_importance_scores,
98 const string &general_params, int max_feature_idx_);
99
100 };
101
102 class GBDT_Accessor : public GBDT {
103 public:
105 string mdl = booster->SaveModelToString(-1);
107 }
108
109 vector<float> FeatureImportanceTrick(const string &method = "gain") {
110 /*
111 vector<pair<size_t, string>> res = FeatureImportance(method);
112
113 vector<float> final_res(MaxFeatureIdx() + 1);
114 for (size_t i = 0; i < res.size(); ++i) {
115 int index = stoi(boost::replace_all_copy(res[i].second, "Column_", ""));
116 if (index >= final_res.size() || index < 0)
117 throw out_of_range("index is out of range: " + to_string(index) + " max=" +
118 to_string(final_res.size()));
119 final_res[index] = (float)res[i].first;
120 }
121 */
122
123 vector<double> imps = FeatureImportance(-1, method == "gain" ? 1 : 0);
124
125 vector<float> final_res;
126 for (auto d : imps)
127 final_res.push_back((float)d);
128 return final_res;
129 }
130 };
131
132};
133
134
135
136//=============================================
137// MedLightGBM
138//=============================================
139
141
143
144 defaults = "";
145 defaults += "boosting_type=gbdt;";
146 defaults += "objective=binary;";
147 defaults += "metric=binary_logloss,auc;";
148 defaults += "metric_freq=1;";
149 defaults += "is_training_metric=true;";
150 defaults += "max_bin=255;";
151 defaults += "num_trees=200;";
152 defaults += "learning_rate=0.05;";
153 defaults += "tree_learner=serial;";
154 defaults += "num_threads=12;";
155 defaults += "min_data_in_leaf=50;";
156 defaults += "min_sum_hessian_in_leaf=5.0;";
157 defaults += "is_enable_sparse=false;";
158 defaults += "num_machines=1;";
159 defaults += "feature_fraction=0.8;";
160 defaults += "bagging_fraction=0.25;";
161 defaults += "bagging_freq=4;";
162 defaults += "is_unbalance=true;";
163 defaults += "num_leaves=80"; // keep last param without a ; at the end
164
165 }
166
167 string defaults = "";
168 string user_params = "";
169
170 ADD_CLASS_NAME(MedLightGBMParams)
171 ADD_SERIALIZATION_FUNCS(defaults, user_params);
172};
173
174class MedLightGBM : public MedPredictor {
175public:
176 MedLightGBMParams params;
177
178 LightGBM::MemApp mem_app;
179
181 int init(map<string, string>& initialization_map) { return mem_app.init(initialization_map); }
182 int set_params(map<string, string>& initialization_map) { return mem_app.set_params(initialization_map); }
183
184 int init_from_string(string init_str) {
185 params.user_params += init_str;
186 string init = params.defaults + ";" + params.user_params;
187 //fprintf(stderr, "Calling MedLightGBM init with :\ninit_str %s\n user_params %s\n all %s\n", init_str.c_str(), params.user_params.c_str(), init.c_str());
188 map<string, string> init_map;
189 MedSerialize::initialization_text_to_map(init, init_map);
190 mem_app.is_silent = init_str.empty();
191 return MedLightGBM::init(init_map);
192 }
193
194 // Function
195 MedLightGBM() {
197 normalize_for_learn = false; //true;
198 normalize_for_predict = false; //true;
199 normalize_y_for_learn = false;
200 transpose_for_learn = false;
201 transpose_for_predict = false;
202 init_from_string("");
203 _mark_learn_done = false;
204 prepared_single = false;
205 };
206 ~MedLightGBM() {};
207
208 // learn predict
209 int Learn(float *x, float *y, const float *w, int nsamples, int nftrs) {
210 if (!mem_app.is_silent)
211 global_logger.log(LOG_MEDALGO, LOG_DEF_LEVEL, "Starting a LightGBM train session...\n");
212 mem_app.InitTrain(x, y, w, nsamples, nftrs);
213 mem_app.Train();
214 _mark_learn_done = true;
215 return 0;
216 }
217 int Learn(float *x, float *y, int nsamples, int nftrs) { return Learn(x, y, NULL, nsamples, nftrs); }
218 int Predict(float *x, float *&preds, int nsamples, int nftrs) const {
219 //mem_app.InitPredict(x, nsamples, nftrs);
220 mem_app.Predict(x, nsamples, nftrs, preds);
221 return 0;
222 }
223
224 void calc_feature_importance(vector<float> &features_importance_scores,
225 const string &general_params, const MedFeatures *features);
226
227
229
230 int n_preds_per_sample() const { return mem_app.n_preds_per_sample(); }
231
232 void prepare_predict_single();
233 void predict_single(const vector<float> &x, vector<float> &preds) const;
234 void predict_single(const vector<double> &x, vector<double> &preds) const;
235
236 void export_predictor(const string & output_fname);
237
238 // serializations
239 void pre_serialization() {
240 model_as_string = "";
241 if (_mark_learn_done) {
242 if (mem_app.serialize_to_string(model_as_string) < 0)
243 global_logger.log(LOG_MEDALGO, MAX_LOG_LEVEL, "MedLightGBM::serialize() failed moving model to string\n");
244 }
245 }
246
247 void post_deserialization() {
248 init_from_string("");
249 if (_mark_learn_done) {
250 if (mem_app.deserialize_from_string(model_as_string) < 0)
251 global_logger.log(LOG_MEDALGO, MAX_LOG_LEVEL, "MedLightGBM::deserialize() failed moving model to string\n");
252 model_as_string = "";
253 }
254 }
255
256 void print(FILE *fp, const string& prefix, int level = 0) const;
257
258 ADD_CLASS_NAME(MedLightGBM)
259 ADD_SERIALIZATION_FUNCS(classifier_type, params, model_as_string, model_features, features_count, _mark_learn_done)
260
261private:
262 bool _mark_learn_done = false;
263 string model_as_string;
264 bool prepared_single;
265
266 int num_preds;
267 LightGBM::Boosting *_boosting;
269};
270
273
274
275#endif
MedAlgo - APIs to different algorithms: Linear Models, RF, GBM, KNN, and more.
@ MODEL_LIGHTGBM
to_use:"lightgbm" the celebrated LightGBM algorithm - creates MedLightGBM
Definition MedAlgo.h:57
#define ADD_SERIALIZATION_FUNCS(...)
Definition SerializableObject.h:122
#define MEDSERIALIZE_SUPPORT(Type)
Definition SerializableObject.h:108
The main entrance of LightGBM. this application has two tasks: Train and Predict. Train task will tra...
Definition application.h:25
The interface for Boosting.
Definition boosting.h:22
std::string SaveModelToString(int num_iterations)
Save model to string.
Definition boosting.h:208
DART algorithm implementation. including Training, prediction, bagging.
Definition dart.hpp:17
Definition MedLightGBM.h:102
GBDT algorithm implementation. including Training, prediction, bagging.
Definition gbdt.h:26
std::vector< double > FeatureImportance(int num_iteration, int importance_type) const override
Calculate feature importances.
Definition gbdt_model_text.cpp:513
bool LoadModelFromString(std::string str)
Restore from a serialized buffer.
Definition gbdt.h:282
Definition goss.hpp:26
Definition MedLightGBM.h:38
A class for holding features data as a virtual matrix
Definition MedFeatures.h:47
Definition MedLightGBM.h:174
int n_preds_per_sample() const
Number of predictions per sample. typically 1 - but some models return several per sample (for exampl...
Definition MedLightGBM.h:230
int Learn(float *x, float *y, const float *w, int nsamples, int nftrs)
Learn should be implemented for each model.
Definition MedLightGBM.h:209
int Predict(float *x, float *&preds, int nsamples, int nftrs) const
Predict should be implemented for each model.
Definition MedLightGBM.h:218
void calc_feature_contribs(MedMat< float > &x, MedMat< float > &contribs)
Feature contributions explains the prediction on each sample (aka BUT_WHY)
Definition MedLightGBM.cpp:464
int init(map< string, string > &initialization_map)
please reffer to KeyAliasTransform in LightGBM::ParameterAlias
Definition MedLightGBM.h:181
Definition MedMat.h:63
Base Interface for predictor.
Definition MedAlgo.h:78
bool normalize_for_learn
True if need to normalize before learn.
Definition MedAlgo.h:87
bool transpose_for_predict
True if need to transpose before predict.
Definition MedAlgo.h:90
bool normalize_for_predict
True if need to normalize before predict.
Definition MedAlgo.h:91
int features_count
The model features count used in Learn, to validate when caling predict.
Definition MedAlgo.h:96
bool normalize_y_for_learn
True if need to normalize labels before learn.
Definition MedAlgo.h:88
MedPredictorTypes classifier_type
The Predicotr enum type.
Definition MedAlgo.h:80
vector< string > model_features
The model features used in Learn, to validate when caling predict.
Definition MedAlgo.h:93
bool transpose_for_learn
True if need to transpose before learn.
Definition MedAlgo.h:86
Definition SerializableObject.h:32
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
Definition prediction_early_stop.h:11
Definition MedLightGBM.h:140