Medial Code Documentation
Loading...
Searching...
No Matches
boosting.h
1#ifndef LIGHTGBM_BOOSTING_H_
2#define LIGHTGBM_BOOSTING_H_
3
4#include <LightGBM/meta.h>
5#include <LightGBM/config.h>
6
7#include <vector>
8#include <string>
9#include <map>
10
11namespace LightGBM {
12
14class Dataset;
15class ObjectiveFunction;
16class Metric;
17struct PredictionEarlyStopInstance;
18
22class LIGHTGBM_EXPORT Boosting {
23public:
25 virtual ~Boosting() {}
26
34 virtual void Init(
35 const Config* config,
36 const Dataset* train_data,
37 const ObjectiveFunction* objective_function,
38 const std::vector<const Metric*>& training_metrics) = 0;
39
45 virtual void MergeFrom(const Boosting* other) = 0;
46
50 virtual void ShuffleModels(int start_iter, int end_iter) = 0;
51
52 virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
53 const std::vector<const Metric*>& training_metrics) = 0;
54
55 virtual void ResetConfig(const Config* config) = 0;
56
57
58
64 virtual void AddValidDataset(const Dataset* valid_data,
65 const std::vector<const Metric*>& valid_metrics) = 0;
66
67 virtual void Train(int snapshot_freq, const std::string& model_output_path) = 0;
68
72 virtual void RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) = 0;
73
80 virtual bool TrainOneIter(const score_t* gradients, const score_t* hessians) = 0;
81
85 virtual void RollbackOneIter() = 0;
86
90 virtual int GetCurrentIteration() const = 0;
91
97 virtual std::vector<double> GetEvalAt(int data_idx) const = 0;
98
104 virtual const double* GetTrainingScore(int64_t* out_len) = 0;
105
111 virtual int64_t GetNumPredictAt(int data_idx) const = 0;
112
119 virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0;
120
121 virtual int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const = 0;
122
129 virtual void PredictRaw(const double* features, double* output,
130 const PredictionEarlyStopInstance* early_stop) const = 0;
131
132 virtual void PredictRawByMap(const std::unordered_map<int, double>& features, double* output,
133 const PredictionEarlyStopInstance* early_stop) const = 0;
134
135
142 virtual void Predict(const double* features, double* output,
143 const PredictionEarlyStopInstance* early_stop) const = 0;
144
145 virtual void PredictByMap(const std::unordered_map<int, double>& features, double* output,
146 const PredictionEarlyStopInstance* early_stop) const = 0;
147
148
154 virtual void PredictLeafIndex(
155 const double* features, double* output) const = 0;
156
157 virtual void PredictLeafIndexByMap(
158 const std::unordered_map<int, double>& features, double* output) const = 0;
159
166 virtual void PredictContrib(const double* features, double* output,
167 const PredictionEarlyStopInstance* early_stop) const = 0;
168
175 virtual std::string DumpModel(int start_iteration, int num_iteration) const = 0;
176
182 virtual std::string ModelToIfElse(int num_iteration) const = 0;
183
190 virtual bool SaveModelToIfElse(int num_iteration, const char* filename) const = 0;
191
200 virtual bool SaveModelToFile(int start_iteration, int num_iterations, const char* filename) const = 0;
201
208 std::string SaveModelToString(int num_iterations) { return SaveModelToString(0, num_iterations); } // ADDED by Medial for backward compatibility with MedLightGBM
209 virtual std::string SaveModelToString(int start_iteration, int num_iterations) const = 0;
210
217 bool LoadModelFromString(std::string str) { return LoadModelFromString(str.c_str(), str.length()); } // ADDED by Medial for backward compatability with MedLightGBM
218 virtual bool LoadModelFromString(const char* buffer, size_t len) = 0;
219
226 virtual std::vector<double> FeatureImportance(int num_iteration, int importance_type) const = 0;
227
232 virtual int MaxFeatureIdx() const = 0;
233
238 virtual std::vector<std::string> FeatureNames() const = 0;
239
244 virtual int LabelIdx() const = 0;
245
250 virtual int NumberOfTotalModel() const = 0;
251
256 virtual int NumModelPerIteration() const = 0;
257
262 virtual int NumberOfClasses() const = 0;
263
265 virtual bool NeedAccuratePrediction() const = 0;
266
272 virtual void InitPredict(int num_iteration, bool is_pred_contrib) = 0;
273
277 virtual const char* SubModelName() const = 0;
278
279 Boosting() = default;
281 Boosting& operator=(const Boosting&) = delete;
283 Boosting(const Boosting&) = delete;
284
285 static bool LoadFileToBoosting(Boosting* boosting, const char* filename);
286
295 static Boosting* CreateBoosting(const std::string& type, const char* filename);
296};
297
298class GBDTBase : public Boosting {
299public:
300 virtual double GetLeafValue(int tree_idx, int leaf_idx) const = 0;
301 virtual void SetLeafValue(int tree_idx, int leaf_idx, double val) = 0;
302};
303
304} // namespace LightGBM
305
306#endif // LightGBM_BOOSTING_H_
The interface for Boosting.
Definition boosting.h:22
virtual void RefitTree(const std::vector< std::vector< int > > &tree_leaf_prediction)=0
Update the tree output by new training data.
Boosting(const Boosting &)=delete
Disable copy.
virtual void Predict(const double *features, double *output, const PredictionEarlyStopInstance *early_stop) const =0
Prediction for one record, sigmoid transformation will be used if needed.
virtual std::vector< std::string > FeatureNames() const =0
Get feature names of this model.
virtual bool SaveModelToIfElse(int num_iteration, const char *filename) const =0
Translate model to if-else statement.
virtual int NumberOfTotalModel() const =0
Get number of weak sub-models.
virtual bool TrainOneIter(const score_t *gradients, const score_t *hessians)=0
Training logic.
virtual int NumberOfClasses() const =0
Get number of classes.
virtual int64_t GetNumPredictAt(int data_idx) const =0
Get prediction result at data_idx data.
virtual void ShuffleModels(int start_iter, int end_iter)=0
Shuffle Existing Models.
virtual std::vector< double > FeatureImportance(int num_iteration, int importance_type) const =0
Calculate feature importances.
virtual void MergeFrom(const Boosting *other)=0
Merge model from other boosting object Will insert to the front of current boosting object.
virtual void Init(const Config *config, const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics)=0
Initialization logic.
virtual void InitPredict(int num_iteration, bool is_pred_contrib)=0
Initial work for the prediction.
virtual int MaxFeatureIdx() const =0
Get max feature index of this model.
Boosting & operator=(const Boosting &)=delete
Disable copy.
virtual void PredictLeafIndex(const double *features, double *output) const =0
Prediction for one record with leaf index.
virtual const double * GetTrainingScore(int64_t *out_len)=0
Get current training score.
virtual std::string ModelToIfElse(int num_iteration) const =0
Translate model to if-else statement.
std::string SaveModelToString(int num_iterations)
Save model to string.
Definition boosting.h:208
virtual void PredictContrib(const double *features, double *output, const PredictionEarlyStopInstance *early_stop) const =0
Feature contributions for the model's prediction of one record.
virtual std::string DumpModel(int start_iteration, int num_iteration) const =0
Dump model to json format string.
virtual std::vector< double > GetEvalAt(int data_idx) const =0
Get evaluation result at data_idx data.
virtual bool NeedAccuratePrediction() const =0
The prediction should be accurate or not. True will disable early stopping for prediction.
virtual int NumModelPerIteration() const =0
Get number of models per iteration.
virtual void RollbackOneIter()=0
Rollback one iteration.
virtual void PredictRaw(const double *features, double *output, const PredictionEarlyStopInstance *early_stop) const =0
Prediction for one record, not sigmoid transform.
virtual const char * SubModelName() const =0
Name of submodel.
bool LoadModelFromString(std::string str)
Restore from a serialized string.
Definition boosting.h:217
virtual void GetPredictAt(int data_idx, double *result, int64_t *out_len)=0
Get prediction result at data_idx data.
virtual bool SaveModelToFile(int start_iteration, int num_iterations, const char *filename) const =0
Save model to file.
virtual ~Boosting()
virtual destructor
Definition boosting.h:25
virtual int GetCurrentIteration() const =0
return current iteration
virtual void AddValidDataset(const Dataset *valid_data, const std::vector< const Metric * > &valid_metrics)=0
Add a validation data.
virtual int LabelIdx() const =0
Get index of label column.
The main class of data set, which are used to traning or validation.
Definition dataset.h:278
Definition boosting.h:298
The interface of Objective Function.
Definition objective_function.h:13
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float score_t
Type of score, and gradients.
Definition meta.h:26
Definition config.h:27
Definition prediction_early_stop.h:11