1#ifndef LIGHTGBM_UTILS_COMMON_FUN_H_
2#define LIGHTGBM_UTILS_COMMON_FUN_H_
4#include <LightGBM/utils/log.h>
5#include <LightGBM/utils/openmp_wrapper.h>
28inline static char tolower(
char in) {
29 if (in <=
'Z' && in >=
'A')
30 return in - (
'Z' -
'z');
34inline static std::string Trim(std::string str) {
38 str.erase(str.find_last_not_of(
" \f\n\r\t\v") + 1);
39 str.erase(0, str.find_first_not_of(
" \f\n\r\t\v"));
43inline static std::string RemoveQuotationSymbol(std::string str) {
47 str.erase(str.find_last_not_of(
"'\"") + 1);
48 str.erase(0, str.find_first_not_of(
"'\""));
52inline static bool StartsWith(
const std::string& str,
const std::string prefix) {
53 if (str.substr(0, prefix.size()) == prefix) {
60inline static std::vector<std::string> Split(
const char* c_str,
char delimiter) {
61 std::vector<std::string> ret;
62 std::string str(c_str);
65 while (pos < str.length()) {
66 if (str[pos] == delimiter) {
68 ret.push_back(str.substr(i, pos - i));
77 ret.push_back(str.substr(i));
82inline static std::vector<std::string> SplitLines(
const char* c_str) {
83 std::vector<std::string> ret;
84 std::string str(c_str);
87 while (pos < str.length()) {
88 if (str[pos] ==
'\n' || str[pos] ==
'\r') {
90 ret.push_back(str.substr(i, pos - i));
93 while (str[pos] ==
'\n' || str[pos] ==
'\r') ++pos;
101 ret.push_back(str.substr(i));
106inline static std::vector<std::string> Split(
const char* c_str,
const char* delimiters) {
107 std::vector<std::string> ret;
108 std::string str(c_str);
111 while (pos < str.length()) {
112 bool met_delimiters =
false;
113 for (
int j = 0; delimiters[j] !=
'\0'; ++j) {
114 if (str[pos] == delimiters[j]) {
115 met_delimiters =
true;
119 if (met_delimiters) {
121 ret.push_back(str.substr(i, pos - i));
130 ret.push_back(str.substr(i));
136inline static const char* Atoi(
const char* p, T* out) {
146 }
else if (*p ==
'+') {
149 for (value = 0; *p >=
'0' && *p <=
'9'; ++p) {
150 value = value * 10 + (*p -
'0');
152 *out =
static_cast<T
>(sign * value);
160inline static double Pow(T base,
int power) {
162 return 1.0 / Pow(base, -power);
163 }
else if (power == 0) {
165 }
else if (power % 2 == 0) {
166 return Pow(base*base, power / 2);
167 }
else if (power % 3 == 0) {
168 return Pow(base*base*base, power / 3);
170 return base * Pow(base, power - 1);
174inline static const char* Atof(
const char* p,
double* out) {
176 double sign, value, scale;
187 }
else if (*p ==
'+') {
192 if ((*p >=
'0' && *p <=
'9') || *p ==
'.' || *p ==
'e' || *p ==
'E') {
194 for (value = 0.0; *p >=
'0' && *p <=
'9'; ++p) {
195 value = value * 10.0 + (*p -
'0');
203 while (*p >=
'0' && *p <=
'9') {
204 right = (*p -
'0') + right * 10.0;
208 value += right / Pow(10.0, nn);
214 if ((*p ==
'e') || (*p ==
'E')) {
221 }
else if (*p ==
'+') {
225 for (expon = 0; *p >=
'0' && *p <=
'9'; ++p) {
226 expon = expon * 10 + (*p -
'0');
228 if (expon > 308) expon = 308;
230 while (expon >= 50) { scale *= 1E50; expon -= 50; }
231 while (expon >= 8) { scale *= 1E8; expon -= 8; }
232 while (expon > 0) { scale *= 10.0; expon -= 1; }
235 *out = sign * (frac ? (value / scale) : (value * scale));
238 while (*(p + cnt) !=
'\0' && *(p + cnt) !=
' '
239 && *(p + cnt) !=
'\t' && *(p + cnt) !=
','
240 && *(p + cnt) !=
'\n' && *(p + cnt) !=
'\r'
241 && *(p + cnt) !=
':') {
245 std::string tmp_str(p, cnt);
246 std::transform(tmp_str.begin(), tmp_str.end(), tmp_str.begin(), Common::tolower);
247 if (tmp_str == std::string(
"na") || tmp_str == std::string(
"nan") ||
248 tmp_str == std::string(
"null")) {
250 }
else if (tmp_str == std::string(
"inf") || tmp_str == std::string(
"infinity")) {
253 Log::Fatal(
"Unknown token %s in data file", tmp_str.c_str());
266inline static bool AtoiAndCheck(
const char* p,
int* out) {
267 const char* after = Atoi(p, out);
268 if (*after !=
'\0') {
274inline static bool AtofAndCheck(
const char* p,
double* out) {
275 const char* after = Atof(p, out);
276 if (*after !=
'\0') {
282inline static unsigned CountDecimalDigit32(uint32_t n) {
283#if defined(_MSC_VER) || defined(__GNUC__)
284 static const uint32_t powers_of_10[] = {
298 _BitScanReverse(&i, n | 1);
299 uint32_t t = (i + 1) * 1233 >> 12;
301 uint32_t t = (32 - __builtin_clz(n | 1)) * 1233 >> 12;
303 return t - (n < powers_of_10[t]) + 1;
305 if (n < 10)
return 1;
306 if (n < 100)
return 2;
307 if (n < 1000)
return 3;
308 if (n < 10000)
return 4;
309 if (n < 100000)
return 5;
310 if (n < 1000000)
return 6;
311 if (n < 10000000)
return 7;
312 if (n < 100000000)
return 8;
313 if (n < 1000000000)
return 9;
318inline static void Uint32ToStr(uint32_t value,
char* buffer) {
319 const char kDigitsLut[200] = {
320 '0',
'0',
'0',
'1',
'0',
'2',
'0',
'3',
'0',
'4',
'0',
'5',
'0',
'6',
'0',
'7',
'0',
'8',
'0',
'9',
321 '1',
'0',
'1',
'1',
'1',
'2',
'1',
'3',
'1',
'4',
'1',
'5',
'1',
'6',
'1',
'7',
'1',
'8',
'1',
'9',
322 '2',
'0',
'2',
'1',
'2',
'2',
'2',
'3',
'2',
'4',
'2',
'5',
'2',
'6',
'2',
'7',
'2',
'8',
'2',
'9',
323 '3',
'0',
'3',
'1',
'3',
'2',
'3',
'3',
'3',
'4',
'3',
'5',
'3',
'6',
'3',
'7',
'3',
'8',
'3',
'9',
324 '4',
'0',
'4',
'1',
'4',
'2',
'4',
'3',
'4',
'4',
'4',
'5',
'4',
'6',
'4',
'7',
'4',
'8',
'4',
'9',
325 '5',
'0',
'5',
'1',
'5',
'2',
'5',
'3',
'5',
'4',
'5',
'5',
'5',
'6',
'5',
'7',
'5',
'8',
'5',
'9',
326 '6',
'0',
'6',
'1',
'6',
'2',
'6',
'3',
'6',
'4',
'6',
'5',
'6',
'6',
'6',
'7',
'6',
'8',
'6',
'9',
327 '7',
'0',
'7',
'1',
'7',
'2',
'7',
'3',
'7',
'4',
'7',
'5',
'7',
'6',
'7',
'7',
'7',
'8',
'7',
'9',
328 '8',
'0',
'8',
'1',
'8',
'2',
'8',
'3',
'8',
'4',
'8',
'5',
'8',
'6',
'8',
'7',
'8',
'8',
'8',
'9',
329 '9',
'0',
'9',
'1',
'9',
'2',
'9',
'3',
'9',
'4',
'9',
'5',
'9',
'6',
'9',
'7',
'9',
'8',
'9',
'9'
331 unsigned digit = CountDecimalDigit32(value);
335 while (value >= 100) {
336 const unsigned i = (value % 100) << 1;
338 *--buffer = kDigitsLut[i + 1];
339 *--buffer = kDigitsLut[i];
343 *--buffer = char(value) +
'0';
346 const unsigned i = value << 1;
347 *--buffer = kDigitsLut[i + 1];
348 *--buffer = kDigitsLut[i];
352inline static void Int32ToStr(int32_t value,
char* buffer) {
353 uint32_t u =
static_cast<uint32_t
>(value);
358 Uint32ToStr(u, buffer);
361inline static void DoubleToStr(
double value,
char* buffer,
size_t
367 sprintf_s(buffer, buffer_len,
"%.17g", value);
369 sprintf(buffer,
"%.17g", value);
373inline static const char* SkipSpaceAndTab(
const char* p) {
374 while (*p ==
' ' || *p ==
'\t') {
380inline static const char* SkipReturn(
const char* p) {
381 while (*p ==
'\n' || *p ==
'\r' || *p ==
' ') {
387template<
typename T,
typename T2>
388inline static std::vector<T2> ArrayCast(
const std::vector<T>& arr) {
389 std::vector<T2> ret(arr.size());
390 for (
size_t i = 0; i < arr.size(); ++i) {
391 ret[i] =
static_cast<T2
>(arr[i]);
396template<
typename T,
bool is_
float,
bool is_unsign>
398 void operator()(T value,
char* buffer,
size_t)
const {
399 Int32ToStr(value, buffer);
405 void operator()(T value,
char* buffer,
size_t
411 sprintf_s(buffer, buf_len,
"%g", value);
413 sprintf(buffer,
"%g", value);
420 void operator()(T value,
char* buffer,
size_t)
const {
421 Uint32ToStr(value, buffer);
426inline static std::string ArrayToStringFast(
const std::vector<T>& arr,
size_t n) {
427 if (arr.empty() || n == 0) {
428 return std::string(
"");
430 __TToStringHelperFast<T, std::is_floating_point<T>::value, std::is_unsigned<T>::value> helper;
431 const size_t buf_len = 16;
432 std::vector<char> buffer(buf_len);
433 std::stringstream str_buf;
434 helper(arr[0], buffer.data(), buf_len);
435 str_buf << buffer.data();
436 for (
size_t i = 1; i < std::min(n, arr.size()); ++i) {
437 helper(arr[i], buffer.data(), buf_len);
438 str_buf <<
' ' << buffer.data();
440 return str_buf.str();
443inline static std::string ArrayToString(
const std::vector<double>& arr,
size_t n) {
444 if (arr.empty() || n == 0) {
445 return std::string(
"");
447 const size_t buf_len = 32;
448 std::vector<char> buffer(buf_len);
449 std::stringstream str_buf;
450 DoubleToStr(arr[0], buffer.data(), buf_len);
451 str_buf << buffer.data();
452 for (
size_t i = 1; i < std::min(n, arr.size()); ++i) {
453 DoubleToStr(arr[i], buffer.data(), buf_len);
454 str_buf <<
' ' << buffer.data();
456 return str_buf.str();
459template<
typename T,
bool is_
float>
461 T operator()(
const std::string& str)
const {
463 Atoi(str.c_str(), &ret);
470 T operator()(
const std::string& str)
const {
471 return static_cast<T
>(std::stod(str));
476inline static std::vector<T> StringToArray(
const std::string& str,
char delimiter) {
477 std::vector<std::string> strs = Split(str.c_str(), delimiter);
479 ret.reserve(strs.size());
481 for (
const auto& s : strs) {
482 ret.push_back(helper(s));
488inline static std::vector<T> StringToArray(
const std::string& str,
int n) {
490 return std::vector<T>();
492 std::vector<std::string> strs = Split(str.c_str(),
' ');
493 CHECK(strs.size() ==
static_cast<size_t>(n));
495 ret.reserve(strs.size());
496 __StringToTHelper<T, std::is_floating_point<T>::value> helper;
497 for (
const auto& s : strs) {
498 ret.push_back(helper(s));
503template<
typename T,
bool is_
float>
505 const char* operator()(
const char*p, T* out)
const {
512 const char* operator()(
const char*p, T* out)
const {
514 auto ret = Atof(p, &tmp);
515 *out =
static_cast<T
>(tmp);
521inline static std::vector<T> StringToArrayFast(
const std::string& str,
int n) {
523 return std::vector<T>();
525 auto p_str = str.c_str();
526 __StringToTHelperFast<T, std::is_floating_point<T>::value> helper;
527 std::vector<T> ret(n);
528 for (
int i = 0; i < n; ++i) {
529 p_str = helper(p_str, &ret[i]);
535inline static std::string Join(
const std::vector<T>& strs,
const char* delimiter) {
537 return std::string(
"");
539 std::stringstream str_buf;
540 str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
542 for (
size_t i = 1; i < strs.size(); ++i) {
543 str_buf << delimiter;
546 return str_buf.str();
550inline static std::string Join(
const std::vector<T>& strs,
size_t start,
size_t end,
const char* delimiter) {
551 if (end - start <= 0) {
552 return std::string(
"");
554 start = std::min(start,
static_cast<size_t>(strs.size()) - 1);
555 end = std::min(end,
static_cast<size_t>(strs.size()));
556 std::stringstream str_buf;
557 str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
558 str_buf << strs[start];
559 for (
size_t i = start + 1; i < end; ++i) {
560 str_buf << delimiter;
563 return str_buf.str();
566inline static int64_t Pow2RoundUp(int64_t x) {
568 for (
int i = 0; i < 64; ++i) {
581inline static void Softmax(std::vector<double>* p_rec) {
582 std::vector<double> &rec = *p_rec;
583 double wmax = rec[0];
584 for (
size_t i = 1; i < rec.size(); ++i) {
585 wmax = std::max(rec[i], wmax);
588 for (
size_t i = 0; i < rec.size(); ++i) {
589 rec[i] = std::exp(rec[i] - wmax);
592 for (
size_t i = 0; i < rec.size(); ++i) {
593 rec[i] /=
static_cast<double>(wsum);
597inline static void Softmax(
const double* input,
double* output,
int len) {
598 double wmax = input[0];
599 for (
int i = 1; i < len; ++i) {
600 wmax = std::max(input[i], wmax);
603 for (
int i = 0; i < len; ++i) {
604 output[i] = std::exp(input[i] - wmax);
607 for (
int i = 0; i < len; ++i) {
608 output[i] /=
static_cast<double>(wsum);
613std::vector<const T*> ConstPtrInVectorWrapper(
const std::vector<std::unique_ptr<T>>& input) {
614 std::vector<const T*> ret;
615 for (
size_t i = 0; i < input.size(); ++i) {
616 ret.push_back(input.at(i).get());
621template<
typename T1,
typename T2>
622inline static void SortForPair(std::vector<T1>& keys, std::vector<T2>& values,
size_t start,
bool is_reverse =
false) {
623 std::vector<std::pair<T1, T2>> arr;
624 for (
size_t i = start; i < keys.size(); ++i) {
625 arr.emplace_back(keys[i], values[i]);
628 std::stable_sort(arr.begin(), arr.end(), [](
const std::pair<T1, T2>& a,
const std::pair<T1, T2>& b) {
629 return a.first < b.first;
632 std::stable_sort(arr.begin(), arr.end(), [](
const std::pair<T1, T2>& a,
const std::pair<T1, T2>& b) {
633 return a.first > b.first;
636 for (
size_t i = start; i < arr.size(); ++i) {
637 keys[i] = arr[i].first;
638 values[i] = arr[i].second;
643inline static std::vector<T*> Vector2Ptr(std::vector<std::vector<T>>& data) {
644 std::vector<T*> ptr(data.size());
645 for (
size_t i = 0; i < data.size(); ++i) {
646 ptr[i] = data[i].data();
652inline static std::vector<int> VectorSize(
const std::vector<std::vector<T>>& data) {
653 std::vector<int> ret(data.size());
654 for (
size_t i = 0; i < data.size(); ++i) {
655 ret[i] =
static_cast<int>(data[i].size());
660inline static double AvoidInf(
double x) {
663 }
else if (x <= -1e300) {
670inline static float AvoidInf(
float x) {
673 }
else if (x <= -1e38) {
680template<
typename _Iter>
inline
681static typename std::iterator_traits<_Iter>::value_type* IteratorValType(_Iter) {
685template<
typename _RanIt,
typename _Pr,
typename _VTRanIt>
inline
686static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
687 size_t len = _Last - _First;
688 const size_t kMinInnerLen = 1024;
693 num_threads = omp_get_num_threads();
695 if (len <= kMinInnerLen || num_threads <= 1) {
696 std::sort(_First, _Last, _Pred);
699 size_t inner_size = (len + num_threads - 1) / num_threads;
700 inner_size = std::max(inner_size, kMinInnerLen);
701 num_threads =
static_cast<int>((len + inner_size - 1) / inner_size);
702 #pragma omp parallel for schedule(static, 1)
703 for (
int i = 0; i < num_threads; ++i) {
704 size_t left = inner_size*i;
705 size_t right = left + inner_size;
706 right = std::min(right, len);
708 std::sort(_First + left, _First + right, _Pred);
712 std::vector<_VTRanIt> temp_buf(len);
713 _RanIt buf = temp_buf.begin();
714 size_t s = inner_size;
717 int loop_size =
static_cast<int>((len + s * 2 - 1) / (s * 2));
718 #pragma omp parallel for schedule(static, 1)
719 for (
int i = 0; i < loop_size; ++i) {
720 size_t left = i * 2 * s;
721 size_t mid = left + s;
722 size_t right = mid + s;
723 right = std::min(len, right);
724 if (mid >= right) {
continue; }
725 std::copy(_First + left, _First + mid, buf + left);
726 std::merge(buf + left, buf + mid, _First + mid, _First + right, _First + left, _Pred);
732template<
typename _RanIt,
typename _Pr>
inline
733static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) {
734 return ParallelSort(_First, _Last, _Pred, IteratorValType(_First));
739inline static void CheckElementsIntervalClosed(
const T *y, T ymin, T ymax,
int ny,
const char *callername) {
740 auto fatal_msg = [&y, &ymin, &ymax, &callername](
int i) {
741 std::ostringstream os;
742 os <<
"[%s]: does not tolerate element [#%i = " << y[i] <<
"] outside [" << ymin <<
", " << ymax <<
"]";
743 Log::Fatal(os.str().c_str(), callername, i);
745 for (
int i = 1; i < ny; i += 2) {
746 if (y[i - 1] < y[i]) {
747 if (y[i - 1] < ymin) {
749 }
else if (y[i] > ymax) {
753 if (y[i - 1] > ymax) {
755 }
else if (y[i] < ymin) {
761 if (y[ny - 1] < ymin || y[ny - 1] > ymax) {
769template <
typename T1,
typename T2>
770inline static void ObtainMinMaxSum(
const T1 *w,
int nw, T1 *mi, T1 *ma, T2 *su) {
791 for (; i < nw; i += 2) {
792 if (w[i - 1] < w[i]) {
793 minw = std::min(minw, w[i - 1]);
794 maxw = std::max(maxw, w[i]);
796 minw = std::min(minw, w[i]);
797 maxw = std::max(maxw, w[i - 1]);
799 sumw += w[i - 1] + w[i];
808 *su =
static_cast<T2
>(sumw);
813inline static std::vector<uint32_t> ConstructBitset(
const T* vals,
int n) {
814 std::vector<uint32_t> ret;
815 for (
int i = 0; i < n; ++i) {
816 int i1 = vals[i] / 32;
817 int i2 = vals[i] % 32;
818 if (
static_cast<int>(ret.size()) < i1 + 1) {
819 ret.resize(i1 + 1, 0);
821 ret[i1] |= (1 << i2);
827inline static bool FindInBitset(
const uint32_t* bits,
int n, T pos) {
833 return (bits[i1] >> i2) & 1;
836inline static bool CheckDoubleEqualOrdered(
double a,
double b) {
837 double upper = std::nextafter(a, INFINITY);
841inline static double GetDoubleUpperBound(
double a) {
842 return std::nextafter(a, INFINITY);;
845inline static size_t GetLine(
const char* str) {
847 while (*str !=
'\0' && *str !=
'\n' && *str !=
'\r') {
853inline static const char* SkipNewLine(
const char* str) {
864static int Sign(T x) {
865 return (x > T(0)) - (x < T(0));
869static T SafeLog(T x) {
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10