6#ifndef DMLC_PARAMETER_H_
7#define DMLC_PARAMETER_H_
40 explicit ParamError(
const std::string &msg)
50template<
typename ValueType>
51inline ValueType GetEnv(
const char *key,
52 ValueType default_value);
59template<
typename ValueType>
60inline void SetEnv(
const char *key,
68class FieldAccessEntry;
70template<
typename DType>
73template<
typename PType>
74struct ParamManagerSingleton;
89struct ParamFieldInfo {
98 std::string type_info_str;
100 std::string description;
127template<
typename PType>
140 template<
typename Container>
141 inline void Init(
const Container &kwargs,
142 parameter::ParamInitOption option = parameter::kAllowHidden) {
143 PType::__MANAGER__()->RunInit(
static_cast<PType*
>(
this),
144 kwargs.begin(), kwargs.end(),
157 template<
typename Container>
158 inline std::vector<std::pair<std::string, std::string> >
159 InitAllowUnknown(
const Container &kwargs) {
160 std::vector<std::pair<std::string, std::string> > unknown;
161 PType::__MANAGER__()->RunInit(
static_cast<PType*
>(
this),
162 kwargs.begin(), kwargs.end(),
163 &unknown, parameter::kAllowUnknown);
178 template <
typename Container>
179 std::vector<std::pair<std::string, std::string> >
180 UpdateAllowUnknown(Container
const& kwargs) {
181 std::vector<std::pair<std::string, std::string> > unknown;
182 PType::__MANAGER__()->RunUpdate(
static_cast<PType *
>(
this), kwargs.begin(),
183 kwargs.end(), parameter::kAllowUnknown,
194 template<
typename Container>
195 inline void UpdateDict(Container *dict)
const {
196 PType::__MANAGER__()->UpdateDict(this->head(), dict);
202 inline std::map<std::string, std::string> __DICT__()
const {
203 std::vector<std::pair<std::string, std::string> > vec
204 = PType::__MANAGER__()->GetDict(this->head());
205 return std::map<std::string, std::string>(vec.begin(), vec.end());
212 writer->
Write(this->__DICT__());
220 std::map<std::string, std::string> kwargs;
221 reader->
Read(&kwargs);
228 inline static std::vector<ParamFieldInfo> __FIELDS__() {
229 return PType::__MANAGER__()->GetFieldInfo();
235 inline static std::string __DOC__() {
236 std::ostringstream os;
237 PType::__MANAGER__()->PrintDocString(os);
248 template<
typename DType>
249 inline parameter::FieldEntry<DType>& DECLARE(
250 parameter::ParamManagerSingleton<PType> *manager,
251 const std::string &key, DType &ref) {
252 parameter::FieldEntry<DType> *e =
253 new parameter::FieldEntry<DType>();
254 e->Init(key, this->head(), ref);
255 manager->manager.AddEntry(key, e);
261 inline PType *head()
const {
262 return static_cast<PType*
>(
const_cast<Parameter<PType>*
>(
this));
286#define DMLC_DECLARE_PARAMETER(PType) \
287 static ::dmlc::parameter::ParamManager *__MANAGER__(); \
288 inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \
294#define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName)
301#define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName)
311#define DMLC_REGISTER_PARAMETER(PType) \
312 ::dmlc::parameter::ParamManager *PType::__MANAGER__() { \
313 static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
314 return &inst.manager; \
316 static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
317 __make__ ## PType ## ParamManager__ = \
318 (*PType::__MANAGER__()) \
332class FieldAccessEntry {
335 : has_default_(false), index_(0) {}
337 virtual ~FieldAccessEntry() {}
343 virtual void SetDefault(
void *head)
const = 0;
349 virtual void Set(
void *head,
const std::string &value)
const = 0;
351 virtual void Check(
void * )
const {}
356 virtual std::string GetStringValue(
void *head)
const = 0;
361 virtual ParamFieldInfo GetFieldInfo()
const = 0;
373 std::string description_;
377 char* GetRawPtr(
void* head)
const {
378 return reinterpret_cast<char*
>(head) + offset_;
384 virtual void PrintDefaultValueString(std::ostream &os)
const = 0;
386 friend class ParamManager;
397 for (
size_t i = 0; i < entry_.size(); ++i) {
406 inline FieldAccessEntry *Find(
const std::string &key)
const {
407 std::map<std::string, FieldAccessEntry*>::const_iterator it =
408 entry_map_.find(key);
409 if (it == entry_map_.end())
return NULL;
422 template<
typename RandomAccessIterator>
423 inline void RunInit(
void *head,
424 RandomAccessIterator begin,
425 RandomAccessIterator end,
426 std::vector<std::pair<std::string, std::string> > *unknown_args,
427 parameter::ParamInitOption option)
const {
428 std::set<FieldAccessEntry*> selected_args;
429 RunUpdate(head, begin, end, option, unknown_args, &selected_args);
430 for (
auto const& kv : entry_map_) {
431 if (selected_args.find(kv.second) == selected_args.cend()) {
432 kv.second->SetDefault(head);
435 for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin();
436 it != entry_map_.end(); ++it) {
437 if (selected_args.count(it->second) == 0) {
438 it->second->SetDefault(head);
454 template <
typename RandomAccessIterator>
455 void RunUpdate(
void *head,
456 RandomAccessIterator begin,
457 RandomAccessIterator end,
458 parameter::ParamInitOption option,
459 std::vector<std::pair<std::string, std::string> > *unknown_args,
460 std::set<FieldAccessEntry*>* selected_args =
nullptr)
const {
461 for (RandomAccessIterator it = begin; it != end; ++it) {
462 if (FieldAccessEntry *e = Find(it->first)) {
463 e->Set(head, it->second);
466 selected_args->insert(e);
469 if (unknown_args != NULL) {
470 unknown_args->push_back(*it);
472 if (option != parameter::kAllowUnknown) {
473 if (option == parameter::kAllowHidden &&
474 it->first.length() > 4 &&
475 it->first.find(
"__") == 0 &&
476 it->first.rfind(
"__") == it->first.length()-2) {
479 std::ostringstream os;
480 os <<
"Cannot find argument \'" << it->first <<
"\', Possible Arguments:\n";
481 os <<
"----------------\n";
483 throw dmlc::ParamError(os.str());
495 inline void AddEntry(
const std::string &key, FieldAccessEntry *e) {
496 e->index_ = entry_.size();
498 if (entry_map_.count(key) != 0) {
499 LOG(FATAL) <<
"key " << key <<
" has already been registered in " << name_;
510 inline void AddAlias(
const std::string& field,
const std::string& alias) {
511 if (entry_map_.count(field) == 0) {
512 LOG(FATAL) <<
"key " << field <<
" has not been registered in " << name_;
514 if (entry_map_.count(alias) != 0) {
515 LOG(FATAL) <<
"Alias " << alias <<
" has already been registered in " << name_;
517 entry_map_[alias] = entry_map_[field];
523 inline void set_name(
const std::string &name) {
530 inline std::vector<ParamFieldInfo> GetFieldInfo()
const {
531 std::vector<ParamFieldInfo> ret(entry_.size());
532 for (
size_t i = 0; i < entry_.size(); ++i) {
533 ret[i] = entry_[i]->GetFieldInfo();
541 inline void PrintDocString(std::ostream &os)
const {
542 for (
size_t i = 0; i < entry_.size(); ++i) {
543 ParamFieldInfo info = entry_[i]->GetFieldInfo();
544 os << info.name <<
" : " << info.type_info_str <<
'\n';
545 if (info.description.length() != 0) {
546 os <<
" " << info.description <<
'\n';
556 inline std::vector<std::pair<std::string, std::string> > GetDict(
void * head)
const {
557 std::vector<std::pair<std::string, std::string> > ret;
558 for (std::map<std::string, FieldAccessEntry*>::const_iterator
559 it = entry_map_.begin(); it != entry_map_.end(); ++it) {
560 ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head)));
570 template<
typename Container>
571 inline void UpdateDict(
void * head, Container* dict)
const {
572 for (std::map<std::string, FieldAccessEntry*>::const_iterator
573 it = entry_map_.begin(); it != entry_map_.end(); ++it) {
574 (*dict)[it->first] = it->second->GetStringValue(head);
582 std::vector<FieldAccessEntry*> entry_;
584 std::map<std::string, FieldAccessEntry*> entry_map_;
591template<
typename PType>
592struct ParamManagerSingleton {
593 ParamManager manager;
594 explicit ParamManagerSingleton(
const std::string ¶m_name) {
596 manager.set_name(param_name);
597 param.__DECLARE__(
this);
603template<
typename TEntry,
typename DType>
604class FieldEntryBase :
public FieldAccessEntry {
607 typedef TEntry EntryType;
609 void Set(
void *head,
const std::string &value)
const override {
610 std::istringstream is(value);
611 is >> this->Get(head);
619 is.setstate(std::ios::failbit);
break;
625 std::ostringstream os;
626 os <<
"Invalid Parameter format for " << key_
627 <<
" expect " << type_ <<
" but value=\'" << value<<
'\'';
628 throw dmlc::ParamError(os.str());
632 std::string GetStringValue(
void *head)
const override {
633 std::ostringstream os;
634 PrintValue(os, this->Get(head));
637 ParamFieldInfo GetFieldInfo()
const override {
639 std::ostringstream os;
644 os <<
',' <<
" optional, default=";
645 PrintDefaultValueString(os);
649 info.type_info_str = os.str();
650 info.description = description_;
654 void SetDefault(
void *head)
const override {
656 std::ostringstream os;
657 os <<
"Required parameter " << key_
658 <<
" of " << type_ <<
" is not presented";
659 throw dmlc::ParamError(os.str());
661 this->Get(head) = default_value_;
665 inline TEntry &self() {
666 return *(
static_cast<TEntry*
>(
this));
669 inline TEntry &set_default(
const DType &default_value) {
670 default_value_ = default_value;
676 inline TEntry &describe(
const std::string &description) {
677 description_ = description;
682 inline void Init(
const std::string &key,
683 void *head, DType &ref) {
685 if (this->type_.length() == 0) {
686 this->type_ = dmlc::type_name<DType>();
688 this->offset_ = ((
char*)&ref) - ((
char*)head);
693 virtual void PrintValue(std::ostream &os, DType value)
const {
696 void PrintDefaultValueString(std::ostream &os)
const override {
697 PrintValue(os, default_value_);
702 inline DType &Get(
void *head)
const {
703 return *(DType*)this->GetRawPtr(head);
706 DType default_value_;
710template<
typename TEntry,
typename DType>
711class FieldEntryNumeric
712 :
public FieldEntryBase<TEntry, DType> {
715 : has_begin_(false), has_end_(false) {}
717 virtual TEntry &set_range(DType begin, DType end) {
718 begin_ = begin; end_ = end;
719 has_begin_ =
true; has_end_ =
true;
723 virtual TEntry &set_lower_bound(DType begin) {
724 begin_ = begin; has_begin_ =
true;
728 virtual void Check(
void *head)
const {
729 FieldEntryBase<TEntry, DType>::Check(head);
730 DType v = this->Get(head);
731 if (has_begin_ && has_end_) {
732 if (v < begin_ || v > end_) {
733 std::ostringstream os;
734 os <<
"value " << v <<
" for Parameter " << this->key_
735 <<
" exceed bound [" << begin_ <<
',' << end_ <<
']' <<
'\n';
736 os << this->key_ <<
": " << this->description_;
737 throw dmlc::ParamError(os.str());
739 }
else if (has_begin_ && v < begin_) {
740 std::ostringstream os;
741 os <<
"value " << v <<
" for Parameter " << this->key_
742 <<
" should be greater equal to " << begin_ <<
'\n';
743 os << this->key_ <<
": " << this->description_;
744 throw dmlc::ParamError(os.str());
745 }
else if (has_end_ && v > end_) {
746 std::ostringstream os;
747 os <<
"value " << v <<
" for Parameter " << this->key_
748 <<
" should be smaller equal to " << end_ <<
'\n';
749 os << this->key_ <<
": " << this->description_;
750 throw dmlc::ParamError(os.str());
756 bool has_begin_, has_end_;
766template<
typename DType>
768 public IfThenElseType<dmlc::is_arithmetic<DType>::value,
769 FieldEntryNumeric<FieldEntry<DType>, DType>,
770 FieldEntryBase<FieldEntry<DType>, DType> >::Type {
776 :
public FieldEntryNumeric<FieldEntry<int>, int> {
779 FieldEntry() : is_enum_(false) {}
781 typedef FieldEntryNumeric<FieldEntry<int>,
int> Parent;
783 virtual void Set(
void *head,
const std::string &value)
const {
785 std::map<std::string, int>::const_iterator it = enum_map_.find(value);
786 std::ostringstream os;
787 if (it == enum_map_.end()) {
788 os <<
"Invalid Input: \'" << value;
789 os <<
"\', valid values are: ";
791 throw dmlc::ParamError(os.str());
794 Parent::Set(head, os.str());
797 Parent::Set(head, value);
800 virtual ParamFieldInfo GetFieldInfo()
const {
803 std::ostringstream os;
808 os <<
',' <<
"optional, default=";
809 PrintDefaultValueString(os);
813 info.type_info_str = os.str();
814 info.description = description_;
817 return Parent::GetFieldInfo();
821 inline FieldEntry<int> &add_enum(
const std::string &key,
int value) {
822 if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
823 enum_back_map_.count(value) != 0) {
824 std::ostringstream os;
825 os <<
"Enum " <<
"(" << key <<
": " << value <<
" exisit!" <<
")\n";
827 for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
828 it != enum_map_.end(); ++it) {
829 os <<
"(" << it->first <<
": " << it->second <<
"), ";
831 throw dmlc::ParamError(os.str());
833 enum_map_[key] = value;
834 enum_back_map_[value] = key;
843 std::map<std::string, int> enum_map_;
845 std::map<int, std::string> enum_back_map_;
847 virtual void PrintDefaultValueString(std::ostream &os)
const {
849 PrintValue(os, default_value_);
853 virtual void PrintValue(std::ostream &os,
int value)
const {
855 CHECK_NE(enum_back_map_.count(value), 0U)
856 <<
"Value not found in enum declared";
857 os << enum_back_map_.at(value);
865 inline void PrintEnums(std::ostream &os)
const {
867 for (std::map<std::string, int>::const_iterator
868 it = enum_map_.begin(); it != enum_map_.end(); ++it) {
869 if (it != enum_map_.begin()) {
872 os <<
"\'" << it->first <<
'\'';
881class FieldEntry<optional<int> >
882 :
public FieldEntryBase<FieldEntry<optional<int> >, optional<int> > {
885 FieldEntry() : is_enum_(false) {}
887 typedef FieldEntryBase<FieldEntry<optional<int> >, optional<int> > Parent;
889 virtual void Set(
void *head,
const std::string &value)
const {
890 if (is_enum_ && value !=
"None") {
891 std::map<std::string, int>::const_iterator it = enum_map_.find(value);
892 std::ostringstream os;
893 if (it == enum_map_.end()) {
894 os <<
"Invalid Input: \'" << value;
895 os <<
"\', valid values are: ";
897 throw dmlc::ParamError(os.str());
900 Parent::Set(head, os.str());
903 Parent::Set(head, value);
906 virtual ParamFieldInfo GetFieldInfo()
const {
909 std::ostringstream os;
914 os <<
',' <<
"optional, default=";
915 PrintDefaultValueString(os);
919 info.type_info_str = os.str();
920 info.description = description_;
923 return Parent::GetFieldInfo();
927 inline FieldEntry<optional<int> > &add_enum(
const std::string &key,
int value) {
928 CHECK_NE(key,
"None") <<
"None is reserved for empty optional<int>";
929 if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
930 enum_back_map_.count(value) != 0) {
931 std::ostringstream os;
932 os <<
"Enum " <<
"(" << key <<
": " << value <<
" exisit!" <<
")\n";
934 for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
935 it != enum_map_.end(); ++it) {
936 os <<
"(" << it->first <<
": " << it->second <<
"), ";
938 throw dmlc::ParamError(os.str());
940 enum_map_[key] = value;
941 enum_back_map_[value] = key;
950 std::map<std::string, int> enum_map_;
952 std::map<int, std::string> enum_back_map_;
954 virtual void PrintDefaultValueString(std::ostream &os)
const {
956 PrintValue(os, default_value_);
960 virtual void PrintValue(std::ostream &os, optional<int> value)
const {
965 CHECK_NE(enum_back_map_.count(value.value()), 0U)
966 <<
"Value not found in enum declared";
967 os << enum_back_map_.at(value.value());
976 inline void PrintEnums(std::ostream &os)
const {
978 for (std::map<std::string, int>::const_iterator
979 it = enum_map_.begin(); it != enum_map_.end(); ++it) {
981 os <<
"\'" << it->first <<
'\'';
990 :
public FieldEntryBase<FieldEntry<std::string>, std::string> {
993 typedef FieldEntryBase<FieldEntry<std::string>, std::string> Parent;
995 virtual void Set(
void *head,
const std::string &value)
const {
996 this->Get(head) = value;
999 virtual void PrintDefaultValueString(std::ostream &os)
const {
1000 os <<
'\'' << default_value_ <<
'\'';
1006class FieldEntry<bool>
1007 :
public FieldEntryBase<FieldEntry<bool>, bool> {
1010 typedef FieldEntryBase<FieldEntry<bool>,
bool> Parent;
1012 virtual void Set(
void *head,
const std::string &value)
const {
1013 std::string lower_case; lower_case.resize(value.length());
1014 std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower);
1015 bool &ref = this->Get(head);
1016 if (lower_case ==
"true") {
1018 }
else if (lower_case ==
"false") {
1020 }
else if (lower_case ==
"1") {
1022 }
else if (lower_case ==
"0") {
1025 std::ostringstream os;
1026 os <<
"Invalid Parameter format for " << key_
1027 <<
" expect " << type_ <<
" but value=\'" << value<<
'\'';
1028 throw dmlc::ParamError(os.str());
1034 virtual void PrintValue(std::ostream &os,
bool value)
const {
1035 os << static_cast<int>(value);
1044class FieldEntry<float> :
public FieldEntryNumeric<FieldEntry<float>, float> {
1047 typedef FieldEntryNumeric<FieldEntry<float>,
float> Parent;
1049 virtual void Set(
void *head,
const std::string &value)
const {
1053 }
catch (
const std::invalid_argument &) {
1054 std::ostringstream os;
1055 os <<
"Invalid Parameter format for " << key_ <<
" expect " << type_
1056 <<
" but value=\'" << value <<
'\'';
1057 throw dmlc::ParamError(os.str());
1058 }
catch (
const std::out_of_range&) {
1059 std::ostringstream os;
1060 os <<
"Out of range value for " << key_ <<
", value=\'" << value <<
'\'';
1061 throw dmlc::ParamError(os.str());
1063 CHECK_LE(pos, value.length());
1064 if (pos < value.length()) {
1065 std::ostringstream os;
1066 os <<
"Some trailing characters could not be parsed: \'"
1067 << value.substr(pos) <<
"\'";
1068 throw dmlc::ParamError(os.str());
1074 virtual void PrintValue(std::ostream &os,
float value)
const {
1075 os << std::setprecision(std::numeric_limits<float>::max_digits10) << value;
1082class FieldEntry<double>
1083 :
public FieldEntryNumeric<FieldEntry<double>, double> {
1086 typedef FieldEntryNumeric<FieldEntry<double>,
double> Parent;
1088 virtual void Set(
void *head,
const std::string &value)
const {
1092 }
catch (
const std::invalid_argument &) {
1093 std::ostringstream os;
1094 os <<
"Invalid Parameter format for " << key_ <<
" expect " << type_
1095 <<
" but value=\'" << value <<
'\'';
1096 throw dmlc::ParamError(os.str());
1097 }
catch (
const std::out_of_range&) {
1098 std::ostringstream os;
1099 os <<
"Out of range value for " << key_ <<
", value=\'" << value <<
'\'';
1100 throw dmlc::ParamError(os.str());
1102 CHECK_LE(pos, value.length());
1103 if (pos < value.length()) {
1104 std::ostringstream os;
1105 os <<
"Some trailing characters could not be parsed: \'"
1106 << value.substr(pos) <<
"\'";
1107 throw dmlc::ParamError(os.str());
1113 virtual void PrintValue(std::ostream &os,
double value)
const {
1114 os << std::setprecision(std::numeric_limits<double>::max_digits10) << value;
1123template<
typename ValueType>
1124inline ValueType GetEnv(
const char *key,
1125 ValueType default_value) {
1126 const char *val = getenv(key);
1130 if (val ==
nullptr || !*val) {
1131 return default_value;
1134 parameter::FieldEntry<ValueType> e;
1135 e.Init(key, &ret, ret);
1141template<
typename ValueType>
1142inline void SetEnv(
const char *key,
1144 parameter::FieldEntry<ValueType> e;
1145 e.Init(key, &value, value);
1147 _putenv_s(key, e.GetStringValue(&value).c_str());
1149 setenv(key, e.GetStringValue(&value).c_str(), 1);
Lightweight JSON Reader to read any STL compositions and structs. The user need to know the schema of...
Definition json.h:44
void Read(ValueType *out_value)
Read next ValueType.
Lightweight json to write any STL compositions.
Definition json.h:190
void Write(const ValueType &value)
Write value to json.
defines console logging options for xgboost. Use to enforce unified print behavior.
namespace for dmlc
Definition array_view.h:12
float stof(const std::string &value, size_t *pos=nullptr)
A faster implementation of stof(). See documentation of std::stof() for more information....
Definition strtonum.h:467
bool isspace(char c)
Inline implementation of isspace(). Tests whether the given character is a whitespace letter.
Definition strtonum.h:26
double stod(const std::string &value, size_t *pos=nullptr)
A faster implementation of stod(). See documentation of std::stod() for more information....
Definition strtonum.h:497
bool Init(int argc, char *argv[])
initializes the engine module
Definition engine.cc:43
void Error(const char *fmt,...)
report error message, same as check
Definition utils.h:103
void Check(bool exp, const char *fmt,...)
same as assert, but this is intended to be used as a message for users
Definition utils.h:91
Container to hold optional data.
Macros common to all headers.
A faster implementation of strtof and strtod.
exception class that will be thrown by default logger if DMLC_LOG_FATAL_THROW == 1
Definition logging.h:29
type traits information header