Medial Code Documentation
Loading...
Searching...
No Matches
tree_shap.h
1
7#ifndef __TREE_SHAP_H__
8#define __TREE_SHAP_H__
9
10#include <algorithm>
11#include <iostream>
12#include <fstream>
13#include <stdio.h>
14#include <cmath>
15#include <ctime>
16#if defined(_WIN32) || defined(WIN32)
17#include <malloc.h>
18#else
19#include <alloca.h>
20#endif
21
22#include <vector>
23#include <random>
24#include <unordered_set>
25#include <MedProcessTools/MedProcessTools/MedFeatures.h>
27#include "SamplesGenerator.h"
28
29using namespace std;
30
31typedef double tfloat;
32
33#define FROM_NEITHER 0
34#define FROM_X_NOT_R 1
35#define FROM_R_NOT_X 2
36
37namespace FEATURE_DEPENDENCE {
38 const unsigned independent = 0;
39 const unsigned tree_path_dependent = 1;
40 const unsigned global_path_dependent = 2;
41}
42
43namespace MODEL_TRANSFORM {
44 const unsigned identity = 0;
45 const unsigned logistic = 1;
46 const unsigned logistic_nlogloss = 2;
47 const unsigned squared_loss = 3;
48}
49
51 tfloat *X;
52 bool *X_missing;
53 tfloat *y;
54 tfloat *R;
55 bool *R_missing;
56 unsigned num_X;
57 unsigned M;
58 unsigned num_Exp;
59 unsigned num_R;
60
62 ExplanationDataset(tfloat *X, bool *X_missing, tfloat *y, tfloat *R, bool *R_missing, unsigned num_X,
63 unsigned M, unsigned num_R, unsigned num_Exp);
64 ExplanationDataset(tfloat *X, bool *X_missing, tfloat *y, tfloat *R, bool *R_missing, unsigned num_X,
65 unsigned M, unsigned num_R);
66
67 void get_x_instance(ExplanationDataset &instance, const unsigned i) const;
68};
69
71 int *children_left;
72 int *children_right;
73 int *children_default;
74 int *features;
75 tfloat *thresholds;
76 tfloat *values;
77 tfloat *node_sample_weights;
78 unsigned max_depth;
79 unsigned tree_limit;
80 tfloat base_offset;
81 unsigned max_nodes;
82 unsigned num_outputs;
83 bool is_allocate;
84
86 TreeEnsemble(int *children_left, int *children_right, int *children_default, int *features,
87 tfloat *thresholds, tfloat *values, tfloat *node_sample_weights,
88 unsigned max_depth, unsigned tree_limit, tfloat base_offset,
89 unsigned max_nodes, unsigned num_outputs);
90
91 void get_tree(TreeEnsemble &tree, const unsigned i) const;
92
93 void allocate(unsigned tree_limit_in, unsigned max_nodes_in, unsigned num_outputs_in);
94
95 void free();
96
97 void fill_adjusted_tree(int node_index, ExplanationDataset& instance, const int *mask, unsigned *feature_sets, TreeEnsemble& adjusted);
98 void create_adjusted_tree(ExplanationDataset& instance, const int *mask, unsigned *feature_sets, TreeEnsemble& adjusted);
99 void calc_feature_contribs_conditional(MedMat<float> &mat_x_in, unordered_map<string, float> contiditional_variables, MedMat<float> &mat_x_out, MedMat<float> &mat_contribs);
100 tfloat predict(ExplanationDataset& instance, int node_index);
101};
102
103inline void tree_shap(const TreeEnsemble& tree, const ExplanationDataset &data, tfloat *out_contribs, int condition, unsigned condition_feature, unsigned *feature_sets);
104
105// data we keep about our decision path
106// note that pweight is included for convenience and is not tied with the other attributes
107// the pweight of the i'th path element is the permuation weight of paths with i-1 ones in them
109 int feature_index;
110 tfloat zero_fraction;
111 tfloat one_fraction;
112 tfloat pweight;
113 PathElement();
114 PathElement(int i, tfloat z, tfloat o, tfloat w);
115};
116
117// Independent Tree SHAP functions below here
118// ------------------------------------------
119struct Node {
120 short cl, cr, cd, pnode, feat, pfeat; // uint_16
121 float thres, value;
122 char from_flag;
123};
124
128void dense_tree_saabas(tfloat *out_contribs, const TreeEnsemble& trees, const ExplanationDataset &data);
129
133void dense_independent(const TreeEnsemble& trees, const ExplanationDataset &data,
134 tfloat *out_contribs, tfloat transform(const tfloat, const tfloat));
135
136
140void dense_tree_path_dependent(const TreeEnsemble& trees, const ExplanationDataset &data,
141 tfloat *out_contribs, unsigned *feature_sets, tfloat transform(const tfloat, const tfloat));
142
143// phi = np.zeros((self._current_X.shape[1] + 1, self._current_X.shape[1] + 1, self.n_outputs))
144// phi_diag = np.zeros((self._current_X.shape[1] + 1, self.n_outputs))
145// for t in range(self.tree_limit):
146// self.tree_shap(self.trees[t], self._current_X[i,:], self._current_x_missing, phi_diag)
147// for j in self.trees[t].unique_features:
148// phi_on = np.zeros((self._current_X.shape[1] + 1, self.n_outputs))
149// phi_off = np.zeros((self._current_X.shape[1] + 1, self.n_outputs))
150// self.tree_shap(self.trees[t], self._current_X[i,:], self._current_x_missing, phi_on, 1, j)
151// self.tree_shap(self.trees[t], self._current_X[i,:], self._current_x_missing, phi_off, -1, j)
152// phi[j] += np.true_divide(np.subtract(phi_on,phi_off),2.0)
153// phi_diag[j] -= np.sum(np.true_divide(np.subtract(phi_on,phi_off),2.0))
154// for j in range(self._current_X.shape[1]+1):
155// phi[j][j] = phi_diag[j]
156// phi /= self.tree_limit
157// return phi
158
159void dense_tree_interactions_path_dependent(const TreeEnsemble& trees, const ExplanationDataset &data,
160 tfloat *out_contribs,
161 tfloat transform(const tfloat, const tfloat));
162
170void dense_global_path_dependent(const TreeEnsemble& trees, const ExplanationDataset &data,
171 tfloat *out_contribs, tfloat transform(const tfloat, const tfloat));
172
173
177void dense_tree_shap(const TreeEnsemble& trees, const ExplanationDataset &data, tfloat *out_contribs,
178 const int feature_dependence, unsigned model_transform, bool interactions);
179void dense_tree_shap(const TreeEnsemble& trees, const ExplanationDataset &data, tfloat *out_contribs,
180 const int feature_dependence, unsigned model_transform, bool interactions, unsigned *feature_sets);
181
185void iterative_tree_shap(const TreeEnsemble& trees, const ExplanationDataset &data, tfloat *out_contribs,
186 const int feature_dependence, unsigned model_transform, bool interactions, unsigned *feature_sets, bool verbose,
187 vector<string>& names, const MedMat<float>& abs_cov_mat, int iteration_cnt, bool max_in_groups);
188
189namespace medial {
190 namespace shapley {
191
193 double nchoosek(long n, long k);
195 void list_all_options_binary(int nfeats, vector<vector<bool>> &all_opts);
197 void generate_mask(vector<bool> &mask, int nfeat, mt19937 &gen, bool uniform_rand = false, bool use_shuffle = true);
199 void generate_mask_(vector<bool> &mask, int nfeat, mt19937 &gen, bool uniform_rand = false, float uniform_rand_p = 0.5,
200 bool use_shuffle = true, int limit_zero_cnt = 0);
202 void sample_options_SHAP(int nfeats, vector<vector<bool>> &all_opts, int opt_count, mt19937 &gen, bool with_repeats
203 , bool uniform_rand = false, bool use_shuffle = true);
205 double get_c(int p1, int p2, int end_l);
207 void explain_shapley(const MedFeatures &matrix, int selected_sample, int max_tests,
208 MedPredictor *predictor, float missing_value, const vector<vector<int>>& group2index, const vector<string> &groupNames,
209 vector<float> &features_coeff, mt19937 &gen, bool sample_masks_with_repeats,
210 float select_from_all, bool uniform_rand, bool use_shuffle, bool verbose);
212 template<typename T> void explain_shapley(const MedFeatures &matrix, int selected_sample, int max_tests,
213 MedPredictor *predictor, const vector<vector<int>>& group2index, const vector<string> &groupNames,
214 const SamplesGenerator<T> &sampler_gen, mt19937 &rnd_gen, int sample_per_row, void *sampling_params,
215 vector<float> &features_coeff, bool use_random_sample, bool verbose = false);
216
218 void explain_minimal_set(const MedFeatures &matrix, int selected_sample, int max_tests,
219 MedPredictor *predictor, float missing_value, const vector<vector<int>>& group2index
220 , vector<float> &features_coeff, vector<float> &scores_history, int max_set_size,
221 float baseline_score, float param_all_alpha, float param_all_beta,
222 float param_all_k1, float param_all_k2, bool verbose);
223
225 void explain_minimal_set(const MedFeatures &matrix, int selected_sample, int max_tests,
226 MedPredictor *predictor, float missing_value, const vector<vector<int>>& group2index,
227 const SamplesGenerator<float> &sampler_gen, mt19937 &rnd_gen, void *sampling_params
228 , vector<float> &features_coeff, vector<float> &scores_history, int max_set_size,
229 float baseline_score, float param_all_alpha, float param_all_beta,
230 float param_all_k1, float param_all_k2, bool verbose);
231
233 typedef enum {
234 LimeWeightLime = 0,
235 LimeWeightUniform = 1,
236 LimeWeightShap = 2,
237 LimeWeightSum = 3,
238 LimeWeightLast
239 } LimeWeightMethod;
240
242 void get_shapley_lime_params(const MedFeatures& data, const MedPredictor *model,
243 SamplesGenerator<float> *generator, float p, int n, LimeWeightMethod weighting, float missing,
244 void *params, const vector<vector<int>>& group2index, const vector<string>& group_names, vector<vector<float>>& alphas);
245
247 void get_shapley_lime_params(const MedFeatures& data, const MedPredictor *model,
248 SamplesGenerator<float> *generator, float p, int n, LimeWeightMethod weighting, float missing,
249 void *params, const vector<vector<int>>& group2index, const vector<string>& group_names, vector<vector<int>>& forced, vector<vector<float>>& alphas);
250
252 void get_iterative_shapley_lime_params(const MedFeatures& data, const MedPredictor *model,
253 SamplesGenerator<float> *generator, float p, int n, LimeWeightMethod weighting, float missing,
254 void *params, const vector<vector<int>>& group2index, const vector<string>& group_names, const MedMat<float>& abs_cov_mat, int iteration_cnt, vector<vector<float>>& alphas, bool max_in_groups);
255 }
256}
257
258#endif
MedAlgo - APIs to different algorithms: Linear Models, RF, GBM, KNN, and more.
A class for holding features data as a virtual matrix
Definition MedFeatures.h:47
Definition MedMat.h:63
Base Interface for predictor.
Definition MedAlgo.h:78
Abstract Random Samples generator.
Definition SamplesGenerator.h:34
medial namespace for function
Definition InfraMed.h:667
Definition StdDeque.h:58
Definition tree_shap.py:1
Definition tree_shap.h:50
tfloat * X
vector of all data. each row is sample of all features for that sample. cols(2nd dim) are features
Definition tree_shap.h:51
unsigned M
Features count.
Definition tree_shap.h:57
unsigned num_X
number of samples
Definition tree_shap.h:56
unsigned num_R
number of explanation features (allowing for grouping)
Definition tree_shap.h:59
bool * X_missing
bool mask to return true on missing value on matrix - same structure as X
Definition tree_shap.h:52
tfloat * y
the labels
Definition tree_shap.h:53
Definition tree_shap.h:119
Definition tree_shap.h:108
Definition tree_shap.h:70