Medial Code Documentation
Loading...
Searching...
No Matches
MedLightGBM.h
1#pragma once
2
3#ifndef __MED_LIGHT_GBM__
4#define __MED_LIGHT_GBM__
5
7#include <LightGBM/c_api.h>
8
9#include <vector>
10#include <string>
11#include <map>
12
13//=============================================
14// MedLightGBM Params
15//=============================================
16
18{
20 {
21 defaults = "";
22 defaults += "boosting_type=gbdt;";
23 defaults += "objective=binary;";
24 defaults += "metric=binary_logloss,auc;";
25 defaults += "metric_freq=1;";
26 defaults += "is_training_metric=true;";
27 defaults += "max_bin=255;";
28 defaults += "num_iterations=200;";
29 defaults += "learning_rate=0.05;";
30 defaults += "tree_learner=serial;";
31 defaults += "num_threads=12;";
32 defaults += "min_data_in_leaf=50;";
33 defaults += "min_sum_hessian_in_leaf=5.0;";
34 defaults += "is_enable_sparse=false;";
35 defaults += "num_machines=1;";
36 defaults += "feature_fraction=0.8;";
37 defaults += "bagging_fraction=0.25;";
38 defaults += "bagging_freq=4;";
39 defaults += "is_unbalance=true;";
40 defaults += "num_leaves=80";
41 }
42
43 string defaults = "";
44 string user_params = "";
45
46 ADD_CLASS_NAME(MedLightGBMParams)
47 ADD_SERIALIZATION_FUNCS(defaults, user_params);
48};
49
50//=============================================
51// MedLightGBM Predictor
52//=============================================
53
55{
56public:
57 MedLightGBMParams params;
58
59 // Initialization methods
60 int init(map<string, string> &initialization_map);
61 int set_params(map<string, string> &initialization_map);
62 int init_from_string(string init_str);
63
66
67 // Core ML functions
68 int Learn(float *x, float *y, const float *w, int nsamples, int nftrs);
69 int Learn(float *x, float *y, int nsamples, int nftrs) { return Learn(x, y, NULL, nsamples, nftrs); }
70
71 int Predict(float *x, float *&preds, int nsamples, int nftrs) const;
72 void predict_single(const vector<float> &x, vector<float> &preds) const;
73 void predict_single(const vector<double> &x, vector<double> &preds) const;
74 void prepare_predict_single();
75
76 // Feature Importance & SHAP
77 void calc_feature_importance(vector<float> &features_importance_scores, const string &general_params, const MedFeatures *features);
78 void calc_feature_contribs(MedMat<float> &x, MedMat<float> &contribs) const;
79
80 int n_preds_per_sample() const;
81 void export_predictor(const string &output_fname);
82
83 // Serialization
84 void pre_serialization();
85 void post_deserialization();
86 void print(FILE *fp, const string &prefix, int level = 0) const;
87 string get_model_str();
88
89 ADD_CLASS_NAME(MedLightGBM)
90 ADD_SERIALIZATION_FUNCS(classifier_type, params, model_as_string, model_features, features_count, _mark_learn_done)
91
92private:
93 bool _mark_learn_done = false;
94 string model_as_string;
95 bool prepared_single;
96
97 int num_preds;
98
99 // The core C API Handle
100 BoosterHandle booster_handle_;
101
102 // Helper to generate LightGBM C API parameter strings
103 string build_c_api_param_string() const;
104 int get_num_iterations() const;
105};
106
109
110#endif
MedAlgo - APIs to different algorithms: Linear Models, RF, GBM, KNN, and more.
#define ADD_SERIALIZATION_FUNCS(...)
Definition SerializableObject.h:156
#define MEDSERIALIZE_SUPPORT(Type)
Definition SerializableObject.h:142
A class for holding features data as a virtual matrix
Definition MedFeatures.h:47
Definition MedLightGBM.h:55
int n_preds_per_sample() const
Number of predictions per sample. typically 1 - but some models return several per sample (for exampl...
Definition MedLightGBM.cpp:251
int Learn(float *x, float *y, const float *w, int nsamples, int nftrs)
Learn should be implemented for each model.
Definition MedLightGBM.cpp:97
int Predict(float *x, float *&preds, int nsamples, int nftrs) const
Predict should be implemented for each model.
Definition MedLightGBM.cpp:139
int init(map< string, string > &initialization_map)
Virtual to init object from parsed fields.
Definition MedLightGBM.cpp:73
Definition MedMat.h:63
Base Interface for predictor.
Definition MedAlgo.h:72
int features_count
The model features count used in Learn, to validate when caling predict.
Definition MedAlgo.h:90
MedPredictorTypes classifier_type
The Predicotr enum type.
Definition MedAlgo.h:74
vector< string > model_features
The model features used in Learn, to validate when caling predict.
Definition MedAlgo.h:87
Definition SerializableObject.h:33
Definition MedLightGBM.h:18