Medial Code Documentation
Loading...
Searching...
No Matches
tree_learner.h
1#ifndef LIGHTGBM_TREE_LEARNER_H_
2#define LIGHTGBM_TREE_LEARNER_H_
3
4
5#include <LightGBM/meta.h>
6#include <LightGBM/config.h>
7#include <LightGBM/json11.hpp>
8
9#include <vector>
10
11using namespace json11;
12
13namespace LightGBM {
14
16class Tree;
17class Dataset;
18class ObjectiveFunction;
19
24public:
26 virtual ~TreeLearner() {}
27
33 virtual void Init(const Dataset* train_data, bool is_constant_hessian) = 0;
34
35 virtual void ResetTrainingData(const Dataset* train_data) = 0;
36
41 virtual void ResetConfig(const Config* config) = 0;
42
50 virtual Tree* Train(const score_t* gradients, const score_t* hessians, bool is_constant_hessian,
51 Json& forced_split_json) = 0;
52
56 virtual Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const = 0;
57
58 virtual Tree* FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred,
59 const score_t* gradients, const score_t* hessians) = 0;
60
66 virtual void SetBaggingData(const data_size_t* used_indices,
67 data_size_t num_data) = 0;
68
73 virtual void AddPredictionToScore(const Tree* tree, double* out_score) const = 0;
74
75 virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction,
76 data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const = 0;
77
78 virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, double prediction,
79 data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const = 0;
80
81 TreeLearner() = default;
85 TreeLearner(const TreeLearner&) = delete;
86
93 static TreeLearner* CreateTreeLearner(const std::string& learner_type,
94 const std::string& device_type,
95 const Config* config);
96};
97
98} // namespace LightGBM
99
100#endif // LightGBM_TREE_LEARNER_H_
The main class of data set, which are used to traning or validation.
Definition dataset.h:278
The interface of Objective Function.
Definition objective_function.h:13
Interface for tree learner.
Definition tree_learner.h:23
virtual void Init(const Dataset *train_data, bool is_constant_hessian)=0
Initialize tree learner with training dataset.
TreeLearner & operator=(const TreeLearner &)=delete
Disable copy.
virtual Tree * Train(const score_t *gradients, const score_t *hessians, bool is_constant_hessian, Json &forced_split_json)=0
training tree model on dataset
virtual Tree * FitByExistingTree(const Tree *old_tree, const score_t *gradients, const score_t *hessians) const =0
use a existing tree to fit the new gradients and hessians.
virtual void SetBaggingData(const data_size_t *used_indices, data_size_t num_data)=0
Set bagging data.
static TreeLearner * CreateTreeLearner(const std::string &learner_type, const std::string &device_type, const Config *config)
Create object of tree learner.
Definition tree_learner.cpp:9
virtual void ResetConfig(const Config *config)=0
Reset tree configs.
virtual ~TreeLearner()
virtual destructor
Definition tree_learner.h:26
TreeLearner(const TreeLearner &)=delete
Disable copy.
virtual void AddPredictionToScore(const Tree *tree, double *out_score) const =0
Using last trained tree to predict score then adding to out_score;.
Tree model.
Definition tree.h:20
Definition json11.hpp:79
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float score_t
Type of score, and gradients.
Definition meta.h:26
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14
Definition config.h:27