Medial Code Documentation
Loading...
Searching...
No Matches
MedGDLM.h
1#pragma once
3
4//==============================================================================================
5// Linear Models2: Linear regression (with Ridge and/or Lasso), using Gradient Descent variants
6//==============================================================================================
8
9 // Required params
10 int max_iter;
12 int max_times_err_grows;
13 string method;
15 float rate;
16 float rate_decay;
17 float momentum;
18
19 int last_is_bias;
20 int print_model;
21 bool verbose_learn;
22
23 // Optional params
24 float l_ridge;
25 vector<float> ls_ridge;
26 float l_lasso;
27 vector<float> ls_lasso;;
28
31
32 int normalize = 0;
33
35 max_iter = 500; stop_at_err = (float)1e-4; max_times_err_grows = 20; method = "logistic_sgd"; batch_size = 512; rate = (float)0.01; rate_decay = (float)1.0; momentum = (float)0.95; last_is_bias = 0;
36 l_ridge = (float)0; l_lasso = (float)0; nthreads = 0; err_freq = 10; normalize = 0; verbose_learn = true;
37 }
38
39 ADD_CLASS_NAME(MedGDLMParams)
40 ADD_SERIALIZATION_FUNCS(method, last_is_bias, max_iter, stop_at_err, max_times_err_grows, batch_size, rate, rate_decay, l_ridge, l_lasso, ls_lasso, ls_ridge, nthreads, err_freq)
41};
42
43class MedGDLM : public MedPredictor {
44public:
45 // Model
46 int n_ftrs;
47 vector<float> b;
48 float b0;
49
50 // Parameters
51 MedGDLMParams params;
52
53 // Function
54 MedGDLM();
55 ~MedGDLM() {};
56 MedGDLM(void *params);
57 MedGDLM(MedGDLMParams& params);
60 int set_params(map<string, string>& mapper);
61 int init(void *params);
62 void init_defaults();
63
64 //int learn(MedMat<float> &x, MedMat<float> &y) {return (MedPredictor::learn(x,y));}; // Special case - un-normalized Y
65
66 int Learn(float *x, float *y, int nsamples, int nftrs);
67 int Learn(float *x, float *y, const float *w, int nsamples, int nftrs);
68
69 int Predict(float *x, float *&preds, int nsamples, int nftrs) const;
70 int Predict(float *x, float *&preds, int nsamples, int nftrs, int transposed_flag) const;
71
72 int denormalize_model(float *f_avg, float *f_std, float lavel_avg, float label_std);
73
74 void print(FILE *fp, const string& prefix, int level = 0) const;
75
77
78 ADD_CLASS_NAME(MedGDLM)
80
81
82 // actual computation functions
83 int Learn_full(float *x, float *y, const float *w, int nsamples, int nftrs); // full non-iterative solution, not supporting lasso
84 int Learn_gd(float *x, float *y, const float *w, int nsamples, int nftrs);
85 int Learn_sgd(float *x, float *y, const float *w, int nsamples, int nftrs);
86 int Learn_logistic_sgd(float *x, float *y, const float *w, int nsamples, int nftrs);
87 int Learn_logistic_sgd_threaded(float *x, float *y, const float *w, int nsamples, int nftrs);
88private:
89 void set_eigen_threads() const;
90 void calc_feature_importance(vector<float> &features_importance_scores, const string &general_params, const MedFeatures *features);
91};
92
MedAlgo - APIs to different algorithms: Linear Models, RF, GBM, KNN, and more.
#define ADD_SERIALIZATION_FUNCS(...)
Definition SerializableObject.h:122
#define MEDSERIALIZE_SUPPORT(Type)
Definition SerializableObject.h:108
A class for holding features data as a virtual matrix
Definition MedFeatures.h:47
Definition MedGDLM.h:43
int Predict(float *x, float *&preds, int nsamples, int nftrs) const
Predict should be implemented for each model.
Definition MedGDLM.cpp:207
int set_params(map< string, string > &mapper)
The parsed fields from init command.
Definition MedGDLM.cpp:88
void calc_feature_contribs(MedMat< float > &x, MedMat< float > &contribs)
Feature contributions explains the prediction on each sample (aka BUT_WHY)
Definition MedGDLM.cpp:76
Definition MedMat.h:63
Base Interface for predictor.
Definition MedAlgo.h:78
int features_count
The model features count used in Learn, to validate when caling predict.
Definition MedAlgo.h:96
MedPredictorTypes classifier_type
The Predicotr enum type.
Definition MedAlgo.h:80
vector< string > model_features
The model features used in Learn, to validate when caling predict.
Definition MedAlgo.h:93
Definition SerializableObject.h:32
Definition MedGDLM.h:7
float l_lasso
labmda for lasso
Definition MedGDLM.h:26
int err_freq
the frequency in which the stopping err on loss will be tested, reccomended > 10
Definition MedGDLM.h:30
float stop_at_err
stop criteria
Definition MedGDLM.h:11
int batch_size
for sgd
Definition MedGDLM.h:14
string method
gd or sgd
Definition MedGDLM.h:13
int nthreads
labmdas for lasso
Definition MedGDLM.h:29
float l_ridge
lambda for ridge
Definition MedGDLM.h:24
vector< float > ls_ridge
lambdas for ridge
Definition MedGDLM.h:25