5#include <unordered_map>
6#include <boost/random.hpp>
7#include <boost/math/special_functions/gamma.hpp>
8#include <boost/math/distributions/chi_squared.hpp>
71 float variance(
const vector<float> &y);
79 void print_tree(
const vector<vector<float>> &feature_sorted_values)
const;
84 void validate_tree(
const vector<vector<float>> &feature_sorted_values,
const vector<float> &x,
int nftrs)
const;
101enum bart_data_prior_type {
102 regression_mean_shift = 0,
119 bart_data_prior_type data_prior_type;
138 data_prior_type = regression_mean_shift;
148 data_prior_type = bart_data_prior_type::classification;
150 sigsq_mu = pow(3 / (
k * sqrt(num_trees)), 2);
160 data_prior_type = bart_data_prior_type::regression_mean_shift;
162 sigsq_mu = pow(1 / (2 *
k * sqrt(num_trees)), 2);
164 boost::math::chi_squared_distribution<> chi_dist(
nu);
165 double ten_pctile_chisq_df_hyper_nu = boost::math::cdf(chi_dist, 1 - 0.9);
166 lambda = sample_var_y_in_data * ten_pctile_chisq_df_hyper_nu /
nu;
193 void next_gen_tree(
const vector<float> &x,
const vector<float> &y);
216 _rnd_gen = other._rnd_gen;
217 action_priors = other.action_priors;
218 feature_to_sorted_vals = other.feature_to_sorted_vals;
219 feature_to_val_index = other.feature_to_val_index;
229 _rnd_gen = mt19937(_rd());
232 action_priors = { (float)0.25, (
float)0.25, (float)0.4, (
float)0.1 };
237 params.data_prior_type = bart_data_prior_type::regression_mean_shift;
247 vector<vector<float>> feature_to_sorted_vals;
248 vector<unordered_map<float, int>> feature_to_val_index;
251 vector<float> action_priors;
258 float score_leaf(
const vector<float> &y,
const vector<int> &obs_indexes);
260 double node_data_likelihood(
const vector<bart_node *> &leaf_node,
const vector<float> &x,
int nftrs,
const vector<float> &y);
261 void calc_likelihood(
const vector<float> &x,
int nftrs,
const vector<float> &y);
264 void get_avalible_change(
const vector<float> &x,
int nftrs, vector<bart_node *> &good_idx);
265 void get_avalible_feats_change(
const bart_node *selected_node,
const vector<float> &x,
int nftrs, vector<int> &good_idx);
266 void propogate_change_down(
bart_node *current_node,
const vector<float> &x,
int nftrs, vector<bart_node *> &list_nodes_after);
269 void get_avalible_grow(
const vector<float> &x,
int nftrs, vector<bart_node *> &good_idx)
const;
270 void get_avalible_feats_grow(
const bart_node *selected_node,
const vector<float> &x,
int nftrs, vector<int> &good_idx);
273 void get_avalible_prune(
const vector<float> &x,
int nftrs, vector<bart_node *> &good_idx);
276 void get_avalible_swap(
const vector<float> &x,
int nftrs, vector<bart_node *> &good_idx);
278 bool has_split(
const bart_node *current_node,
const vector<float> &x,
int nftrs,
int feature_num)
const;
279 bool select_split(
bart_node *current_node,
const vector<float> &x,
int nftrs,
280 vector<vector<int>> &split_obx_indexes);
287 void predict(
const vector<float> &x,
int nSamples, vector<float> &scores)
const;
288 void predict_on_train(
const vector<float> &x,
int nSamples, vector<float> &scores)
const;
314 void learn(
const vector<float> &x,
const vector<float> &y);
325 void predict(
const vector<float> &x,
int nSamples, vector<float> &scores)
const;
334 this->iter_count = iterations;
335 this->burn_count = burn_cnt;
336 this->tree_params = tree_pr;
337 this->restart_count = restart_cnt;
340 for (
size_t i = 0; i <
ntrees; ++i)
341 _trees[i].params = this->tree_params;
355 _trees.resize(other._trees.size());
356 for (
size_t i = 0; i < other._trees.size(); ++i) {
357 _trees[i].action_priors = other._trees[i].action_priors;
358 _trees[i].root = other._trees[i].root;
359 _trees[i].tree_loglikelihood = other._trees[i].tree_loglikelihood;
360 _trees[i]._rnd_gen = other._trees[i]._rnd_gen;
361 _trees[i].feature_to_sorted_vals = other._trees[i].feature_to_sorted_vals;
362 _trees[i].feature_to_val_index = other._trees[i].feature_to_val_index;
363 _trees[i].params = other._trees[i].params;
371 for (
size_t i = 0; i < _trees.size(); ++i)
372 _trees[i].clear_tree_mem(_trees[i].root);
377 vector<bart_tree> _trees;
381 void transform_y(vector<float> &y);
382 void untransform_y(vector<float> &y)
const;
384 void init_hyper_parameters(
const vector<float> &residuals);
385 void update_sigma_param(boost::mt19937 &rng,
const vector<float> &residuals,
double &sigma);
386 void update_latent_z_params(boost::random::random_number_generator<boost::mt19937> &rng_gen,
387 const vector<float> &x,
const vector<float> &y,
const vector<bart_tree> &forest_trees,
388 vector<float> &residuals);
Bayesian Additive Regression Trees.
Definition BART.h:300
BART(int ntrees, int iterations, int burn_cnt, int restart_cnt, bart_params &tree_pr)
a simple default ctor
Definition BART.h:330
~BART()
a dctor to free all tree memory
Definition BART.h:370
void operator=(const BART &other)
a simple assignment operator to shallow copy all BART model with all trees.
Definition BART.h:350
void learn(const vector< float > &x, const vector< float > &y)
learning on x vector which represents matrix.
Definition BART.cpp:1033
int iter_count
the number of steps to call next_gen_tree for each tree
Definition BART.h:303
bart_params tree_params
additional tree parameters
Definition BART.h:306
int burn_count
the burn count
Definition BART.h:304
void predict(const vector< float > &x, int nSamples, vector< float > &scores) const
prediction on x vector which represents matrix
Definition BART.cpp:1181
int restart_count
number of restarts
Definition BART.h:305
int ntrees
The nubmer of trees/restarts.
Definition BART.h:302
bart tree node
Definition BART.h:15
int feature_number
feature number in node for split
Definition BART.h:17
void validate_tree(const vector< vector< float > > &feature_sorted_values, const vector< float > &x, int nftrs) const
for debug - validating the tree structure is correct with childs\parents pointers and that the observ...
Definition BART.cpp:177
float node_value
the output value for input that reaches this node
Definition BART.h:19
int num_feature_options
number of features to select for likelihood calc
Definition BART.h:25
bart_node * parent
the parent node
Definition BART.h:20
void print_tree(const vector< vector< float > > &feature_sorted_values) const
printing tree from current node
Definition BART.cpp:157
void deep_clone(bart_node *&target)
deep copying of the node and all it's acendents by allocating new copies of all nodes
Definition BART.cpp:902
vector< int > observation_indexes
the indexes of observations in this node
Definition BART.h:22
void list_all_nodes(vector< bart_node * > &all_nodes)
populating all_nodes with flatted array of all nodes from current node including this node.
Definition BART.cpp:366
int depth()
returning node depth
Definition BART.cpp:814
bool mark_change
mark change for calculating node_value
Definition BART.h:23
float variance(const vector< float > &y)
calculating node variance in the observations in the node
Definition BART.cpp:824
bart_node()
a simple default ctor
Definition BART.h:31
bart_node * childs[2]
the left,right childs
Definition BART.h:21
int num_split_options
number of split value to select after feature select for likelihood calc
Definition BART.h:26
bart_node(const bart_node &cp)
a copy ctor shallow copy.
Definition BART.h:47
int split_index
the split index of the feature sorted unique value as split value in node
Definition BART.h:18
bart tree parameters
Definition BART.h:111
float nu
the node-data dict params for sigma_i: sigma_i ~ IG(nu, mean_sigma*lambda/2)
Definition BART.h:123
void set_regression(int num_trees, float sample_var_y_in_data)
an initializer for regression problems.
Definition BART.h:156
float lambda
the node-data dict params for sigma_i: sigma_i ~ IG(mean_sigma/2, mean_sigma*lambda/2)
Definition BART.h:124
float k
the range for bandwidth interval
Definition BART.h:122
float beta
prior for tree structure: alpha * (1 + depth(node)) ^ -beta
Definition BART.h:116
bart_params()
a simple default ctor
Definition BART.h:129
int min_obs_in_node
minimal allowed observations in node
Definition BART.h:113
void set_classification(int num_trees)
an initializer for classification problems
Definition BART.h:144
float alpha
prior for tree structure: alpha * (1 + depth(node)) ^ -beta
Definition BART.h:115
bart tree
Definition BART.h:182
void clone_tree(bart_tree &tree)
a function to clone bart_tree from root node deep clone - creating a copy of each of the nodes data
Definition BART.cpp:912
bart_tree()
a simple default ctor
Definition BART.h:226
void next_gen_tree(const vector< float > &x, const vector< float > &y)
creating next move in MCMC from current tree using metropolis hasting algorithm
Definition BART.cpp:221
bart_tree(const bart_tree &other)
a copy ctor preforming a shallow copy: not allocating memory for all tree nodes again.
Definition BART.h:213
void set_sigma(double sig)
setting the parameter sigma - should happen before generation new tree using metropolis hasting
Definition BART.h:205
double tree_loglikelihood
the tree likelihood based on tree prior and tree match to data
Definition BART.h:185
bart_params params
the barat params
Definition BART.h:186
bart_node * root
the tree root
Definition BART.h:184
A Class to represnet change in tree - for rollback or release memory in commit.
Definition BART.h:92
int num_node_selection
the number of nodes to select
Definition BART.h:98
vector< bart_node * > changed_nodes_after
the node after changes
Definition BART.h:95
vector< bart_node * > changed_nodes_before
the node before changes
Definition BART.h:94
int action
the action
Definition BART.h:97