Medial Code Documentation
Loading...
Searching...
No Matches
common.h
1#ifndef LIGHTGBM_UTILS_COMMON_FUN_H_
2#define LIGHTGBM_UTILS_COMMON_FUN_H_
3
4#include <LightGBM/utils/log.h>
5#include <LightGBM/utils/openmp_wrapper.h>
6
7#include <cstdio>
8#include <string>
9#include <vector>
10#include <sstream>
11#include <cstdint>
12#include <algorithm>
13#include <cmath>
14#include <functional>
15#include <memory>
16#include <iterator>
17#include <type_traits>
18#include <iomanip>
19
20#ifdef _MSC_VER
21#include "intrin.h"
22#endif
23
24namespace LightGBM {
25
26namespace Common {
27
28inline static char tolower(char in) {
29 if (in <= 'Z' && in >= 'A')
30 return in - ('Z' - 'z');
31 return in;
32}
33
34inline static std::string Trim(std::string str) {
35 if (str.empty()) {
36 return str;
37 }
38 str.erase(str.find_last_not_of(" \f\n\r\t\v") + 1);
39 str.erase(0, str.find_first_not_of(" \f\n\r\t\v"));
40 return str;
41}
42
43inline static std::string RemoveQuotationSymbol(std::string str) {
44 if (str.empty()) {
45 return str;
46 }
47 str.erase(str.find_last_not_of("'\"") + 1);
48 str.erase(0, str.find_first_not_of("'\""));
49 return str;
50}
51
52inline static bool StartsWith(const std::string& str, const std::string prefix) {
53 if (str.substr(0, prefix.size()) == prefix) {
54 return true;
55 } else {
56 return false;
57 }
58}
59
60inline static std::vector<std::string> Split(const char* c_str, char delimiter) {
61 std::vector<std::string> ret;
62 std::string str(c_str);
63 size_t i = 0;
64 size_t pos = 0;
65 while (pos < str.length()) {
66 if (str[pos] == delimiter) {
67 if (i < pos) {
68 ret.push_back(str.substr(i, pos - i));
69 }
70 ++pos;
71 i = pos;
72 } else {
73 ++pos;
74 }
75 }
76 if (i < pos) {
77 ret.push_back(str.substr(i));
78 }
79 return ret;
80}
81
82inline static std::vector<std::string> SplitLines(const char* c_str) {
83 std::vector<std::string> ret;
84 std::string str(c_str);
85 size_t i = 0;
86 size_t pos = 0;
87 while (pos < str.length()) {
88 if (str[pos] == '\n' || str[pos] == '\r') {
89 if (i < pos) {
90 ret.push_back(str.substr(i, pos - i));
91 }
92 // skip the line endings
93 while (str[pos] == '\n' || str[pos] == '\r') ++pos;
94 // new begin
95 i = pos;
96 } else {
97 ++pos;
98 }
99 }
100 if (i < pos) {
101 ret.push_back(str.substr(i));
102 }
103 return ret;
104}
105
106inline static std::vector<std::string> Split(const char* c_str, const char* delimiters) {
107 std::vector<std::string> ret;
108 std::string str(c_str);
109 size_t i = 0;
110 size_t pos = 0;
111 while (pos < str.length()) {
112 bool met_delimiters = false;
113 for (int j = 0; delimiters[j] != '\0'; ++j) {
114 if (str[pos] == delimiters[j]) {
115 met_delimiters = true;
116 break;
117 }
118 }
119 if (met_delimiters) {
120 if (i < pos) {
121 ret.push_back(str.substr(i, pos - i));
122 }
123 ++pos;
124 i = pos;
125 } else {
126 ++pos;
127 }
128 }
129 if (i < pos) {
130 ret.push_back(str.substr(i));
131 }
132 return ret;
133}
134
135template<typename T>
136inline static const char* Atoi(const char* p, T* out) {
137 int sign;
138 T value;
139 while (*p == ' ') {
140 ++p;
141 }
142 sign = 1;
143 if (*p == '-') {
144 sign = -1;
145 ++p;
146 } else if (*p == '+') {
147 ++p;
148 }
149 for (value = 0; *p >= '0' && *p <= '9'; ++p) {
150 value = value * 10 + (*p - '0');
151 }
152 *out = static_cast<T>(sign * value);
153 while (*p == ' ') {
154 ++p;
155 }
156 return p;
157}
158
159template<typename T>
160inline static double Pow(T base, int power) {
161 if (power < 0) {
162 return 1.0 / Pow(base, -power);
163 } else if (power == 0) {
164 return 1;
165 } else if (power % 2 == 0) {
166 return Pow(base*base, power / 2);
167 } else if (power % 3 == 0) {
168 return Pow(base*base*base, power / 3);
169 } else {
170 return base * Pow(base, power - 1);
171 }
172}
173
174inline static const char* Atof(const char* p, double* out) {
175 int frac;
176 double sign, value, scale;
177 *out = NAN;
178 // Skip leading white space, if any.
179 while (*p == ' ') {
180 ++p;
181 }
182 // Get sign, if any.
183 sign = 1.0;
184 if (*p == '-') {
185 sign = -1.0;
186 ++p;
187 } else if (*p == '+') {
188 ++p;
189 }
190
191 // is a number
192 if ((*p >= '0' && *p <= '9') || *p == '.' || *p == 'e' || *p == 'E') {
193 // Get digits before decimal point or exponent, if any.
194 for (value = 0.0; *p >= '0' && *p <= '9'; ++p) {
195 value = value * 10.0 + (*p - '0');
196 }
197
198 // Get digits after decimal point, if any.
199 if (*p == '.') {
200 double right = 0.0;
201 int nn = 0;
202 ++p;
203 while (*p >= '0' && *p <= '9') {
204 right = (*p - '0') + right * 10.0;
205 ++nn;
206 ++p;
207 }
208 value += right / Pow(10.0, nn);
209 }
210
211 // Handle exponent, if any.
212 frac = 0;
213 scale = 1.0;
214 if ((*p == 'e') || (*p == 'E')) {
215 uint32_t expon;
216 // Get sign of exponent, if any.
217 ++p;
218 if (*p == '-') {
219 frac = 1;
220 ++p;
221 } else if (*p == '+') {
222 ++p;
223 }
224 // Get digits of exponent, if any.
225 for (expon = 0; *p >= '0' && *p <= '9'; ++p) {
226 expon = expon * 10 + (*p - '0');
227 }
228 if (expon > 308) expon = 308;
229 // Calculate scaling factor.
230 while (expon >= 50) { scale *= 1E50; expon -= 50; }
231 while (expon >= 8) { scale *= 1E8; expon -= 8; }
232 while (expon > 0) { scale *= 10.0; expon -= 1; }
233 }
234 // Return signed and scaled floating point result.
235 *out = sign * (frac ? (value / scale) : (value * scale));
236 } else {
237 size_t cnt = 0;
238 while (*(p + cnt) != '\0' && *(p + cnt) != ' '
239 && *(p + cnt) != '\t' && *(p + cnt) != ','
240 && *(p + cnt) != '\n' && *(p + cnt) != '\r'
241 && *(p + cnt) != ':') {
242 ++cnt;
243 }
244 if (cnt > 0) {
245 std::string tmp_str(p, cnt);
246 std::transform(tmp_str.begin(), tmp_str.end(), tmp_str.begin(), Common::tolower);
247 if (tmp_str == std::string("na") || tmp_str == std::string("nan") ||
248 tmp_str == std::string("null")) {
249 *out = NAN;
250 } else if (tmp_str == std::string("inf") || tmp_str == std::string("infinity")) {
251 *out = sign * 1e308;
252 } else {
253 Log::Fatal("Unknown token %s in data file", tmp_str.c_str());
254 }
255 p += cnt;
256 }
257 }
258
259 while (*p == ' ') {
260 ++p;
261 }
262
263 return p;
264}
265
266inline static bool AtoiAndCheck(const char* p, int* out) {
267 const char* after = Atoi(p, out);
268 if (*after != '\0') {
269 return false;
270 }
271 return true;
272}
273
274inline static bool AtofAndCheck(const char* p, double* out) {
275 const char* after = Atof(p, out);
276 if (*after != '\0') {
277 return false;
278 }
279 return true;
280}
281
282inline static unsigned CountDecimalDigit32(uint32_t n) {
283#if defined(_MSC_VER) || defined(__GNUC__)
284 static const uint32_t powers_of_10[] = {
285 0,
286 10,
287 100,
288 1000,
289 10000,
290 100000,
291 1000000,
292 10000000,
293 100000000,
294 1000000000
295 };
296#ifdef _MSC_VER
297 unsigned long i = 0;
298 _BitScanReverse(&i, n | 1);
299 uint32_t t = (i + 1) * 1233 >> 12;
300#elif __GNUC__
301 uint32_t t = (32 - __builtin_clz(n | 1)) * 1233 >> 12;
302#endif
303 return t - (n < powers_of_10[t]) + 1;
304#else
305 if (n < 10) return 1;
306 if (n < 100) return 2;
307 if (n < 1000) return 3;
308 if (n < 10000) return 4;
309 if (n < 100000) return 5;
310 if (n < 1000000) return 6;
311 if (n < 10000000) return 7;
312 if (n < 100000000) return 8;
313 if (n < 1000000000) return 9;
314 return 10;
315#endif
316}
317
318inline static void Uint32ToStr(uint32_t value, char* buffer) {
319 const char kDigitsLut[200] = {
320 '0', '0', '0', '1', '0', '2', '0', '3', '0', '4', '0', '5', '0', '6', '0', '7', '0', '8', '0', '9',
321 '1', '0', '1', '1', '1', '2', '1', '3', '1', '4', '1', '5', '1', '6', '1', '7', '1', '8', '1', '9',
322 '2', '0', '2', '1', '2', '2', '2', '3', '2', '4', '2', '5', '2', '6', '2', '7', '2', '8', '2', '9',
323 '3', '0', '3', '1', '3', '2', '3', '3', '3', '4', '3', '5', '3', '6', '3', '7', '3', '8', '3', '9',
324 '4', '0', '4', '1', '4', '2', '4', '3', '4', '4', '4', '5', '4', '6', '4', '7', '4', '8', '4', '9',
325 '5', '0', '5', '1', '5', '2', '5', '3', '5', '4', '5', '5', '5', '6', '5', '7', '5', '8', '5', '9',
326 '6', '0', '6', '1', '6', '2', '6', '3', '6', '4', '6', '5', '6', '6', '6', '7', '6', '8', '6', '9',
327 '7', '0', '7', '1', '7', '2', '7', '3', '7', '4', '7', '5', '7', '6', '7', '7', '7', '8', '7', '9',
328 '8', '0', '8', '1', '8', '2', '8', '3', '8', '4', '8', '5', '8', '6', '8', '7', '8', '8', '8', '9',
329 '9', '0', '9', '1', '9', '2', '9', '3', '9', '4', '9', '5', '9', '6', '9', '7', '9', '8', '9', '9'
330 };
331 unsigned digit = CountDecimalDigit32(value);
332 buffer += digit;
333 *buffer = '\0';
334
335 while (value >= 100) {
336 const unsigned i = (value % 100) << 1;
337 value /= 100;
338 *--buffer = kDigitsLut[i + 1];
339 *--buffer = kDigitsLut[i];
340 }
341
342 if (value < 10) {
343 *--buffer = char(value) + '0';
344 }
345 else {
346 const unsigned i = value << 1;
347 *--buffer = kDigitsLut[i + 1];
348 *--buffer = kDigitsLut[i];
349 }
350}
351
352inline static void Int32ToStr(int32_t value, char* buffer) {
353 uint32_t u = static_cast<uint32_t>(value);
354 if (value < 0) {
355 *buffer++ = '-';
356 u = ~u + 1;
357 }
358 Uint32ToStr(u, buffer);
359}
360
361inline static void DoubleToStr(double value, char* buffer, size_t
362 #ifdef _MSC_VER
363 buffer_len
364 #endif
365) {
366 #ifdef _MSC_VER
367 sprintf_s(buffer, buffer_len, "%.17g", value);
368 #else
369 sprintf(buffer, "%.17g", value);
370 #endif
371}
372
373inline static const char* SkipSpaceAndTab(const char* p) {
374 while (*p == ' ' || *p == '\t') {
375 ++p;
376 }
377 return p;
378}
379
380inline static const char* SkipReturn(const char* p) {
381 while (*p == '\n' || *p == '\r' || *p == ' ') {
382 ++p;
383 }
384 return p;
385}
386
387template<typename T, typename T2>
388inline static std::vector<T2> ArrayCast(const std::vector<T>& arr) {
389 std::vector<T2> ret(arr.size());
390 for (size_t i = 0; i < arr.size(); ++i) {
391 ret[i] = static_cast<T2>(arr[i]);
392 }
393 return ret;
394}
395
396template<typename T, bool is_float, bool is_unsign>
398 void operator()(T value, char* buffer, size_t) const {
399 Int32ToStr(value, buffer);
400 }
401};
402
403template<typename T>
404struct __TToStringHelperFast<T, true, false> {
405 void operator()(T value, char* buffer, size_t
406 #ifdef _MSC_VER
407 buf_len
408 #endif
409 ) const {
410 #ifdef _MSC_VER
411 sprintf_s(buffer, buf_len, "%g", value);
412 #else
413 sprintf(buffer, "%g", value);
414 #endif
415 }
416};
417
418template<typename T>
419struct __TToStringHelperFast<T, false, true> {
420 void operator()(T value, char* buffer, size_t) const {
421 Uint32ToStr(value, buffer);
422 }
423};
424
425template<typename T>
426inline static std::string ArrayToStringFast(const std::vector<T>& arr, size_t n) {
427 if (arr.empty() || n == 0) {
428 return std::string("");
429 }
430 __TToStringHelperFast<T, std::is_floating_point<T>::value, std::is_unsigned<T>::value> helper;
431 const size_t buf_len = 16;
432 std::vector<char> buffer(buf_len);
433 std::stringstream str_buf;
434 helper(arr[0], buffer.data(), buf_len);
435 str_buf << buffer.data();
436 for (size_t i = 1; i < std::min(n, arr.size()); ++i) {
437 helper(arr[i], buffer.data(), buf_len);
438 str_buf << ' ' << buffer.data();
439 }
440 return str_buf.str();
441}
442
443inline static std::string ArrayToString(const std::vector<double>& arr, size_t n) {
444 if (arr.empty() || n == 0) {
445 return std::string("");
446 }
447 const size_t buf_len = 32;
448 std::vector<char> buffer(buf_len);
449 std::stringstream str_buf;
450 DoubleToStr(arr[0], buffer.data(), buf_len);
451 str_buf << buffer.data();
452 for (size_t i = 1; i < std::min(n, arr.size()); ++i) {
453 DoubleToStr(arr[i], buffer.data(), buf_len);
454 str_buf << ' ' << buffer.data();
455 }
456 return str_buf.str();
457}
458
459template<typename T, bool is_float>
461 T operator()(const std::string& str) const {
462 T ret = 0;
463 Atoi(str.c_str(), &ret);
464 return ret;
465 }
466};
467
468template<typename T>
469struct __StringToTHelper<T, true> {
470 T operator()(const std::string& str) const {
471 return static_cast<T>(std::stod(str));
472 }
473};
474
475template<typename T>
476inline static std::vector<T> StringToArray(const std::string& str, char delimiter) {
477 std::vector<std::string> strs = Split(str.c_str(), delimiter);
478 std::vector<T> ret;
479 ret.reserve(strs.size());
481 for (const auto& s : strs) {
482 ret.push_back(helper(s));
483 }
484 return ret;
485}
486
487template<typename T>
488inline static std::vector<T> StringToArray(const std::string& str, int n) {
489 if (n == 0) {
490 return std::vector<T>();
491 }
492 std::vector<std::string> strs = Split(str.c_str(), ' ');
493 CHECK(strs.size() == static_cast<size_t>(n));
494 std::vector<T> ret;
495 ret.reserve(strs.size());
496 __StringToTHelper<T, std::is_floating_point<T>::value> helper;
497 for (const auto& s : strs) {
498 ret.push_back(helper(s));
499 }
500 return ret;
501}
502
503template<typename T, bool is_float>
505 const char* operator()(const char*p, T* out) const {
506 return Atoi(p, out);
507 }
508};
509
510template<typename T>
511struct __StringToTHelperFast<T, true> {
512 const char* operator()(const char*p, T* out) const {
513 double tmp = 0.0f;
514 auto ret = Atof(p, &tmp);
515 *out = static_cast<T>(tmp);
516 return ret;
517 }
518};
519
520template<typename T>
521inline static std::vector<T> StringToArrayFast(const std::string& str, int n) {
522 if (n == 0) {
523 return std::vector<T>();
524 }
525 auto p_str = str.c_str();
526 __StringToTHelperFast<T, std::is_floating_point<T>::value> helper;
527 std::vector<T> ret(n);
528 for (int i = 0; i < n; ++i) {
529 p_str = helper(p_str, &ret[i]);
530 }
531 return ret;
532}
533
534template<typename T>
535inline static std::string Join(const std::vector<T>& strs, const char* delimiter) {
536 if (strs.empty()) {
537 return std::string("");
538 }
539 std::stringstream str_buf;
540 str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
541 str_buf << strs[0];
542 for (size_t i = 1; i < strs.size(); ++i) {
543 str_buf << delimiter;
544 str_buf << strs[i];
545 }
546 return str_buf.str();
547}
548
549template<typename T>
550inline static std::string Join(const std::vector<T>& strs, size_t start, size_t end, const char* delimiter) {
551 if (end - start <= 0) {
552 return std::string("");
553 }
554 start = std::min(start, static_cast<size_t>(strs.size()) - 1);
555 end = std::min(end, static_cast<size_t>(strs.size()));
556 std::stringstream str_buf;
557 str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
558 str_buf << strs[start];
559 for (size_t i = start + 1; i < end; ++i) {
560 str_buf << delimiter;
561 str_buf << strs[i];
562 }
563 return str_buf.str();
564}
565
566inline static int64_t Pow2RoundUp(int64_t x) {
567 int64_t t = 1;
568 for (int i = 0; i < 64; ++i) {
569 if (t >= x) {
570 return t;
571 }
572 t <<= 1;
573 }
574 return 0;
575}
576
581inline static void Softmax(std::vector<double>* p_rec) {
582 std::vector<double> &rec = *p_rec;
583 double wmax = rec[0];
584 for (size_t i = 1; i < rec.size(); ++i) {
585 wmax = std::max(rec[i], wmax);
586 }
587 double wsum = 0.0f;
588 for (size_t i = 0; i < rec.size(); ++i) {
589 rec[i] = std::exp(rec[i] - wmax);
590 wsum += rec[i];
591 }
592 for (size_t i = 0; i < rec.size(); ++i) {
593 rec[i] /= static_cast<double>(wsum);
594 }
595}
596
597inline static void Softmax(const double* input, double* output, int len) {
598 double wmax = input[0];
599 for (int i = 1; i < len; ++i) {
600 wmax = std::max(input[i], wmax);
601 }
602 double wsum = 0.0f;
603 for (int i = 0; i < len; ++i) {
604 output[i] = std::exp(input[i] - wmax);
605 wsum += output[i];
606 }
607 for (int i = 0; i < len; ++i) {
608 output[i] /= static_cast<double>(wsum);
609 }
610}
611
612template<typename T>
613std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<T>>& input) {
614 std::vector<const T*> ret;
615 for (size_t i = 0; i < input.size(); ++i) {
616 ret.push_back(input.at(i).get());
617 }
618 return ret;
619}
620
621template<typename T1, typename T2>
622inline static void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t start, bool is_reverse = false) {
623 std::vector<std::pair<T1, T2>> arr;
624 for (size_t i = start; i < keys.size(); ++i) {
625 arr.emplace_back(keys[i], values[i]);
626 }
627 if (!is_reverse) {
628 std::stable_sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
629 return a.first < b.first;
630 });
631 } else {
632 std::stable_sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
633 return a.first > b.first;
634 });
635 }
636 for (size_t i = start; i < arr.size(); ++i) {
637 keys[i] = arr[i].first;
638 values[i] = arr[i].second;
639 }
640}
641
642template <typename T>
643inline static std::vector<T*> Vector2Ptr(std::vector<std::vector<T>>& data) {
644 std::vector<T*> ptr(data.size());
645 for (size_t i = 0; i < data.size(); ++i) {
646 ptr[i] = data[i].data();
647 }
648 return ptr;
649}
650
651template <typename T>
652inline static std::vector<int> VectorSize(const std::vector<std::vector<T>>& data) {
653 std::vector<int> ret(data.size());
654 for (size_t i = 0; i < data.size(); ++i) {
655 ret[i] = static_cast<int>(data[i].size());
656 }
657 return ret;
658}
659
660inline static double AvoidInf(double x) {
661 if (x >= 1e300) {
662 return 1e300;
663 } else if (x <= -1e300) {
664 return -1e300;
665 } else {
666 return x;
667 }
668}
669
670inline static float AvoidInf(float x) {
671 if (x >= 1e38) {
672 return 1e38f;
673 } else if (x <= -1e38) {
674 return -1e38f;
675 } else {
676 return x;
677 }
678}
679
680template<typename _Iter> inline
681static typename std::iterator_traits<_Iter>::value_type* IteratorValType(_Iter) {
682 return (0);
683}
684
685template<typename _RanIt, typename _Pr, typename _VTRanIt> inline
686static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
687 size_t len = _Last - _First;
688 const size_t kMinInnerLen = 1024;
689 int num_threads = 1;
690 #pragma omp parallel
691 #pragma omp master
692 {
693 num_threads = omp_get_num_threads();
694 }
695 if (len <= kMinInnerLen || num_threads <= 1) {
696 std::sort(_First, _Last, _Pred);
697 return;
698 }
699 size_t inner_size = (len + num_threads - 1) / num_threads;
700 inner_size = std::max(inner_size, kMinInnerLen);
701 num_threads = static_cast<int>((len + inner_size - 1) / inner_size);
702 #pragma omp parallel for schedule(static, 1)
703 for (int i = 0; i < num_threads; ++i) {
704 size_t left = inner_size*i;
705 size_t right = left + inner_size;
706 right = std::min(right, len);
707 if (right > left) {
708 std::sort(_First + left, _First + right, _Pred);
709 }
710 }
711 // Buffer for merge.
712 std::vector<_VTRanIt> temp_buf(len);
713 _RanIt buf = temp_buf.begin();
714 size_t s = inner_size;
715 // Recursive merge
716 while (s < len) {
717 int loop_size = static_cast<int>((len + s * 2 - 1) / (s * 2));
718 #pragma omp parallel for schedule(static, 1)
719 for (int i = 0; i < loop_size; ++i) {
720 size_t left = i * 2 * s;
721 size_t mid = left + s;
722 size_t right = mid + s;
723 right = std::min(len, right);
724 if (mid >= right) { continue; }
725 std::copy(_First + left, _First + mid, buf + left);
726 std::merge(buf + left, buf + mid, _First + mid, _First + right, _First + left, _Pred);
727 }
728 s *= 2;
729 }
730}
731
732template<typename _RanIt, typename _Pr> inline
733static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) {
734 return ParallelSort(_First, _Last, _Pred, IteratorValType(_First));
735}
736
737// Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not
738template <typename T>
739inline static void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) {
740 auto fatal_msg = [&y, &ymin, &ymax, &callername](int i) {
741 std::ostringstream os;
742 os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]";
743 Log::Fatal(os.str().c_str(), callername, i);
744 };
745 for (int i = 1; i < ny; i += 2) {
746 if (y[i - 1] < y[i]) {
747 if (y[i - 1] < ymin) {
748 fatal_msg(i - 1);
749 } else if (y[i] > ymax) {
750 fatal_msg(i);
751 }
752 } else {
753 if (y[i - 1] > ymax) {
754 fatal_msg(i - 1);
755 } else if (y[i] < ymin) {
756 fatal_msg(i);
757 }
758 }
759 }
760 if (ny & 1) { // odd
761 if (y[ny - 1] < ymin || y[ny - 1] > ymax) {
762 fatal_msg(ny - 1);
763 }
764 }
765}
766
767// One-pass scan over array w with nw elements: find min, max and sum of elements;
768// this is useful for checking weight requirements.
769template <typename T1, typename T2>
770inline static void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
771 T1 minw;
772 T1 maxw;
773 T1 sumw;
774 int i;
775 if (nw & 1) { // odd
776 minw = w[0];
777 maxw = w[0];
778 sumw = w[0];
779 i = 2;
780 } else { // even
781 if (w[0] < w[1]) {
782 minw = w[0];
783 maxw = w[1];
784 } else {
785 minw = w[1];
786 maxw = w[0];
787 }
788 sumw = w[0] + w[1];
789 i = 3;
790 }
791 for (; i < nw; i += 2) {
792 if (w[i - 1] < w[i]) {
793 minw = std::min(minw, w[i - 1]);
794 maxw = std::max(maxw, w[i]);
795 } else {
796 minw = std::min(minw, w[i]);
797 maxw = std::max(maxw, w[i - 1]);
798 }
799 sumw += w[i - 1] + w[i];
800 }
801 if (mi != nullptr) {
802 *mi = minw;
803 }
804 if (ma != nullptr) {
805 *ma = maxw;
806 }
807 if (su != nullptr) {
808 *su = static_cast<T2>(sumw);
809 }
810}
811
812template<typename T>
813inline static std::vector<uint32_t> ConstructBitset(const T* vals, int n) {
814 std::vector<uint32_t> ret;
815 for (int i = 0; i < n; ++i) {
816 int i1 = vals[i] / 32;
817 int i2 = vals[i] % 32;
818 if (static_cast<int>(ret.size()) < i1 + 1) {
819 ret.resize(i1 + 1, 0);
820 }
821 ret[i1] |= (1 << i2);
822 }
823 return ret;
824}
825
826template<typename T>
827inline static bool FindInBitset(const uint32_t* bits, int n, T pos) {
828 int i1 = pos / 32;
829 if (i1 >= n) {
830 return false;
831 }
832 int i2 = pos % 32;
833 return (bits[i1] >> i2) & 1;
834}
835
836inline static bool CheckDoubleEqualOrdered(double a, double b) {
837 double upper = std::nextafter(a, INFINITY);
838 return b <= upper;
839}
840
841inline static double GetDoubleUpperBound(double a) {
842 return std::nextafter(a, INFINITY);;
843}
844
845inline static size_t GetLine(const char* str) {
846 auto start = str;
847 while (*str != '\0' && *str != '\n' && *str != '\r') {
848 ++str;
849 }
850 return str - start;
851}
852
853inline static const char* SkipNewLine(const char* str) {
854 if (*str == '\r') {
855 ++str;
856 }
857 if (*str == '\n') {
858 ++str;
859 }
860 return str;
861}
862
863template <typename T>
864static int Sign(T x) {
865 return (x > T(0)) - (x < T(0));
866}
867
868template <typename T>
869static T SafeLog(T x) {
870 if (x > 0) {
871 return std::log(x);
872 } else {
873 return -INFINITY;
874 }
875}
876
877} // namespace Common
878
879} // namespace LightGBM
880
881#endif // LightGBM_UTILS_COMMON_FUN_H_
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10