Medial Code Documentation
Loading...
Searching...
No Matches
MedBART.h
1#ifndef _MED_BART_H_
2#define _MED_BART_H_
3#include "MedAlgo.h"
4#include "BART.h"
5
11class MedBART : public MedPredictor {
12public:
13
18 void init_defaults() {};
19 int set_params(map<string, string>& mapper);
20
27 int Learn(float *x, float *y, const float *w, int nsamples, int nftrs);
38 int Predict(float *x, float *&preds, int nsamples, int nftrs) const;
39
43 MedBART() : _model(0, 0, 0, 0, tree_params) {
44 normalize_for_learn = false;
45 transpose_for_learn = false;
50
51 //default params:
52 ntrees = 50;
53 iter_count = 1000;
54 burn_count = 250;
55 restart_count = 5;
56 tree_params.alpha = (float)0.95;
57 tree_params.beta = 1;
58 tree_params.min_obs_in_node = 5;
59
60 tree_params.k = 2;
61 tree_params.nu = 3;
62 tree_params.lambda = 1;
63 }
64
65 ADD_CLASS_NAME(MedBART)
66 ADD_SERIALIZATION_FUNCS(classifier_type, ntrees, iter_count, burn_count, restart_count)
67private:
68 int ntrees;
69 int iter_count;
70 int burn_count;
71 int restart_count;
72 bart_params tree_params;
73
74 BART _model;
75};
76
78
79#endif // !_MED_BART_H_
80
MedAlgo - APIs to different algorithms: Linear Models, RF, GBM, KNN, and more.
@ MODEL_BART
to_use:"bart" MedBART model using BART
Definition MedAlgo.h:63
#define ADD_SERIALIZATION_FUNCS(...)
Definition SerializableObject.h:122
#define MEDSERIALIZE_SUPPORT(Type)
Definition SerializableObject.h:108
Bayesian Additive Regression Trees.
Definition BART.h:300
a wrapper for BART class model.
Definition MedBART.h:11
void init_defaults()
an initialization for model
Definition MedBART.h:18
int set_params(map< string, string > &mapper)
Definition MedBART.cpp:27
MedBART()
a simple default ctor
Definition MedBART.h:43
int Predict(float *x, float *&preds, int nsamples, int nftrs) const
prediction on x vector which represents matrix
Definition MedBART.cpp:16
int Learn(float *x, float *y, const float *w, int nsamples, int nftrs)
learning on x vector which represents matrix.
Definition MedBART.cpp:6
Base Interface for predictor.
Definition MedAlgo.h:78
bool normalize_for_learn
True if need to normalize before learn.
Definition MedAlgo.h:87
bool transpose_for_predict
True if need to transpose before predict.
Definition MedAlgo.h:90
bool normalize_for_predict
True if need to normalize before predict.
Definition MedAlgo.h:91
bool normalize_y_for_learn
True if need to normalize labels before learn.
Definition MedAlgo.h:88
MedPredictorTypes classifier_type
The Predicotr enum type.
Definition MedAlgo.h:80
bool transpose_for_learn
True if need to transpose before learn.
Definition MedAlgo.h:86
bart tree parameters
Definition BART.h:111
float nu
the node-data dict params for sigma_i: sigma_i ~ IG(nu, mean_sigma*lambda/2)
Definition BART.h:123
float lambda
the node-data dict params for sigma_i: sigma_i ~ IG(mean_sigma/2, mean_sigma*lambda/2)
Definition BART.h:124
float k
the range for bandwidth interval
Definition BART.h:122
float beta
prior for tree structure: alpha * (1 + depth(node)) ^ -beta
Definition BART.h:116
int min_obs_in_node
minimal allowed observations in node
Definition BART.h:113
float alpha
prior for tree structure: alpha * (1 + depth(node)) ^ -beta
Definition BART.h:115