Medial Code Documentation
Loading...
Searching...
No Matches
BART.h
1#ifndef BART_H
2#define BART_H
3#include <vector>
4#include <random>
5#include <unordered_map>
6#include <boost/random.hpp>
7#include <boost/math/special_functions/gamma.hpp>
8#include <boost/math/distributions/chi_squared.hpp>
9
10using namespace std;
11
15class bart_node {
16public:
19 float node_value;
22 vector<int> observation_indexes;
24
27
32 feature_number = -1;
33 split_index = -1;
34 node_value = 0;
35 parent = NULL; //root
36 //Childrens
37 childs[0] = NULL; //smaller than or equal
38 childs[1] = NULL; //biger than
39 mark_change = true;
42 }
43
47 bart_node(const bart_node &cp) {
48 feature_number = cp.feature_number;
49 split_index = cp.split_index;
50 node_value = cp.node_value;
51 parent = cp.parent;
52 childs[0] = cp.childs[0];
53 childs[1] = cp.childs[1];
54 observation_indexes = cp.observation_indexes;
55 mark_change = cp.mark_change;
56 num_feature_options = cp.num_feature_options;
57 num_split_options = cp.num_split_options;
58 }
59
63 void list_all_nodes(vector<bart_node *> &all_nodes);
67 int depth();
71 float variance(const vector<float> &y);
75 void deep_clone(bart_node *&target);
79 void print_tree(const vector<vector<float>> &feature_sorted_values) const; //for debug
84 void validate_tree(const vector<vector<float>> &feature_sorted_values, const vector<float> &x, int nftrs) const; //for debug
85private:
86};
87
93public:
94 vector<bart_node *>changed_nodes_before;
95 vector<bart_node *>changed_nodes_after;
96
97 int action = -1;
99};
100
101enum bart_data_prior_type {
102 regression_mean_shift = 0,
103 classification = 1
104};
105
106class bart_tree;
107
112public:
114
115 float alpha;
116 float beta;
117
118 //general:
119 bart_data_prior_type data_prior_type;
120
121 //params for prior:
122 float k;
123 float nu;
124 float lambda;
125
130 min_obs_in_node = 0;
131 alpha = 1;
132 beta = 1;
133 sigsq_mu = 0;
134 mean_mu = 0;
135 nu = 3;
136 k = 2;
137 lambda = 1;
138 data_prior_type = regression_mean_shift;
139 }
140
144 void set_classification(int num_trees) {
145 mean_mu = 0;
146 //k = 2;
147 //nu = 3.0;
148 data_prior_type = bart_data_prior_type::classification;
149
150 sigsq_mu = pow(3 / (k * sqrt(num_trees)), 2);
151 }
156 void set_regression(int num_trees, float sample_var_y_in_data) {
157 mean_mu = 0;
158 //nu = 3.0;
159 //k = 2;
160 data_prior_type = bart_data_prior_type::regression_mean_shift;
161
162 sigsq_mu = pow(1 / (2 * k * sqrt(num_trees)), 2);
163
164 boost::math::chi_squared_distribution<> chi_dist(nu);
165 double ten_pctile_chisq_df_hyper_nu = boost::math::cdf(chi_dist, 1 - 0.9);
166 lambda = sample_var_y_in_data * ten_pctile_chisq_df_hyper_nu / nu;
167 //lambda = 2.7e-3;
168 //printf("ten_pctile_chisq_df_hyper_nu=%2.5f, var=%2.5f\n", ten_pctile_chisq_df_hyper_nu, sample_var_y_in_data);
169 }
170private:
171 float mean_mu;
172 float sigsq_mu;
173
174 friend class bart_tree;
175};
176
177class BART;
178
183public:
187
193 void next_gen_tree(const vector<float> &x, const vector<float> &y);
194
199 void clone_tree(bart_tree &tree);
200
205 void set_sigma(double sig) {
206 sigma = sig;
207 }
208
213 bart_tree(const bart_tree &other) {
214 root = other.root;
216 _rnd_gen = other._rnd_gen;
217 action_priors = other.action_priors;
218 feature_to_sorted_vals = other.feature_to_sorted_vals;
219 feature_to_val_index = other.feature_to_val_index;
220 params = other.params;
221 }
222
227 root = NULL;
229 _rnd_gen = mt19937(_rd());
230
231 //Default Values:
232 action_priors = { (float)0.25, (float)0.25, (float)0.4, (float)0.1 };
234 params.alpha = (float)0.95;
235 params.beta = 1.0;
236
237 params.data_prior_type = bart_data_prior_type::regression_mean_shift;
238 params.mean_mu = 0;
239 params.nu = 3;
240 params.sigsq_mu = 0;
241 params.lambda = 1.0;
242 }
243private:
244 //general variables
245 random_device _rd;
246 mt19937 _rnd_gen;
247 vector<vector<float>> feature_to_sorted_vals;
248 vector<unordered_map<float, int>> feature_to_val_index;
249 double sigma;
250
251 vector<float> action_priors;
252 //actions:
253 tree_change_details do_grow(const vector<float> &x, const vector<float> &y);
254 tree_change_details do_prune(const vector<float> &x, const vector<float> &y);
255 tree_change_details do_change(const vector<float> &x, const vector<float> &y);
256 tree_change_details do_swap(const vector<float> &x, const vector<float> &y);
257
258 float score_leaf(const vector<float> &y, const vector<int> &obs_indexes);
259
260 double node_data_likelihood(const vector<bart_node *> &leaf_node, const vector<float> &x, int nftrs, const vector<float> &y);
261 void calc_likelihood(const vector<float> &x, int nftrs, const vector<float> &y); //will calc before mean in all leaves
262
263 //helper for change:
264 void get_avalible_change(const vector<float> &x, int nftrs, vector<bart_node *> &good_idx);
265 void get_avalible_feats_change(const bart_node *selected_node, const vector<float> &x, int nftrs, vector<int> &good_idx);
266 void propogate_change_down(bart_node *current_node, const vector<float> &x, int nftrs, vector<bart_node *> &list_nodes_after);
267
268 //helper for grow:
269 void get_avalible_grow(const vector<float> &x, int nftrs, vector<bart_node *> &good_idx) const;
270 void get_avalible_feats_grow(const bart_node *selected_node, const vector<float> &x, int nftrs, vector<int> &good_idx);
271
272 //helper for prune;
273 void get_avalible_prune(const vector<float> &x, int nftrs, vector<bart_node *> &good_idx);
274
275 //helper for swap:
276 void get_avalible_swap(const vector<float> &x, int nftrs, vector<bart_node *> &good_idx);
277
278 bool has_split(const bart_node *current_node, const vector<float> &x, int nftrs, int feature_num) const;
279 bool select_split(bart_node *current_node, const vector<float> &x, int nftrs,
280 vector<vector<int>> &split_obx_indexes);
281
282 void commit_change(const tree_change_details &change);
283 void rollback_change(const tree_change_details &change);
284
285 int clear_tree_mem(bart_node *node); //erase count
286protected:
287 void predict(const vector<float> &x, int nSamples, vector<float> &scores) const;
288 void predict_on_train(const vector<float> &x, int nSamples, vector<float> &scores) const; //faster
289
290 friend class BART;
291};
292
300class BART {
301public:
302 int ntrees;
307
314 void learn(const vector<float> &x, const vector<float> &y);
315
325 void predict(const vector<float> &x, int nSamples, vector<float> &scores) const;
326
330 BART(int ntrees, int iterations, int burn_cnt, int restart_cnt, bart_params &tree_pr) {
331 nftrs = 0;
332 //default:
333 this->ntrees = ntrees;
334 this->iter_count = iterations;
335 this->burn_count = burn_cnt;
336 this->tree_params = tree_pr;
337 this->restart_count = restart_cnt;
338
339 _trees.resize(ntrees);
340 for (size_t i = 0; i < ntrees; ++i)
341 _trees[i].params = this->tree_params;
342 trans_y_b = 0;
343 trans_y_max = 0;
344 }
345
350 void operator=(const BART &other) {
351 ntrees = other.ntrees; iter_count = other.iter_count;
353 tree_params = other.tree_params;
354 nftrs = other.nftrs;
355 _trees.resize(other._trees.size());
356 for (size_t i = 0; i < other._trees.size(); ++i) {
357 _trees[i].action_priors = other._trees[i].action_priors;
358 _trees[i].root = other._trees[i].root;
359 _trees[i].tree_loglikelihood = other._trees[i].tree_loglikelihood;
360 _trees[i]._rnd_gen = other._trees[i]._rnd_gen;
361 _trees[i].feature_to_sorted_vals = other._trees[i].feature_to_sorted_vals;
362 _trees[i].feature_to_val_index = other._trees[i].feature_to_val_index;
363 _trees[i].params = other._trees[i].params;
364 }
365 }
366
371 for (size_t i = 0; i < _trees.size(); ++i)
372 _trees[i].clear_tree_mem(_trees[i].root);
373 }
374
375private:
376 int nftrs;
377 vector<bart_tree> _trees;
378 float trans_y_b; //movement of y values
379 float trans_y_max;
380
381 void transform_y(vector<float> &y);
382 void untransform_y(vector<float> &y) const;
383
384 void init_hyper_parameters(const vector<float> &residuals);
385 void update_sigma_param(boost::mt19937 &rng, const vector<float> &residuals, double &sigma); //for regression
386 void update_latent_z_params(boost::random::random_number_generator<boost::mt19937> &rng_gen,
387 const vector<float> &x, const vector<float> &y, const vector<bart_tree> &forest_trees,
388 vector<float> &residuals); //for classification
389};
390
391#endif // !BART_H
Bayesian Additive Regression Trees.
Definition BART.h:300
BART(int ntrees, int iterations, int burn_cnt, int restart_cnt, bart_params &tree_pr)
a simple default ctor
Definition BART.h:330
~BART()
a dctor to free all tree memory
Definition BART.h:370
void operator=(const BART &other)
a simple assignment operator to shallow copy all BART model with all trees.
Definition BART.h:350
void learn(const vector< float > &x, const vector< float > &y)
learning on x vector which represents matrix.
Definition BART.cpp:1033
int iter_count
the number of steps to call next_gen_tree for each tree
Definition BART.h:303
bart_params tree_params
additional tree parameters
Definition BART.h:306
int burn_count
the burn count
Definition BART.h:304
void predict(const vector< float > &x, int nSamples, vector< float > &scores) const
prediction on x vector which represents matrix
Definition BART.cpp:1181
int restart_count
number of restarts
Definition BART.h:305
int ntrees
The nubmer of trees/restarts.
Definition BART.h:302
bart tree node
Definition BART.h:15
int feature_number
feature number in node for split
Definition BART.h:17
void validate_tree(const vector< vector< float > > &feature_sorted_values, const vector< float > &x, int nftrs) const
for debug - validating the tree structure is correct with childs\parents pointers and that the observ...
Definition BART.cpp:177
float node_value
the output value for input that reaches this node
Definition BART.h:19
int num_feature_options
number of features to select for likelihood calc
Definition BART.h:25
bart_node * parent
the parent node
Definition BART.h:20
void print_tree(const vector< vector< float > > &feature_sorted_values) const
printing tree from current node
Definition BART.cpp:157
void deep_clone(bart_node *&target)
deep copying of the node and all it's acendents by allocating new copies of all nodes
Definition BART.cpp:902
vector< int > observation_indexes
the indexes of observations in this node
Definition BART.h:22
void list_all_nodes(vector< bart_node * > &all_nodes)
populating all_nodes with flatted array of all nodes from current node including this node.
Definition BART.cpp:366
int depth()
returning node depth
Definition BART.cpp:814
bool mark_change
mark change for calculating node_value
Definition BART.h:23
float variance(const vector< float > &y)
calculating node variance in the observations in the node
Definition BART.cpp:824
bart_node()
a simple default ctor
Definition BART.h:31
bart_node * childs[2]
the left,right childs
Definition BART.h:21
int num_split_options
number of split value to select after feature select for likelihood calc
Definition BART.h:26
bart_node(const bart_node &cp)
a copy ctor shallow copy.
Definition BART.h:47
int split_index
the split index of the feature sorted unique value as split value in node
Definition BART.h:18
bart tree parameters
Definition BART.h:111
float nu
the node-data dict params for sigma_i: sigma_i ~ IG(nu, mean_sigma*lambda/2)
Definition BART.h:123
void set_regression(int num_trees, float sample_var_y_in_data)
an initializer for regression problems.
Definition BART.h:156
float lambda
the node-data dict params for sigma_i: sigma_i ~ IG(mean_sigma/2, mean_sigma*lambda/2)
Definition BART.h:124
float k
the range for bandwidth interval
Definition BART.h:122
float beta
prior for tree structure: alpha * (1 + depth(node)) ^ -beta
Definition BART.h:116
bart_params()
a simple default ctor
Definition BART.h:129
int min_obs_in_node
minimal allowed observations in node
Definition BART.h:113
void set_classification(int num_trees)
an initializer for classification problems
Definition BART.h:144
float alpha
prior for tree structure: alpha * (1 + depth(node)) ^ -beta
Definition BART.h:115
bart tree
Definition BART.h:182
void clone_tree(bart_tree &tree)
a function to clone bart_tree from root node deep clone - creating a copy of each of the nodes data
Definition BART.cpp:912
bart_tree()
a simple default ctor
Definition BART.h:226
void next_gen_tree(const vector< float > &x, const vector< float > &y)
creating next move in MCMC from current tree using metropolis hasting algorithm
Definition BART.cpp:221
bart_tree(const bart_tree &other)
a copy ctor preforming a shallow copy: not allocating memory for all tree nodes again.
Definition BART.h:213
void set_sigma(double sig)
setting the parameter sigma - should happen before generation new tree using metropolis hasting
Definition BART.h:205
double tree_loglikelihood
the tree likelihood based on tree prior and tree match to data
Definition BART.h:185
bart_params params
the barat params
Definition BART.h:186
bart_node * root
the tree root
Definition BART.h:184
A Class to represnet change in tree - for rollback or release memory in commit.
Definition BART.h:92
int num_node_selection
the number of nodes to select
Definition BART.h:98
vector< bart_node * > changed_nodes_after
the node after changes
Definition BART.h:95
vector< bart_node * > changed_nodes_before
the node before changes
Definition BART.h:94
int action
the action
Definition BART.h:97
Definition StdDeque.h:58