1#ifndef LIGHTGBM_BOOSTING_H_
2#define LIGHTGBM_BOOSTING_H_
4#include <LightGBM/meta.h>
5#include <LightGBM/config.h>
15class ObjectiveFunction;
17struct PredictionEarlyStopInstance;
38 const std::vector<const Metric*>& training_metrics) = 0;
53 const std::vector<const Metric*>& training_metrics) = 0;
55 virtual void ResetConfig(
const Config* config) = 0;
65 const std::vector<const Metric*>& valid_metrics) = 0;
67 virtual void Train(
int snapshot_freq,
const std::string& model_output_path) = 0;
72 virtual void RefitTree(
const std::vector<std::vector<int>>& tree_leaf_prediction) = 0;
97 virtual std::vector<double>
GetEvalAt(
int data_idx)
const = 0;
119 virtual void GetPredictAt(
int data_idx,
double* result, int64_t* out_len) = 0;
121 virtual int NumPredictOneRow(
int num_iteration,
bool is_pred_leaf,
bool is_pred_contrib)
const = 0;
129 virtual void PredictRaw(
const double* features,
double* output,
132 virtual void PredictRawByMap(
const std::unordered_map<int, double>& features,
double* output,
142 virtual void Predict(
const double* features,
double* output,
145 virtual void PredictByMap(
const std::unordered_map<int, double>& features,
double* output,
155 const double* features,
double* output)
const = 0;
157 virtual void PredictLeafIndexByMap(
158 const std::unordered_map<int, double>& features,
double* output)
const = 0;
175 virtual std::string
DumpModel(
int start_iteration,
int num_iteration)
const = 0;
200 virtual bool SaveModelToFile(
int start_iteration,
int num_iterations,
const char* filename)
const = 0;
209 virtual std::string SaveModelToString(
int start_iteration,
int num_iterations)
const = 0;
218 virtual bool LoadModelFromString(
const char* buffer,
size_t len) = 0;
226 virtual std::vector<double>
FeatureImportance(
int num_iteration,
int importance_type)
const = 0;
272 virtual void InitPredict(
int num_iteration,
bool is_pred_contrib) = 0;
285 static bool LoadFileToBoosting(
Boosting* boosting,
const char* filename);
295 static Boosting* CreateBoosting(
const std::string& type,
const char* filename);
300 virtual double GetLeafValue(
int tree_idx,
int leaf_idx)
const = 0;
301 virtual void SetLeafValue(
int tree_idx,
int leaf_idx,
double val) = 0;
The interface for Boosting.
Definition boosting.h:22
virtual void RefitTree(const std::vector< std::vector< int > > &tree_leaf_prediction)=0
Update the tree output by new training data.
Boosting(const Boosting &)=delete
Disable copy.
virtual void Predict(const double *features, double *output, const PredictionEarlyStopInstance *early_stop) const =0
Prediction for one record, sigmoid transformation will be used if needed.
virtual std::vector< std::string > FeatureNames() const =0
Get feature names of this model.
virtual bool SaveModelToIfElse(int num_iteration, const char *filename) const =0
Translate model to if-else statement.
virtual int NumberOfTotalModel() const =0
Get number of weak sub-models.
virtual bool TrainOneIter(const score_t *gradients, const score_t *hessians)=0
Training logic.
virtual int NumberOfClasses() const =0
Get number of classes.
virtual int64_t GetNumPredictAt(int data_idx) const =0
Get prediction result at data_idx data.
virtual void ShuffleModels(int start_iter, int end_iter)=0
Shuffle Existing Models.
virtual std::vector< double > FeatureImportance(int num_iteration, int importance_type) const =0
Calculate feature importances.
virtual void MergeFrom(const Boosting *other)=0
Merge model from other boosting object Will insert to the front of current boosting object.
virtual void Init(const Config *config, const Dataset *train_data, const ObjectiveFunction *objective_function, const std::vector< const Metric * > &training_metrics)=0
Initialization logic.
virtual void InitPredict(int num_iteration, bool is_pred_contrib)=0
Initial work for the prediction.
virtual int MaxFeatureIdx() const =0
Get max feature index of this model.
Boosting & operator=(const Boosting &)=delete
Disable copy.
virtual void PredictLeafIndex(const double *features, double *output) const =0
Prediction for one record with leaf index.
virtual const double * GetTrainingScore(int64_t *out_len)=0
Get current training score.
virtual std::string ModelToIfElse(int num_iteration) const =0
Translate model to if-else statement.
std::string SaveModelToString(int num_iterations)
Save model to string.
Definition boosting.h:208
virtual void PredictContrib(const double *features, double *output, const PredictionEarlyStopInstance *early_stop) const =0
Feature contributions for the model's prediction of one record.
virtual std::string DumpModel(int start_iteration, int num_iteration) const =0
Dump model to json format string.
virtual std::vector< double > GetEvalAt(int data_idx) const =0
Get evaluation result at data_idx data.
virtual bool NeedAccuratePrediction() const =0
The prediction should be accurate or not. True will disable early stopping for prediction.
virtual int NumModelPerIteration() const =0
Get number of models per iteration.
virtual void RollbackOneIter()=0
Rollback one iteration.
virtual void PredictRaw(const double *features, double *output, const PredictionEarlyStopInstance *early_stop) const =0
Prediction for one record, not sigmoid transform.
virtual const char * SubModelName() const =0
Name of submodel.
bool LoadModelFromString(std::string str)
Restore from a serialized string.
Definition boosting.h:217
virtual void GetPredictAt(int data_idx, double *result, int64_t *out_len)=0
Get prediction result at data_idx data.
virtual bool SaveModelToFile(int start_iteration, int num_iterations, const char *filename) const =0
Save model to file.
virtual ~Boosting()
virtual destructor
Definition boosting.h:25
virtual int GetCurrentIteration() const =0
return current iteration
virtual void AddValidDataset(const Dataset *valid_data, const std::vector< const Metric * > &valid_metrics)=0
Add a validation data.
virtual int LabelIdx() const =0
Get index of label column.
The main class of data set, which are used to traning or validation.
Definition dataset.h:278
Definition boosting.h:298
The interface of Objective Function.
Definition objective_function.h:13
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float score_t
Type of score, and gradients.
Definition meta.h:26
Definition prediction_early_stop.h:11