1#ifndef LIGHTGBM_TREE_H_
2#define LIGHTGBM_TREE_H_
4#include <LightGBM/meta.h>
5#include <LightGBM/dataset.h>
14#define kCategoricalMask (1)
15#define kDefaultLeftMask (2)
26 explicit Tree(
int max_leaves);
33 Tree(
const char* str,
size_t* used_len);
53 int Split(
int leaf,
int feature,
int real_feature, uint32_t threshold_bin,
54 double threshold_double,
double left_value,
double right_value,
55 int left_cnt,
int right_cnt,
float gain, MissingType missing_type,
bool default_left);
73 int SplitCategorical(
int leaf,
int feature,
int real_feature,
const uint32_t* threshold_bin,
int num_threshold_bin,
74 const uint32_t* threshold,
int num_threshold,
double left_value,
double right_value,
75 int left_cnt,
int right_cnt,
float gain, MissingType missing_type);
78 inline double LeafOutput(
int leaf)
const {
return leaf_value_[leaf]; }
82 leaf_value_[leaf] = output;
111 inline double Predict(
const double* feature_values)
const;
112 inline double PredictByMap(
const std::unordered_map<int, double>& feature_values)
const;
114 inline int PredictLeafIndex(
const double* feature_values)
const;
115 inline int PredictLeafIndexByMap(
const std::unordered_map<int, double>& feature_values)
const;
118 inline void PredictContrib(
const double* feature_values,
int num_features,
double* output);
124 inline int leaf_depth(
int leaf_idx)
const {
return leaf_depth_[leaf_idx]; }
127 inline int split_feature(
int split_idx)
const {
return split_feature_[split_idx]; }
129 inline double split_gain(
int split_idx)
const {
return split_gain_[split_idx]; }
132 inline int data_count(
int node)
const {
return node >= 0 ? internal_count_[node] : leaf_count_[~node]; }
140 #pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
141 for (
int i = 0; i < num_leaves_; ++i) {
142 leaf_value_[i] *= rate;
147 inline double shrinkage()
const {
151 inline void AddBias(
double val) {
152 #pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
153 for (
int i = 0; i < num_leaves_; ++i) {
154 leaf_value_[i] = val + leaf_value_[i];
160 inline void AsConstantTree(
double val) {
163 leaf_value_[0] = val;
170 std::string
ToJSON()
const;
173 std::string
ToIfElse(
int index,
bool predict_leaf_index)
const;
175 inline static bool IsZero(
double fval) {
176 if (fval > -kZeroThreshold && fval <= kZeroThreshold) {
183 inline static bool GetDecisionType(int8_t decision_type, int8_t mask) {
184 return (decision_type & mask) > 0;
187 inline static void SetDecisionType(int8_t* decision_type,
bool input, int8_t mask) {
189 (*decision_type) |= mask;
191 (*decision_type) &= (127 - mask);
195 inline static int8_t GetMissingType(int8_t decision_type) {
196 return (decision_type >> 2) & 3;
199 inline static void SetMissingType(int8_t* decision_type, int8_t input) {
200 (*decision_type) &= 3;
201 (*decision_type) |= (input << 2);
204 void RecomputeMaxDepth();
207 std::string NumericalDecisionIfElse(
int node)
const;
209 std::string CategoricalDecisionIfElse(
int node)
const;
211 inline int NumericalDecision(
double fval,
int node)
const {
212 uint8_t missing_type = GetMissingType(decision_type_[node]);
213 if (std::isnan(fval)) {
214 if (missing_type != 2) {
218 if ((missing_type == 1 && IsZero(fval))
219 || (missing_type == 2 && std::isnan(fval))) {
220 if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) {
221 return left_child_[node];
223 return right_child_[node];
226 if (fval <= threshold_[node]) {
227 return left_child_[node];
229 return right_child_[node];
233 inline int NumericalDecisionInner(uint32_t fval,
int node, uint32_t default_bin, uint32_t max_bin)
const {
234 uint8_t missing_type = GetMissingType(decision_type_[node]);
235 if ((missing_type == 1 && fval == default_bin)
236 || (missing_type == 2 && fval == max_bin)) {
237 if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) {
238 return left_child_[node];
240 return right_child_[node];
243 if (fval <= threshold_in_bin_[node]) {
244 return left_child_[node];
246 return right_child_[node];
250 inline int CategoricalDecision(
double fval,
int node)
const {
251 uint8_t missing_type = GetMissingType(decision_type_[node]);
252 int int_fval =
static_cast<int>(fval);
254 return right_child_[node];;
255 }
else if (std::isnan(fval)) {
257 if (missing_type == 2) {
258 return right_child_[node];
262 int cat_idx = int(threshold_[node]);
263 if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
264 cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], int_fval)) {
265 return left_child_[node];
267 return right_child_[node];
270 inline int CategoricalDecisionInner(uint32_t fval,
int node)
const {
271 int cat_idx = int(threshold_in_bin_[node]);
272 if (Common::FindInBitset(cat_threshold_inner_.data() + cat_boundaries_inner_[cat_idx],
273 cat_boundaries_inner_[cat_idx + 1] - cat_boundaries_inner_[cat_idx], fval)) {
274 return left_child_[node];
276 return right_child_[node];
279 inline int Decision(
double fval,
int node)
const {
280 if (GetDecisionType(decision_type_[node], kCategoricalMask)) {
281 return CategoricalDecision(fval, node);
283 return NumericalDecision(fval, node);
287 inline int DecisionInner(uint32_t fval,
int node, uint32_t default_bin, uint32_t max_bin)
const {
288 if (GetDecisionType(decision_type_[node], kCategoricalMask)) {
289 return CategoricalDecisionInner(fval, node);
291 return NumericalDecisionInner(fval, node, default_bin, max_bin);
295 inline void Split(
int leaf,
int feature,
int real_feature,
296 double left_value,
double right_value,
int left_cnt,
int right_cnt,
float gain);
302 inline int GetLeaf(
const double* feature_values)
const;
303 inline int GetLeafByMap(
const std::unordered_map<int, double>& feature_values)
const;
306 std::string NodeToJSON(
int index)
const;
309 std::string NodeToIfElse(
int index,
bool predict_leaf_index)
const;
311 std::string NodeToIfElseByMap(
int index,
bool predict_leaf_index)
const;
313 double ExpectedValue()
const;
316 inline void RecomputeLeafDepths(
int node = 0,
int depth = 0);
323 double zero_fraction;
331 PathElement(
int i,
double z,
double o,
double w) : feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
335 void TreeSHAP(
const double *feature_values,
double *phi,
336 int node,
int unique_depth,
337 PathElement *parent_unique_path,
double parent_zero_fraction,
338 double parent_one_fraction,
int parent_feature_index)
const;
341 static void ExtendPath(
PathElement *unique_path,
int unique_depth,
342 double zero_fraction,
double one_fraction,
int feature_index);
345 static void UnwindPath(
PathElement *unique_path,
int unique_depth,
int path_index);
348 static double UnwoundPathSum(
const PathElement *unique_path,
int unique_depth,
int path_index);
356 std::vector<int> left_child_;
358 std::vector<int> right_child_;
360 std::vector<int> split_feature_inner_;
362 std::vector<int> split_feature_;
364 std::vector<uint32_t> threshold_in_bin_;
366 std::vector<double> threshold_;
368 std::vector<int> cat_boundaries_inner_;
369 std::vector<uint32_t> cat_threshold_inner_;
370 std::vector<int> cat_boundaries_;
371 std::vector<uint32_t> cat_threshold_;
373 std::vector<int8_t> decision_type_;
375 std::vector<float> split_gain_;
378 std::vector<int> leaf_parent_;
380 std::vector<double> leaf_value_;
382 std::vector<int> leaf_count_;
384 std::vector<double> internal_value_;
386 std::vector<int> internal_count_;
388 std::vector<int> leaf_depth_;
393inline void Tree::Split(
int leaf,
int feature,
int real_feature,
394 double left_value,
double right_value,
int left_cnt,
int right_cnt,
float gain) {
395 int new_node_idx = num_leaves_ - 1;
397 int parent = leaf_parent_[leaf];
400 if (left_child_[parent] == ~leaf) {
401 left_child_[parent] = new_node_idx;
403 right_child_[parent] = new_node_idx;
407 split_feature_inner_[new_node_idx] = feature;
408 split_feature_[new_node_idx] = real_feature;
410 split_gain_[new_node_idx] = Common::AvoidInf(gain);
412 left_child_[new_node_idx] = ~leaf;
413 right_child_[new_node_idx] = ~num_leaves_;
415 leaf_parent_[leaf] = new_node_idx;
416 leaf_parent_[num_leaves_] = new_node_idx;
418 internal_value_[new_node_idx] = leaf_value_[leaf];
419 internal_count_[new_node_idx] = left_cnt + right_cnt;
420 leaf_value_[leaf] = std::isnan(left_value) ? 0.0f : left_value;
421 leaf_count_[leaf] = left_cnt;
422 leaf_value_[num_leaves_] = std::isnan(right_value) ? 0.0f : right_value;
423 leaf_count_[num_leaves_] = right_cnt;
425 leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
430 if (num_leaves_ > 1) {
431 int leaf = GetLeaf(feature_values);
434 return leaf_value_[0];
438inline double Tree::PredictByMap(
const std::unordered_map<int, double>& feature_values)
const {
439 if (num_leaves_ > 1) {
440 int leaf = GetLeafByMap(feature_values);
443 return leaf_value_[0];
447inline int Tree::PredictLeafIndex(
const double* feature_values)
const {
448 if (num_leaves_ > 1) {
449 int leaf = GetLeaf(feature_values);
456inline int Tree::PredictLeafIndexByMap(
const std::unordered_map<int, double>& feature_values)
const {
457 if (num_leaves_ > 1) {
458 int leaf = GetLeafByMap(feature_values);
465inline void Tree::PredictContrib(
const double* feature_values,
int num_features,
double* output) {
466 output[num_features] += ExpectedValue();
468 if (num_leaves_ > 1) {
469 CHECK(max_depth_ >= 0);
470 const int max_path_len = max_depth_ + 1;
471 std::vector<PathElement> unique_path_data(max_path_len*(max_path_len + 1) / 2);
472 TreeSHAP(feature_values, output, 0, 0, unique_path_data.data(), 1, 1, -1);
476inline void Tree::RecomputeLeafDepths(
int node,
int depth) {
477 if (node == 0) leaf_depth_.resize(
num_leaves());
479 leaf_depth_[~node] = depth;
481 RecomputeLeafDepths(left_child_[node], depth + 1);
482 RecomputeLeafDepths(right_child_[node], depth + 1);
486inline int Tree::GetLeaf(
const double* feature_values)
const {
490 node = Decision(feature_values[split_feature_[node]], node);
494 node = NumericalDecision(feature_values[split_feature_[node]], node);
500inline int Tree::GetLeafByMap(
const std::unordered_map<int, double>& feature_values)
const {
504 node = Decision(feature_values.count(split_feature_[node]) > 0 ? feature_values.at(split_feature_[node]) : 0.0f, node);
508 node = NumericalDecision(feature_values.count(split_feature_[node]) > 0 ? feature_values.at(split_feature_[node]) : 0.0f, node);
The main class of data set, which are used to traning or validation.
Definition dataset.h:278
Tree model.
Definition tree.h:20
int split_feature(int split_idx) const
Get feature of specific split.
Definition tree.h:127
int leaf_depth(int leaf_idx) const
Get depth of specific leaf.
Definition tree.h:124
int num_leaves() const
Get Number of leaves.
Definition tree.h:121
double Predict(const double *feature_values) const
Prediction on one record.
Definition tree.h:429
void AddPredictionToScore(const Dataset *data, data_size_t num_data, double *score) const
Adding prediction value of this tree model to scores.
Definition tree.cpp:113
int SplitCategorical(int leaf, int feature, int real_feature, const uint32_t *threshold_bin, int num_threshold_bin, const uint32_t *threshold, int num_threshold, double left_value, double right_value, int left_cnt, int right_cnt, float gain, MissingType missing_type)
Performing a split on tree leaves, with categorical feature.
Definition tree.cpp:70
void Shrinkage(double rate)
Shrinkage for the tree's output shrinkage rate (a.k.a learning rate) is used to tune the traning proc...
Definition tree.h:139
std::string ToIfElse(int index, bool predict_leaf_index) const
Serialize this object to if-else statement.
Definition tree.cpp:353
double LeafOutput(int leaf) const
Get the output of one leaf.
Definition tree.h:78
std::string ToString() const
Serialize this object to string.
Definition tree.cpp:207
int data_count(int node) const
Get the number of data points that fall at or below this node.
Definition tree.h:132
int Split(int leaf, int feature, int real_feature, uint32_t threshold_bin, double threshold_double, double left_value, double right_value, int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left)
Performing a split on tree leaves.
Definition tree.cpp:49
void SetLeafOutput(int leaf, double output)
Set the output of one leaf.
Definition tree.h:81
std::string ToJSON() const
Serialize this object to json.
Definition tree.cpp:242
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14
Definition tree_shap.h:108