162 unordered_map<int, unordered_map<string, unordered_set<string>>> unknown_codes;
165 map<string, map<string, float>> mbr;
166 string default_threshold =
"";
173 bool model_init_done =
false;
174 bool model_rep_done =
false;
183 void set_name(
const char *_name) { name = string(_name); }
184 void set_model_end_stage(
int _model_end_stage) { model_end_stage = _model_end_stage; };
187 int init_rep_config(
const char *config_fname) {
188 rep.switch_to_in_mem_mode();
189 if (rep.MedRepository::init(
string(config_fname)) < 0)
return -1;
195 int set_time_unit_env(
int time_unit) {
196 global_default_time_unit = time_unit;
201 void set_pids(
int *_pids,
int npids) { pids.clear(); pids.assign(_pids, _pids + npids); }
204 int init_rep_with_file_data(
const char *_rep_fname) {
206 rep_fname = string(_rep_fname);
207 vector<string> sigs = {};
208 return (rep.read_all(rep_fname, pids, sigs));
212 int init_model_from_file(
const char *_model_fname) { model.clear(); model.verbosity = 0;
return (model.read_from_file(
string(_model_fname))); }
213 int model_check_required_signals() {
215 vector<string> req_sigs;
216 model.get_required_signal_names(req_sigs);
217 for (
const auto& s : req_sigs)
218 if (0 == rep.sigs.Name2Sid.count(s)) {
220 fprintf(stderr,
"ERROR: AM model requires signal '%s' but signal does not exist in AM repository .signals file\n", s.c_str());
226 int init_model_for_apply() {
227 global_logger.log(
LOG_APP, LOG_DEF_LEVEL,
"Init MedModel for Apply\n");
228 model_init_done =
true;
232 void fit_model_to_rep() {
233 model.fit_for_repository(rep);
236 int init_model_for_rep() {
238 if (!model_rep_done) {
239 model_rep_done =
true;
245 unordered_map<string, unordered_set<string>> *get_unknown_codes(
int pid) {
246 return &unknown_codes[pid];
249 int init_samples(
int *pids,
int *times,
int n_samples) { clear_samples();
int rc = insert_samples(pids, times, n_samples); samples.normalize();
return rc; }
250 int init_samples(
int pid,
int time) {
return init_samples(&pid, &time, 1); }
255 void add_json_dict(json &js) { rep.dict.add_json(js); }
257 bool model_initiated() {
return model_init_done; }
264 int data_load_init() { unknown_codes.clear(); rep.switch_to_in_mem_mode();
return 0; }
267 int data_load_pid_sig(
int pid,
const char *sig_name,
int *times,
float *vals,
int n_elems) {
268 int sid = rep.sigs.Name2Sid[string(sig_name)];
269 if (sid < 0)
return -1;
270 int n_times = n_elems * rep.sigs.Sid2Info[sid].n_time_channels, n_vals = n_elems * rep.sigs.Sid2Info[sid].n_val_channels;
271 if (times == NULL) n_times = 0;
272 if (vals == NULL) n_vals = 0;
273 return rep.in_mem_rep.insertData(pid, sid, times, vals, n_times, n_vals);
277 int data_load_pid_sig(
int pid,
const char *sig_name,
int *times,
int n_times,
float *vals,
int n_vals,
278 map<pair<int, int>, pair<
int, vector<char>>> *data = NULL) {
279 int sid = rep.sigs.Name2Sid[string(sig_name)];
280 if (sid < 0)
return -1;
282 data = &rep.in_mem_rep.data;
283 return rep.in_mem_rep.insertData_to_buffer(pid, sid, times, vals, n_times, n_vals, rep.sigs, *data);
287 int data_load_pid_sig(
int pid,
const char *sig_name,
int *times,
float *vals) {
return data_load_pid_sig(pid, sig_name, times, vals, 1); }
290 int data_load_end() {
return rep.in_mem_rep.sortData(); }
292 void get_rep_signals(unordered_set<string> &sigs)
294 for (
auto &sig : rep.sigs.signals_names)
306 void clear_samples() { samples.clear(); }
309 int insert_samples(
int *pids,
int *times,
int n_samples) {
310 for (
int i = 0; i < n_samples; i++)
311 samples.insertRec(pids[i], times[i]);
315 int insert_sample(
int pid,
int time) {
return insert_samples(&pid, &time, 1); }
320 int normalize_samples() { samples.normalize();
return 0; }
322 MedSamples *get_samples_ptr() {
return &samples; }
328 int get_preds(
int *_pids,
int *times,
float *preds,
int n_samples) {
331 init_samples(_pids, times, n_samples);
333 return get_raw_preds(_pids, times, preds);
336 int get_preds(
int *_pids,
int *times,
float *preds,
int n_samples,
337 const vector<Effected_Field> &requested_fields,
MedPidRepository *_rep=NULL) {
340 init_samples(_pids, times, n_samples);
343 return get_raw_preds(_pids, times, preds, requested_fields, _rep);
346 int get_raw_preds(
int *_pids,
int *times,
float *preds,
353 if (!samples.idSamples.empty())
354 model.no_init_apply_partial(*_rep, samples, requested_fields);
357 fprintf(stderr,
"Caught an exception in no_init_apply_partial\n");
364 for (
auto& idSample : samples.idSamples)
365 for (
auto& sample : idSample.samples) {
366 _pids[j] = sample.id;
367 times[j] = sample.time;
368 preds[j] = sample.prediction.size() > 0 ? sample.prediction[0] : (float)AM_UNDEFINED_VALUE;
375 catch (
int &exception_code) {
376 fprintf(stderr,
"Caught an exception code: %d\n", exception_code);
380 fprintf(stderr,
"Caught Something...\n");
386 int get_raw_preds(
int *_pids,
int *times,
float *preds) {
392 if (!samples.idSamples.empty())
394 fprintf(stderr,
"ERROR: MedAlgoMarkerInternal::get_preds FAILED.");
399 fprintf(stderr,
"Caught an exception in no_init_apply\n");
406 for (
auto& idSample : samples.idSamples)
407 for (
auto& sample : idSample.samples) {
408 _pids[j] = sample.id;
409 times[j] = sample.time;
410 preds[j] = sample.prediction.size() > 0 ? sample.prediction[0] : (float)AM_UNDEFINED_VALUE;
417 catch (
int &exception_code) {
418 fprintf(stderr,
"Caught an exception code: %d\n", exception_code);
422 fprintf(stderr,
"Caught Something...\n");
427 int get_preds(
MedSamples &_samples,
float *preds) {
433 fprintf(stderr,
"ERROR: MedAlgoMarkerInternal::get_preds FAILED.");
439 for (
auto& idSample : samples.idSamples)
440 for (
auto& sample : idSample.samples) {
441 preds[j++] = sample.prediction[0];
446 int get_pred(
int *pid,
int *time,
float *pred) {
return get_preds(pid, time, pred, 1); }
452 void clear() { unknown_codes.clear(); pids.clear(); model.clear(); samples.clear(); rep.in_mem_rep.clear(); rep.clear(); }
456 samples.clear(); rep.in_mem_rep.clear(); unknown_codes.clear();
463 const char *get_name() {
return name.c_str(); }
465 void write_features_mat(
const string &feat_mat) { model.write_feature_matrix(feat_mat); }
466 void add_features_mat(
const string &feat_mat) { model.write_feature_matrix(feat_mat,
false,
true); }
468 void get_signal_structure(
string &sig,
int &n_time_channels,
int &n_val_channels,
int* &is_categ)
470 int sid = this->rep.sigs.sid(sig);
476 n_time_channels = this->rep.sigs.Sid2Info[sid].n_time_channels;
477 n_val_channels = this->rep.sigs.Sid2Info[sid].n_val_channels;
478 is_categ = &(this->rep.sigs.Sid2Info[sid].is_categorical_per_val_channel[0]);
482 void model_apply_verbose(
bool flag) {
483 if ((model.verbosity > 0) ^ flag) {
484 model.verbosity = int(flag);
486 string full_log_format =
"$timestamp\t$level\t$section\t%s";
489 global_logger.
init_format(LOG_MED_MODEL, full_log_format);
490 global_logger.
init_format(LOG_MEDALGO, full_log_format);
495 string model_version_info()
const {
496 return model.version_info;
499 void get_model_signals_info(vector<string> &sigs,
500 unordered_map<
string, vector<string>> &res_categ)
const {
501 model.get_required_signal_names(sigs);
502 model.get_required_signal_categories(res_categ);
506 out = explainer_params;
509 void get_explainer_output_options(vector<string> &opts) {
510 vector<const PostProcessor *> flat;
516 flat.push_back(m_pp);
525 if (explainer_m != NULL) {
526 for (
const string &grp : explainer_m->
processing.groupNames)
533 void set_explainer_params(
const string ¶ms,
const string &base_dir) {
534 explainer_params.base_dir = base_dir;
538 void set_threshold_leaflet(
const string &init_string,
const string &base_dir) {
539 map<string, string> params;
540 if (MedSerialize::init_map_from_string(init_string, params) < 0)
541 MTHROW_AND_ERR(
"Error Init from String %s\n", init_string.c_str());
542 string bt_file_path =
"";
543 map<string, string> rename_cohorts;
544 for (
const auto &it : params)
546 if (it.first ==
"bootstrap_file_path")
547 bt_file_path = it.second;
548 else if (it.first ==
"rename_cohorts") {
549 vector<string> tokens;
550 boost::split(tokens, it.second, boost::is_any_of(
"#"));
551 for (
const string &tk : tokens)
553 vector<string> src_target;
554 boost::split(src_target, tk, boost::is_any_of(
"|"));
555 if (src_target.size() != 2)
556 MTHROW_AND_ERR(
"Error expecting 2 tokens, recieved \"%s\"\n", tk.c_str());
557 mes_trim(src_target[1]);
558 mes_trim(src_target[1]);
559 rename_cohorts[src_target[0]] = src_target[1];
562 else if (it.first ==
"default_threshold") {
563 default_threshold = it.second;
564 mes_trim(default_threshold);
567 MTHROW_AND_ERR(
"Error unknown param %s\n", it.first.c_str());
569 if (bt_file_path.empty())
570 MTHROW_AND_ERR(
"Error must provide bootstrap_file_path in THRESHOLD_LEAFLET\n");
572 if (bt_file_path !=
"" && bt_file_path[0] !=
'/' && bt_file_path[0] !=
'\\' && !base_dir.empty())
573 bt_file_path = base_dir + path_sep() + bt_file_path;
575 if (default_threshold.empty())
576 MTHROW_AND_ERR(
"Error - must have default_threshold\n");
578 map<string, map<string, float>> mbr_before;
579 read_pivot_bootstrap_results(bt_file_path, mbr_before);
582 for (
auto &it : mbr_before)
584 string cohort = it.first;
585 if (rename_cohorts.find(cohort) != rename_cohorts.end())
586 cohort = rename_cohorts[cohort];
588 map<string, float> &filt = mbr[cohort];
589 for (
const auto &jt : it.second)
590 if (boost::starts_with(jt.first,
"SCORE@") && boost::ends_with(jt.first,
"_Mean") && jt.second != MED_MAT_MISSING_VALUE)
591 filt[jt.first.substr(6, jt.first.length() - 11)] = jt.second;
596 fetch_threshold(default_threshold, err_c);
597 if (!err_c.empty()) {
599 fetch_all_thresholds(opts);
600 for (
const string & s : opts)
601 MLOG(
"Option: \"%s\"\n", s.c_str());
602 MTHROW_AND_ERR(
"Error default_threshold is invalid - please select one in format as COHORT$MEASURE_NUMERIC\n");
606 bool has_threshold_settings()
const {
610 string get_default_threshold()
const {
return default_threshold; }
612 void fetch_all_thresholds(vector<string> &opts)
const {
613 for (
const auto &it : mbr)
615 for (
const auto &jt : it.second)
617 string res = it.first +
"$" + jt.first;
623 float fetch_threshold(
const string &threshold,
string &err_msg)
const {
624 vector<string> tokens;
626 boost::split(tokens, threshold, boost::is_any_of(
"$"));
627 if (tokens.size() != 2) {
628 err_msg =
"(" + to_string(AM_THRESHOLD_ERROR_NON_FATAL) +
")Error flag_threshold should contain $";
629 return MED_MAT_MISSING_VALUE;
633 if (mbr.find(tokens[0]) == mbr.end()) {
634 err_msg =
"(" + to_string(AM_THRESHOLD_ERROR_NON_FATAL) +
")Error flag_threshold doesn't contain threshold settings for " + tokens[0];
635 return MED_MAT_MISSING_VALUE;
637 const map<string, float> &fnd = mbr.at(tokens[0]);
639 vector<string> meas_tokens;
640 boost::split(meas_tokens, tokens[1], boost::is_any_of(
"_"));
641 if (meas_tokens.size() != 2) {
642 err_msg =
"(" + to_string(AM_THRESHOLD_ERROR_NON_FATAL) +
")Error flag_threshold doesn't should contain _ in the cutoff setting part";
643 return MED_MAT_MISSING_VALUE;
647 num_val = stof(meas_tokens[1]);
650 err_msg =
"(" + to_string(AM_THRESHOLD_ERROR_NON_FATAL) +
")Error flag_threshold search cutoff isn't numeric";
651 return MED_MAT_MISSING_VALUE;
654 float res = MED_MAT_MISSING_VALUE;
655 for (
const auto &jt : fnd)
657 string cand = jt.first;
658 vector<string> cand_tokens;
659 boost::split(cand_tokens, cand, boost::is_any_of(
"_"));
660 if (cand_tokens.size() != 2)
662 if (cand_tokens[0] != meas_tokens[0])
667 num_val_cmp = stof(cand_tokens[1]);
672 if (abs(num_val_cmp - num_val) <= 1e-6) {
678 if (res == MED_MAT_MISSING_VALUE)
679 err_msg =
"(" + to_string(AM_THRESHOLD_ERROR_NON_FATAL) +
")Error flag_threshold doesn't contain threshold for " + tokens[1];