Medial Code Documentation
Loading...
Searching...
No Matches
ExplainWrapper.h
Go to the documentation of this file.
1
2#ifndef __EXPLAIN_WRAPPER_H__
3#define __EXPLAIN_WRAPPER_H__
4
5#include <vector>
6#include <string>
9#include <MedStat/MedStat/GibbsSampler.h>
10#include <MedAlgo/MedAlgo/tree_shap.h>
11#include <MedAlgo/MedAlgo/SamplesGenerator.h>
12
13using namespace std;
14
19public:
22 float sum_ratio;
23
25
26 int init(map<string, string> &map);
27
29 void filter(map<string, float> &explain_list) const;
30
31 ADD_CLASS_NAME(ExplainFilters)
33
34};
35
40private:
41 bool postprocessing_cov = false;
42public:
43 bool group_by_sum = false;
44 bool learn_cov_matrix = false;
46 int zero_missing = 0;
47 bool keep_b0 = false;
48 bool iterative = false;
49 int iteration_cnt = 0;
50 bool use_max_cov = false;
51
54
56
57 string grouping;
58 vector<vector<int>> group2Inds;
59 vector<string> groupNames;
60 map<string, vector<int>> groupName2Inds;
61
63
64 int init(map<string, string> &map);
65
67 void learn(const MedFeatures &train_mat);
68
70 void process(map<string, float> &explain_list) const;
71 // same as process but zero-ing all contributions of missing values features and groups with all the participants inside missing
72 void process(map<string, float> &explain_list, unsigned char *missing_value_mask) const;
73
75 float get_group_normalized_contrib(const vector<int> &group_inds, vector<float> &contribs, float total_normalization_factor) const;
76
77 void post_deserialization();
78
80 static void read_feature_grouping(const string &file_name, const vector<string>& features, vector<vector<int>>& group2index,
81 vector<string>& group_names, bool verbose = true);
82
83 ADD_CLASS_NAME(ExplainProcessings)
86};
87
92public:
93 string attr_name = "";
94 bool store_as_json = false;
95 bool denorm_features = true;
96
97 //No init - will be initialized directly in ModelExplainer::init
98
99 ADD_CLASS_NAME(GlobalExplainerParams)
101};
102
107private:
109 virtual void _init(map<string, string> &mapper) = 0;
111 unordered_map<string, const FeatureNormalizer *> feats_to_norm;
112public:
116 GlobalExplainerParams global_explain_params;
117
119 virtual int init(map<string, string> &mapper);
120
121 virtual int update(map<string, string>& mapper);
123 virtual void _learn(const MedFeatures &train_mat) = 0;
124
126 virtual void Learn(const MedFeatures &train_mat);
127 void Apply(MedFeatures &matrix) { explain(matrix); }
128
129 void get_input_fields(vector<Effected_Field> &fields) const;
130 void get_output_fields(vector<Effected_Field> &fields) const;
131
133 void init_post_processor(MedModel& model);
134
136 virtual void explain(const MedFeatures &matrix, vector<map<string, float>> &sample_explain_reasons) const = 0;
137
139 virtual void explain(MedFeatures &matrix) const; //stores _explain results in MedFeatures
140
141 static void print_explain(MedSample &smp, int sort_mode = 0);
142
143 void dprint(const string &pref) const;
144
145 virtual ~ModelExplainer() {};
146};
147
151enum TreeExplainerMode {
152 ORIGINAL_IMPL = 0,
153 CONVERTED_TREES_IMPL = 1,
154 PROXY_IMPL = 2
155};
156
164private:
165 MedPredictor * proxy_predictor = NULL; //uses this if model has no tree implementation
166 //Tree structure of generic ensamble trees
167private:
168 bool convert_qrf_trees();
169 bool convert_lightgbm_trees();
170 bool convert_xgb_trees();
171 void _init(map<string, string> &mapper);
172public:
173 bool try_convert_trees();
174 TreeEnsemble generic_tree_model;
175 string proxy_model_type = "";
176 string proxy_model_init = "";
177 bool interaction_shap = false;
178 int approximate = false;
179 float missing_value = MED_MAT_MISSING_VALUE;
180 bool verbose = false;
181
182 TreeExplainer() { processor_type = FTR_POSTPROCESS_TREE_SHAP; }
183
184 void init_post_processor(MedModel& model);
185
186 TreeExplainerMode get_mode() const;
187
188 void _learn(const MedFeatures &train_mat);
189
190 void explain(const MedFeatures &matrix, vector<map<string, float>> &sample_explain_reasons) const;
191
192 void post_deserialization();
193
195
196 ADD_CLASS_NAME(TreeExplainer)
197 ADD_SERIALIZATION_FUNCS(proxy_predictor, interaction_shap, filters, processing, global_explain_params, verbose)
198};
199
208private:
209 MedPredictor * retrain_predictor = NULL; //the retrain model
210
211 void _init(map<string, string> &mapper);
212
213 float avg_bias_score;
214public:
217
225 string predictor_type;
231
232 // parameters for minimal_set usage if use_minimal_set is true will do different thing
241
242
244
245 void _learn(const MedFeatures &train_mat);
246
247 void explain(const MedFeatures &matrix, vector<map<string, float>> &sample_explain_reasons) const;
248
250
251 ADD_CLASS_NAME(MissingShapExplainer)
253 select_from_all, uniform_rand, use_shuffle, no_relearn, avg_bias_score, filters, processing, global_explain_params,
257};
258
263private:
264 unique_ptr<SamplesGenerator<float>> _sampler = NULL;
265 void *sampler_sampling_args = NULL;
266
267 GibbsSampler<float> _gibbs;
268 GibbsSamplingParams _gibbs_sample_params;
269
270 float avg_bias_score;
271
272 void init_sampler(bool with_sampler = true);
273
274 void _init(map<string, string> &mapper);
275public:
276 GeneratorType gen_type = GeneratorType::GIBBS;
277 string generator_args = "";
278 string sampling_args = "";
279 int n_masks = 100;
281 float missing_value = MED_MAT_MISSING_VALUE;
282
283 ShapleyExplainer() { processor_type = FTR_POSTPROCESS_SHAPLEY; avg_bias_score = 0; }
284
285 void _learn(const MedFeatures &train_mat);
286
287 void explain(const MedFeatures &matrix, vector<map<string, float>> &sample_explain_reasons) const;
288
289 void post_deserialization();
290
291 void load_GIBBS(MedPredictor *original_pred, const GibbsSampler<float> &gibbs, const GibbsSamplingParams &sampling_args);
292 void load_GAN(MedPredictor *original_pred, const string &gan_path);
293 void load_MISSING(MedPredictor *original_pred);
294 void load_sampler(MedPredictor *original_pred, unique_ptr<SamplesGenerator<float>> &&generator);
295
296 void dprint(const string &pref) const;
297
298 ADD_CLASS_NAME(ShapleyExplainer)
300 use_random_sampling, avg_bias_score, filters, processing, global_explain_params)
301};
302
307private:
308 unique_ptr<SamplesGenerator<float>> _sampler = NULL;
309 void *sampler_sampling_args = NULL;
310
311 //just for gibbs memory hold when init & learn
312 GibbsSampler<float> _gibbs;
313 GibbsSamplingParams _gibbs_sample_params;
314
315 void init_sampler(bool with_sampler = true);
316 void _init(map<string, string> &mapper);
317 medial::shapley::LimeWeightMethod get_weight_method(string method_s);
318public:
319 GeneratorType gen_type = GeneratorType::GIBBS;
320 string generator_args = "";
321 string sampling_args = "";
322 float missing_value = MED_MAT_MISSING_VALUE;
323 float p_mask = 0;
324 medial::shapley::LimeWeightMethod weighting = medial::shapley::LimeWeightSum;
325 int n_masks = 1250;
326
327 LimeExplainer() { processor_type = FTR_POSTPROCESS_LIME_SHAP; }
328
329 void _learn(const MedFeatures &train_mat);
330
331 void explain(const MedFeatures &matrix, vector<map<string, float>> &sample_explain_reasons) const;
332
333 void load_GIBBS(MedPredictor *original_pred, const GibbsSampler<float> &gibbs, const GibbsSamplingParams &sampling_args);
334 void load_GAN(MedPredictor *original_pred, const string &gan_path);
335 void load_MISSING(MedPredictor *original_pred);
336 void load_sampler(MedPredictor *original_pred, unique_ptr<SamplesGenerator<float>> &&generator);
337
338 void post_deserialization();
339
340 void dprint(const string &pref) const;
341
342 ADD_CLASS_NAME(LimeExplainer)
344 filters, processing, global_explain_params)
345};
346
351private:
352 MedFeatures trainingMap;
353 vector<float> average, std;
354
355 // do the calculation for a single sample after normalization
356 void computeExplanation(vector<float> thisRow, map<string, float> &sample_explain_reasons, vector <vector<int>> knnGroups, vector<string> knnGroupNames)const;
357
358 void _init(map<string, string> &mapper);
359public:
360
361 int numClusters = -1;
362 float fraction = (float)0.02;
363 float chosenThreshold = MED_MAT_MISSING_VALUE;
364 float thresholdQ = MED_MAT_MISSING_VALUE;
365
366 KNN_Explainer() { processor_type = FTR_POSTPROCESS_KNN_EXPLAIN; }
367
368 void _learn(const MedFeatures &train_mat);
369
370 void explain(const MedFeatures &matrix, vector<map<string, float>> &sample_explain_reasons) const;
371
372 ADD_CLASS_NAME(KNN_Explainer)
373 ADD_SERIALIZATION_FUNCS(numClusters, trainingMap, average, std, fraction, chosenThreshold, filters, processing, global_explain_params)
374};
375
380private:
381 void _init(map<string, string> &mapper);
382
383 float avg_bias_score;
384public:
385 LinearExplainer() { processor_type = FTR_POSTPROCESS_LINEAR; avg_bias_score = 0; }
386
387 void _learn(const MedFeatures &train_mat);
388
389 void explain(const MedFeatures &matrix, vector<map<string, float>> &sample_explain_reasons) const;
390
391 ADD_CLASS_NAME(LinearExplainer)
392 ADD_SERIALIZATION_FUNCS(avg_bias_score, filters, processing, global_explain_params)
393};
394
400private:
401 unique_ptr<SamplesGenerator<float>> _sampler = NULL;
402 void *sampler_sampling_args = NULL;
403
404 GibbsSampler<float> _gibbs;
405 GibbsSamplingParams _gibbs_sample_params;
406
407 float avg_bias_score;
408
409 void init_sampler(bool with_sampler = true);
410
411 void _init(map<string, string> &mapper);
412public:
413 GeneratorType gen_type = GeneratorType::GIBBS;
414 string generator_args = "";
415 string sampling_args = "";
416 int n_masks = 100;
418 float missing_value = MED_MAT_MISSING_VALUE;
419
425
426 IterativeSetExplainer() { processor_type = FTR_POSTPROCESS_ITERATIVE_SET; avg_bias_score = 0; }
427
428 void _learn(const MedFeatures &train_mat);
429
430 void explain(const MedFeatures &matrix, vector<map<string, float>> &sample_explain_reasons) const;
431
432 void post_deserialization();
433
434 void load_GIBBS(MedPredictor *original_pred, const GibbsSampler<float> &gibbs, const GibbsSamplingParams &sampling_args);
435 void load_GAN(MedPredictor *original_pred, const string &gan_path);
436 void load_MISSING(MedPredictor *original_pred);
437 void load_sampler(MedPredictor *original_pred, unique_ptr<SamplesGenerator<float>> &&generator);
438
439 void dprint(const string &pref) const;
440
441 ADD_CLASS_NAME(IterativeSetExplainer)
443 use_random_sampling, avg_bias_score, filters, processing, global_explain_params, max_set_size,
445};
446
457
458
459
460#endif
MedAlgo - APIs to different algorithms: Linear Models, RF, GBM, KNN, and more.
@ FTR_POSTPROCESS_ITERATIVE_SET
"iterative_set" to create IterativeSetExplainer - model agnostic iterative explainer for model....
Definition PostProcessor.h:23
@ FTR_POSTPROCESS_LINEAR
"linear" to create LinearExplainer to explain linear model - importance is score change when putting ...
Definition PostProcessor.h:22
@ FTR_POSTPROCESS_LIME_SHAP
"lime_shap" to create LimeExplainer - model agnostic shapley algorithm with lime on shap values sampl...
Definition PostProcessor.h:20
@ FTR_POSTPROCESS_SHAPLEY
"shapley" to create ShapleyExplainer - model agnostic shapley explainer for model....
Definition PostProcessor.h:18
@ FTR_POSTPROCESS_KNN_EXPLAIN
"knn" Explainer built on knn principles KNN_Explainer
Definition PostProcessor.h:21
@ FTR_POSTPROCESS_TREE_SHAP
"tree_shap" to create TreeExplainer to explain tree mode or mimic generic model with trees model
Definition PostProcessor.h:17
#define ADD_SERIALIZATION_FUNCS(...)
Definition SerializableObject.h:122
#define MEDSERIALIZE_SUPPORT(Type)
Definition SerializableObject.h:108
A specific settings for binning feature.
Definition BinSplitOptimizer.h:37
Parameters for filtering explanations.
Definition ExplainWrapper.h:18
int max_count
maximal limit to take as explain features. 0 - no limit
Definition ExplainWrapper.h:21
void filter(map< string, float > &explain_list) const
commit filterings
Definition ExplainWrapper.cpp:44
float sum_ratio
percentage of sum of explain values to take from sort_mode. [0 - 1]
Definition ExplainWrapper.h:22
int init(map< string, string > &map)
Virtual to init object from parsed fields.
Definition ExplainWrapper.cpp:26
int sort_mode
0 - both pos and negative (sorted by abs), -1 - only negatives, +1 - only positives
Definition ExplainWrapper.h:20
Processings of explanations - grouping, Using covariance matrix for taking feature correlations into ...
Definition ExplainWrapper.h:39
bool learn_cov_matrix
If true will learn cov_matrix.
Definition ExplainWrapper.h:44
int zero_missing
if != 0 will throw bias terms and zero all contributions of missing values and groups of missing valu...
Definition ExplainWrapper.h:46
static void read_feature_grouping(const string &file_name, const vector< string > &features, vector< vector< int > > &group2index, vector< string > &group_names, bool verbose=true)
Creates the feature groups from the argument file_name and by existing features.
Definition ExplainWrapper.cpp:733
void process(map< string, float > &explain_list) const
commit processings
Definition ExplainWrapper.cpp:327
bool iterative
if true will add explainers iteratively, conditioned on those already selected
Definition ExplainWrapper.h:48
int init(map< string, string > &map)
Virtual to init object from parsed fields.
Definition ExplainWrapper.cpp:92
string grouping
grouping file or "BY_SIGNAL" keyword to group by signal or "BY_SIGNAL_CATEG" - for category signal to...
Definition ExplainWrapper.h:57
int normalize_vals
If != 0 will normalize contributions. 1: normalize by sum of (non b0) abs of all contributions 2: sam...
Definition ExplainWrapper.h:45
float get_group_normalized_contrib(const vector< int > &group_inds, vector< float > &contribs, float total_normalization_factor) const
helper func: returns the normalized contribution for a specific group given original contributions
Definition ExplainWrapper.cpp:295
BinSettings mutual_inf_bin_setting
the bin setting for mutual information
Definition ExplainWrapper.h:53
bool group_by_sum
If true will do grouping by sum of each feature, otherwise will use internal special implementation.
Definition ExplainWrapper.h:43
MedMat< float > abs_cov_features
absolute values of covariance features for matrix.either read from file (and then apply absolute valu...
Definition ExplainWrapper.h:55
int iteration_cnt
if >0 the maximal number of iterations
Definition ExplainWrapper.h:49
bool use_mutual_information
if true will use mutual information instead of covariance
Definition ExplainWrapper.h:52
void learn(const MedFeatures &train_mat)
Learns process - for example cov matrix.
Definition ExplainWrapper.cpp:183
bool keep_b0
if true will keep b0 prior
Definition ExplainWrapper.h:47
bool use_max_cov
If true will use max cov logic.
Definition ExplainWrapper.h:50
A gibbs sampler - has learn and create sample based on mask.
Definition GibbsSampler.h:89
A class that contains all sampling arguments.
Definition GibbsSampler.h:71
A wrapper class to hold all global arguments needed for ModelExplainer.
Definition ExplainWrapper.h:91
bool store_as_json
If true will store ButWhy output as json in string attributes.
Definition ExplainWrapper.h:94
bool denorm_features
If true will save feature values denorm.
Definition ExplainWrapper.h:95
string attr_name
attribute name for explainer
Definition ExplainWrapper.h:93
iterative set explainer with (gibbs, GAN or other samples generator) or proxy predictor algorithm to ...
Definition ExplainWrapper.h:399
int max_set_size
the size to look for to explain
Definition ExplainWrapper.h:424
void _learn(const MedFeatures &train_mat)
overload function for ModelExplainer - easier API
Definition ExplainWrapper.cpp:2900
float sort_params_a
weight for minimal distance from original score importance
Definition ExplainWrapper.h:420
int n_masks
how many test to conduct from shapley
Definition ExplainWrapper.h:416
float sort_params_k2
weight for variance in prediction using imputation. the rest is change from prev
Definition ExplainWrapper.h:423
float sort_params_k1
weight for minimal distance from original score importance
Definition ExplainWrapper.h:422
float missing_value
missing value
Definition ExplainWrapper.h:418
void explain(const MedFeatures &matrix, vector< map< string, float > > &sample_explain_reasons) const
Virtual - return explain results in sample_feature_contrib.
Definition ExplainWrapper.cpp:2905
string sampling_args
args for sampling
Definition ExplainWrapper.h:415
float sort_params_b
weight for variance in prediction using imputation. the rest is change from prev
Definition ExplainWrapper.h:421
GeneratorType gen_type
generator type
Definition ExplainWrapper.h:413
string generator_args
for learn
Definition ExplainWrapper.h:414
bool use_random_sampling
If True will use random sampling - otherwise will sample mask size and than create it.
Definition ExplainWrapper.h:417
KNN explainer.
Definition ExplainWrapper.h:350
float chosenThreshold
Threshold to use on scores. If missing use thresholdQ to define threshold.
Definition ExplainWrapper.h:363
float thresholdQ
defines threshold by positive ratio on training set ( when chosenThreshold missing)....
Definition ExplainWrapper.h:364
float fraction
fraction of points that is considered neighborhood to a point
Definition ExplainWrapper.h:362
int numClusters
how many samples (randomly chosen) represent the training space -1:all. If larger than size of matrix...
Definition ExplainWrapper.h:361
void explain(const MedFeatures &matrix, vector< map< string, float > > &sample_explain_reasons) const
Virtual - return explain results in sample_feature_contrib.
Definition ExplainWrapper.cpp:2729
void _learn(const MedFeatures &train_mat)
overload function for ModelExplainer - easier API
Definition ExplainWrapper.cpp:2659
shapley-Lime explainer with gibbs, GAN or other sampler generator
Definition ExplainWrapper.h:306
int n_masks
number of masks
Definition ExplainWrapper.h:325
void explain(const MedFeatures &matrix, vector< map< string, float > > &sample_explain_reasons) const
Virtual - return explain results in sample_feature_contrib.
Definition ExplainWrapper.cpp:2493
string generator_args
for learn
Definition ExplainWrapper.h:320
GeneratorType gen_type
generator type
Definition ExplainWrapper.h:319
float p_mask
prob for 1 in mask, if 0 - mask generation done by first selecting # of 1's in mask (uniformly) and t...
Definition ExplainWrapper.h:323
float missing_value
missing value
Definition ExplainWrapper.h:322
string sampling_args
args for sampling
Definition ExplainWrapper.h:321
void _learn(const MedFeatures &train_mat)
overload function for ModelExplainer - easier API
Definition ExplainWrapper.cpp:2489
Simple Linear Explainer - puts zeros for each feature and measures change in score.
Definition ExplainWrapper.h:379
void explain(const MedFeatures &matrix, vector< map< string, float > > &sample_explain_reasons) const
Virtual - return explain results in sample_feature_contrib.
Definition ExplainWrapper.cpp:2561
void _learn(const MedFeatures &train_mat)
overload function for ModelExplainer - easier API
Definition ExplainWrapper.cpp:2557
A class for holding features data as a virtual matrix
Definition MedFeatures.h:47
Definition MedMat.h:63
A model = repCleaner + featureGenerator + featureProcessor + MedPredictor.
Definition MedModel.h:56
Base Interface for predictor.
Definition MedAlgo.h:78
MedSample represents a signle sample: id + time (date) Additional (optinal) entries: outcome,...
Definition MedSamples.h:20
Shapely Explainer - Based on learning training data to handle missing_values as "correct" input.
Definition ExplainWrapper.h:207
string verbose_apply
If has value - output file.
Definition ExplainWrapper.h:227
float override_score_bias
when given will use it as score bias it train is very different from test
Definition ExplainWrapper.h:239
float missing_value
missing value
Definition ExplainWrapper.h:219
float sort_params_k1
weight for minimal distance from original score importance
Definition ExplainWrapper.h:236
float split_to_test
to report RMSE on this ratio > 0 and < 1
Definition ExplainWrapper.h:240
float sort_params_b
weight for variance in prediction using imputation. the rest is change from prev
Definition ExplainWrapper.h:235
int max_set_size
the size to look for to explain
Definition ExplainWrapper.h:238
int subsample_train
if not zero will use this to subsample original train sampels to this number
Definition ExplainWrapper.h:229
string predictor_args
arguments to change in predictor - for example to change it into regression
Definition ExplainWrapper.h:224
void _learn(const MedFeatures &train_mat)
overload function for ModelExplainer - easier API
Definition ExplainWrapper.cpp:1723
bool no_relearn
If true will use original model without relearn. assume original model is good enough for missing val...
Definition ExplainWrapper.h:216
bool verbose_learn
If true will print more in learn.
Definition ExplainWrapper.h:226
bool use_minimal_set
If true will use different method to find minimal set.
Definition ExplainWrapper.h:233
float sort_params_k2
weight for variance in prediction using imputation. the rest is change from prev
Definition ExplainWrapper.h:237
float max_weight
the maximal weight number. if < 0 no limit
Definition ExplainWrapper.h:228
void explain(const MedFeatures &matrix, vector< map< string, float > > &sample_explain_reasons) const
Virtual - return explain results in sample_feature_contrib.
Definition ExplainWrapper.cpp:2012
int limit_mask_size
if set will limit mask size in the train - usefull for minimal_set
Definition ExplainWrapper.h:230
int add_new_data
how many new data data points to add for train according to sample masks
Definition ExplainWrapper.h:215
int max_test
max number of samples in SHAP
Definition ExplainWrapper.h:218
float select_from_all
If max_test is beyond this percentage of all options than sample from all options (to speed up runtim...
Definition ExplainWrapper.h:221
bool sample_masks_with_repeats
Whether or not to sample masks with repeats.
Definition ExplainWrapper.h:220
bool use_shuffle
if not sampling uniformlly, If true will use shuffle (to speed up runtime)
Definition ExplainWrapper.h:223
bool uniform_rand
it True will sample masks uniformlly
Definition ExplainWrapper.h:222
float sort_params_a
weight for minimal distance from original score importance
Definition ExplainWrapper.h:234
An abstract class API for explainer.
Definition ExplainWrapper.h:106
void Apply(MedFeatures &matrix)
alias for explain
Definition ExplainWrapper.h:127
virtual int init(map< string, string > &mapper)
Global init for general args in all explainers. initialize directly all args in GlobalExplainerParams...
Definition ExplainWrapper.cpp:530
void get_input_fields(vector< Effected_Field > &fields) const
List of fields that are used by this post_processor.
Definition ExplainWrapper.cpp:595
ExplainFilters filters
general filters of results
Definition ExplainWrapper.h:114
void get_output_fields(vector< Effected_Field > &fields) const
List of fields that are being effected by this post_processor.
Definition ExplainWrapper.cpp:598
virtual void explain(const MedFeatures &matrix, vector< map< string, float > > &sample_explain_reasons) const =0
Virtual - return explain results in sample_feature_contrib.
void init_post_processor(MedModel &model)
Init ModelExplainer from MedModel - copies predictor pointer, might save normalizers pointers.
Definition ExplainWrapper.cpp:614
virtual void _learn(const MedFeatures &train_mat)=0
overload function for ModelExplainer - easier API
ExplainProcessings processing
processing of results, like groupings, COV
Definition ExplainWrapper.h:115
virtual int update(map< string, string > &mapper)
Virtual to update object from parsed fields.
Definition ExplainWrapper.cpp:558
virtual void Learn(const MedFeatures &train_mat)
Learns from predictor and train_matrix (PostProcessor API)
Definition ExplainWrapper.cpp:897
MedPredictor * original_predictor
predictor we're trying to explain
Definition ExplainWrapper.h:113
An Abstract PostProcessor class.
Definition PostProcessor.h:39
Abstract Random Samples generator.
Definition SamplesGenerator.h:34
Definition SerializableObject.h:32
shapley explainer with gibbs, GAN or other samples generator
Definition ExplainWrapper.h:262
void explain(const MedFeatures &matrix, vector< map< string, float > > &sample_explain_reasons) const
Virtual - return explain results in sample_feature_contrib.
Definition ExplainWrapper.cpp:2259
string sampling_args
args for sampling
Definition ExplainWrapper.h:278
float missing_value
missing value
Definition ExplainWrapper.h:281
bool use_random_sampling
If True will use random sampling - otherwise will sample mask size and than create it.
Definition ExplainWrapper.h:280
void _learn(const MedFeatures &train_mat)
overload function for ModelExplainer - easier API
Definition ExplainWrapper.cpp:2254
int n_masks
how many test to conduct from shapley
Definition ExplainWrapper.h:279
string generator_args
for learn
Definition ExplainWrapper.h:277
GeneratorType gen_type
generator type
Definition ExplainWrapper.h:276
A generic tree explainer:
Definition ExplainWrapper.h:163
string proxy_model_init
proxy predictor arguments
Definition ExplainWrapper.h:176
void _learn(const MedFeatures &train_mat)
overload function for ModelExplainer - easier API
Definition ExplainWrapper.cpp:1407
float missing_value
missing value
Definition ExplainWrapper.h:179
int approximate
if true will run SAABAS alg - which is faster
Definition ExplainWrapper.h:178
bool interaction_shap
If true will calc interaction_shap values (slower)
Definition ExplainWrapper.h:177
void explain(const MedFeatures &matrix, vector< map< string, float > > &sample_explain_reasons) const
Virtual - return explain results in sample_feature_contrib.
Definition ExplainWrapper.cpp:1460
string proxy_model_type
proxy predictor type to relearn original predictor output with tree models
Definition ExplainWrapper.h:175
void init_post_processor(MedModel &model)
Init ModelExplainer from MedModel - copies predictor pointer, might save normalizers pointers.
Definition ExplainWrapper.cpp:1402
Definition StdDeque.h:58
Definition tree_shap.h:70