161 unordered_map<int, unordered_map<string, unordered_set<string>>> unknown_codes;
164 map<string, map<string, float>> mbr;
165 string default_threshold =
"";
172 bool model_init_done =
false;
173 bool model_rep_done =
false;
182 void set_name(
const char *_name) { name = string(_name); }
183 void set_model_end_stage(
int _model_end_stage) { model_end_stage = _model_end_stage; };
186 int init_rep_config(
const char *config_fname) {
187 rep.switch_to_in_mem_mode();
188 if (rep.MedRepository::init(
string(config_fname)) < 0)
return -1;
194 int set_time_unit_env(
int time_unit) {
195 global_default_time_unit = time_unit;
200 void set_pids(
int *_pids,
int npids) { pids.clear(); pids.assign(_pids, _pids + npids); }
203 int init_rep_with_file_data(
const char *_rep_fname) {
205 rep_fname = string(_rep_fname);
206 vector<string> sigs = {};
207 return (rep.read_all(rep_fname, pids, sigs));
211 int init_model_from_file(
const char *_model_fname) { model.clear(); model.verbosity = 0;
return (model.read_from_file(
string(_model_fname))); }
212 int model_check_required_signals() {
214 vector<string> req_sigs;
215 model.get_required_signal_names(req_sigs);
216 for (
const auto& s : req_sigs)
217 if (0 == rep.sigs.Name2Sid.count(s)) {
219 fprintf(stderr,
"ERROR: AM model requires signal '%s' but signal does not exist in AM repository .signals file\n", s.c_str());
225 int init_model_for_apply() {
226 global_logger.log(
LOG_APP, LOG_DEF_LEVEL,
"Init MedModel for Apply\n");
227 model_init_done =
true;
231 void fit_model_to_rep() {
232 model.fit_for_repository(rep);
235 int init_model_for_rep() {
237 if (!model_rep_done) {
238 model_rep_done =
true;
244 unordered_map<string, unordered_set<string>> *get_unknown_codes(
int pid) {
245 return &unknown_codes[pid];
248 int init_samples(
int *pids,
int *times,
int n_samples) { clear_samples();
int rc = insert_samples(pids, times, n_samples); samples.normalize();
return rc; }
249 int init_samples(
int pid,
int time) {
return init_samples(&pid, &time, 1); }
254 void add_json_dict(json &js) { rep.dict.add_json(js); }
256 bool model_initiated() {
return model_init_done; }
263 int data_load_init() { unknown_codes.clear(); rep.switch_to_in_mem_mode();
return 0; }
266 int data_load_pid_sig(
int pid,
const char *sig_name,
int *times,
float *vals,
int n_elems) {
267 int sid = rep.sigs.Name2Sid[string(sig_name)];
268 if (sid < 0)
return -1;
269 int n_times = n_elems * rep.sigs.Sid2Info[sid].n_time_channels, n_vals = n_elems * rep.sigs.Sid2Info[sid].n_val_channels;
270 if (times == NULL) n_times = 0;
271 if (vals == NULL) n_vals = 0;
272 return rep.in_mem_rep.insertData(pid, sid, times, vals, n_times, n_vals);
276 int data_load_pid_sig(
int pid,
const char *sig_name,
int *times,
int n_times,
float *vals,
int n_vals,
277 map<pair<int, int>, pair<
int, vector<char>>> *data = NULL) {
278 int sid = rep.sigs.Name2Sid[string(sig_name)];
279 if (sid < 0)
return -1;
281 data = &rep.in_mem_rep.data;
282 return rep.in_mem_rep.insertData_to_buffer(pid, sid, times, vals, n_times, n_vals, rep.sigs, *data);
286 int data_load_pid_sig(
int pid,
const char *sig_name,
int *times,
float *vals) {
return data_load_pid_sig(pid, sig_name, times, vals, 1); }
289 int data_load_end() {
return rep.in_mem_rep.sortData(); }
291 void get_rep_signals(unordered_set<string> &sigs)
293 for (
auto &sig : rep.sigs.signals_names)
305 void clear_samples() { samples.clear(); }
308 int insert_samples(
int *pids,
int *times,
int n_samples) {
309 for (
int i = 0; i < n_samples; i++)
310 samples.insertRec(pids[i], times[i]);
314 int insert_sample(
int pid,
int time) {
return insert_samples(&pid, &time, 1); }
319 int normalize_samples() { samples.normalize();
return 0; }
321 MedSamples *get_samples_ptr() {
return &samples; }
327 int get_preds(
int *_pids,
int *times,
float *preds,
int n_samples) {
330 init_samples(_pids, times, n_samples);
332 return get_raw_preds(_pids, times, preds);
335 int get_preds(
int *_pids,
int *times,
float *preds,
int n_samples,
336 const vector<Effected_Field> &requested_fields,
MedPidRepository *_rep=NULL) {
339 init_samples(_pids, times, n_samples);
342 return get_raw_preds(_pids, times, preds, requested_fields, _rep);
345 int get_raw_preds(
int *_pids,
int *times,
float *preds,
352 if (!samples.idSamples.empty())
353 model.no_init_apply_partial(*_rep, samples, requested_fields);
356 fprintf(stderr,
"Caught an exception in no_init_apply_partial\n");
363 for (
auto& idSample : samples.idSamples)
364 for (
auto& sample : idSample.samples) {
365 _pids[j] = sample.id;
366 times[j] = sample.time;
367 preds[j] = sample.prediction.size() > 0 ? sample.prediction[0] : (float)AM_UNDEFINED_VALUE;
374 catch (
int &exception_code) {
375 fprintf(stderr,
"Caught an exception code: %d\n", exception_code);
379 fprintf(stderr,
"Caught Something...\n");
385 int get_raw_preds(
int *_pids,
int *times,
float *preds) {
391 if (!samples.idSamples.empty())
393 fprintf(stderr,
"ERROR: MedAlgoMarkerInternal::get_preds FAILED.");
398 fprintf(stderr,
"Caught an exception in no_init_apply\n");
405 for (
auto& idSample : samples.idSamples)
406 for (
auto& sample : idSample.samples) {
407 _pids[j] = sample.id;
408 times[j] = sample.time;
409 preds[j] = sample.prediction.size() > 0 ? sample.prediction[0] : (float)AM_UNDEFINED_VALUE;
416 catch (
int &exception_code) {
417 fprintf(stderr,
"Caught an exception code: %d\n", exception_code);
421 fprintf(stderr,
"Caught Something...\n");
426 int get_preds(
MedSamples &_samples,
float *preds) {
432 fprintf(stderr,
"ERROR: MedAlgoMarkerInternal::get_preds FAILED.");
438 for (
auto& idSample : samples.idSamples)
439 for (
auto& sample : idSample.samples) {
440 preds[j++] = sample.prediction[0];
445 int get_pred(
int *pid,
int *time,
float *pred) {
return get_preds(pid, time, pred, 1); }
451 void clear() { unknown_codes.clear(); pids.clear(); model.clear(); samples.clear(); rep.in_mem_rep.clear(); rep.clear(); }
455 samples.clear(); rep.in_mem_rep.clear(); unknown_codes.clear();
462 const char *get_name() {
return name.c_str(); }
464 void write_features_mat(
const string &feat_mat) { model.write_feature_matrix(feat_mat); }
465 void add_features_mat(
const string &feat_mat) { model.write_feature_matrix(feat_mat,
false,
true); }
467 void get_signal_structure(
string &sig,
int &n_time_channels,
int &n_val_channels,
int* &is_categ)
469 int sid = this->rep.sigs.sid(sig);
475 n_time_channels = this->rep.sigs.Sid2Info[sid].n_time_channels;
476 n_val_channels = this->rep.sigs.Sid2Info[sid].n_val_channels;
477 is_categ = &(this->rep.sigs.Sid2Info[sid].is_categorical_per_val_channel[0]);
481 void model_apply_verbose(
bool flag) {
482 if ((model.verbosity > 0) ^ flag) {
483 model.verbosity = int(flag);
485 string full_log_format =
"$timestamp\t$level\t$section\t%s";
488 global_logger.
init_format(LOG_MED_MODEL, full_log_format);
489 global_logger.
init_format(LOG_MEDALGO, full_log_format);
494 string model_version_info()
const {
495 return model.version_info;
498 void get_model_signals_info(vector<string> &sigs,
499 unordered_map<
string, vector<string>> &res_categ)
const {
500 model.get_required_signal_names(sigs);
501 model.get_required_signal_categories(res_categ);
505 out = explainer_params;
508 void get_explainer_output_options(vector<string> &opts) {
509 vector<const PostProcessor *> flat;
515 flat.push_back(m_pp);
524 if (explainer_m != NULL) {
525 for (
const string &grp : explainer_m->
processing.groupNames)
532 void set_explainer_params(
const string ¶ms,
const string &base_dir) {
533 explainer_params.base_dir = base_dir;
537 void set_threshold_leaflet(
const string &init_string,
const string &base_dir) {
538 map<string, string> params;
539 if (MedSerialize::init_map_from_string(init_string, params) < 0)
540 MTHROW_AND_ERR(
"Error Init from String %s\n", init_string.c_str());
541 string bt_file_path =
"";
542 map<string, string> rename_cohorts;
543 for (
const auto &it : params)
545 if (it.first ==
"bootstrap_file_path")
546 bt_file_path = it.second;
547 else if (it.first ==
"rename_cohorts") {
548 vector<string> tokens;
549 boost::split(tokens, it.second, boost::is_any_of(
"#"));
550 for (
const string &tk : tokens)
552 vector<string> src_target;
553 boost::split(src_target, tk, boost::is_any_of(
"|"));
554 if (src_target.size() != 2)
555 MTHROW_AND_ERR(
"Error expecting 2 tokens, recieved \"%s\"\n", tk.c_str());
556 mes_trim(src_target[1]);
557 mes_trim(src_target[1]);
558 rename_cohorts[src_target[0]] = src_target[1];
561 else if (it.first ==
"default_threshold") {
562 default_threshold = it.second;
563 mes_trim(default_threshold);
566 MTHROW_AND_ERR(
"Error unknown param %s\n", it.first.c_str());
568 if (bt_file_path.empty())
569 MTHROW_AND_ERR(
"Error must provide bootstrap_file_path in THRESHOLD_LEAFLET\n");
571 if (bt_file_path !=
"" && bt_file_path[0] !=
'/' && bt_file_path[0] !=
'\\' && !base_dir.empty())
572 bt_file_path = base_dir + path_sep() + bt_file_path;
574 if (default_threshold.empty())
575 MTHROW_AND_ERR(
"Error - must have default_threshold\n");
577 map<string, map<string, float>> mbr_before;
578 read_pivot_bootstrap_results(bt_file_path, mbr_before);
581 for (
auto &it : mbr_before)
583 string cohort = it.first;
584 if (rename_cohorts.find(cohort) != rename_cohorts.end())
585 cohort = rename_cohorts[cohort];
587 map<string, float> &filt = mbr[cohort];
588 for (
const auto &jt : it.second)
589 if (boost::starts_with(jt.first,
"SCORE@") && boost::ends_with(jt.first,
"_Mean") && jt.second != MED_MAT_MISSING_VALUE)
590 filt[jt.first.substr(6, jt.first.length() - 11)] = jt.second;
595 fetch_threshold(default_threshold, err_c);
596 if (!err_c.empty()) {
598 fetch_all_thresholds(opts);
599 for (
const string & s : opts)
600 MLOG(
"Option: \"%s\"\n", s.c_str());
601 MTHROW_AND_ERR(
"Error default_threshold is invalid - please select one in format as COHORT$MEASURE_NUMERIC\n");
605 bool has_threshold_settings()
const {
609 string get_default_threshold()
const {
return default_threshold; }
611 void fetch_all_thresholds(vector<string> &opts)
const {
612 for (
const auto &it : mbr)
614 for (
const auto &jt : it.second)
616 string res = it.first +
"$" + jt.first;
622 float fetch_threshold(
const string &threshold,
string &err_msg)
const {
623 vector<string> tokens;
625 boost::split(tokens, threshold, boost::is_any_of(
"$"));
626 if (tokens.size() != 2) {
627 err_msg =
"(" + to_string(AM_THRESHOLD_ERROR_NON_FATAL) +
")Error flag_threshold should contain $";
628 return MED_MAT_MISSING_VALUE;
632 if (mbr.find(tokens[0]) == mbr.end()) {
633 err_msg =
"(" + to_string(AM_THRESHOLD_ERROR_NON_FATAL) +
")Error flag_threshold doesn't contain threshold settings for " + tokens[0];
634 return MED_MAT_MISSING_VALUE;
636 const map<string, float> &fnd = mbr.at(tokens[0]);
638 vector<string> meas_tokens;
639 boost::split(meas_tokens, tokens[1], boost::is_any_of(
"_"));
640 if (meas_tokens.size() != 2) {
641 err_msg =
"(" + to_string(AM_THRESHOLD_ERROR_NON_FATAL) +
")Error flag_threshold doesn't should contain _ in the cutoff setting part";
642 return MED_MAT_MISSING_VALUE;
646 num_val = stof(meas_tokens[1]);
649 err_msg =
"(" + to_string(AM_THRESHOLD_ERROR_NON_FATAL) +
")Error flag_threshold search cutoff isn't numeric";
650 return MED_MAT_MISSING_VALUE;
653 float res = MED_MAT_MISSING_VALUE;
654 for (
const auto &jt : fnd)
656 string cand = jt.first;
657 vector<string> cand_tokens;
658 boost::split(cand_tokens, cand, boost::is_any_of(
"_"));
659 if (cand_tokens.size() != 2)
661 if (cand_tokens[0] != meas_tokens[0])
666 num_val_cmp = stof(cand_tokens[1]);
671 if (abs(num_val_cmp - num_val) <= 1e-6) {
677 if (res == MED_MAT_MISSING_VALUE)
678 err_msg =
"(" + to_string(AM_THRESHOLD_ERROR_NON_FATAL) +
")Error flag_threshold doesn't contain threshold for " + tokens[1];