Medial Code Documentation
Loading...
Searching...
No Matches
array_args.h
1#ifndef LIGHTGBM_UTILS_ARRAY_AGRS_H_
2#define LIGHTGBM_UTILS_ARRAY_AGRS_H_
3
4#include <vector>
5#include <algorithm>
6#include <LightGBM/utils/openmp_wrapper.h>
7
8namespace LightGBM {
9
13template<typename VAL_T>
14class ArrayArgs {
15public:
16 inline static size_t ArgMaxMT(const std::vector<VAL_T>& array) {
17 int num_threads = 1;
18#pragma omp parallel
19#pragma omp master
20 {
21 num_threads = omp_get_num_threads();
22 }
23 int step = std::max(1, (static_cast<int>(array.size()) + num_threads - 1) / num_threads);
24 std::vector<size_t> arg_maxs(num_threads, 0);
25 #pragma omp parallel for schedule(static, 1)
26 for (int i = 0; i < num_threads; ++i) {
27 size_t start = step * i;
28 if (start >= array.size()) { continue; }
29 size_t end = std::min(array.size(), start + step);
30 size_t arg_max = start;
31 for (size_t j = start + 1; j < end; ++j) {
32 if (array[j] > array[arg_max]) {
33 arg_max = j;
34 }
35 }
36 arg_maxs[i] = arg_max;
37 }
38 size_t ret = arg_maxs[0];
39 for (int i = 1; i < num_threads; ++i) {
40 if (array[arg_maxs[i]] > array[ret]) {
41 ret = arg_maxs[i];
42 }
43 }
44 return ret;
45 }
46 inline static size_t ArgMax(const std::vector<VAL_T>& array) {
47 if (array.empty()) {
48 return 0;
49 }
50 if (array.size() > 1024) {
51 return ArgMaxMT(array);
52 } else {
53 size_t arg_max = 0;
54 for (size_t i = 1; i < array.size(); ++i) {
55 if (array[i] > array[arg_max]) {
56 arg_max = i;
57 }
58 }
59 return arg_max;
60 }
61 }
62
63 inline static size_t ArgMin(const std::vector<VAL_T>& array) {
64 if (array.empty()) {
65 return 0;
66 }
67 size_t arg_min = 0;
68 for (size_t i = 1; i < array.size(); ++i) {
69 if (array[i] < array[arg_min]) {
70 arg_min = i;
71 }
72 }
73 return arg_min;
74 }
75
76 inline static size_t ArgMax(const VAL_T* array, size_t n) {
77 if (n <= 0) {
78 return 0;
79 }
80 size_t arg_max = 0;
81 for (size_t i = 1; i < n; ++i) {
82 if (array[i] > array[arg_max]) {
83 arg_max = i;
84 }
85 }
86 return arg_max;
87 }
88
89 inline static size_t ArgMin(const VAL_T* array, size_t n) {
90 if (n <= 0) {
91 return 0;
92 }
93 size_t arg_min = 0;
94 for (size_t i = 1; i < n; ++i) {
95 if (array[i] < array[arg_min]) {
96 arg_min = i;
97 }
98 }
99 return arg_min;
100 }
101
102 inline static void Partition(std::vector<VAL_T>* arr, int start, int end, int* l, int* r) {
103 int i = start - 1;
104 int j = end - 1;
105 int p = i;
106 int q = j;
107 if (start >= end) {
108 return;
109 }
110 std::vector<VAL_T>& ref = *arr;
111 VAL_T v = ref[end - 1];
112 for (;;) {
113 while (ref[++i] > v);
114 while (v > ref[--j]) { if (j == start) { break; } }
115 if (i >= j) { break; }
116 std::swap(ref[i], ref[j]);
117 if (ref[i] == v) { p++; std::swap(ref[p], ref[i]); }
118 if (v == ref[j]) { q--; std::swap(ref[j], ref[q]); }
119 }
120 std::swap(ref[i], ref[end - 1]);
121 j = i - 1;
122 i = i + 1;
123 for (int k = start; k <= p; k++, j--) { std::swap(ref[k], ref[j]); }
124 for (int k = end - 2; k >= q; k--, i++) { std::swap(ref[i], ref[k]); }
125 *l = j;
126 *r = i;
127 }
128
129 // Note: k refer to index here. e.g. k=0 means get the max number.
130 inline static int ArgMaxAtK(std::vector<VAL_T>* arr, int start, int end, int k) {
131 if (start >= end - 1) {
132 return start;
133 }
134 int l = start;
135 int r = end - 1;
136 Partition(arr, start, end, &l, &r);
137 // if find or all elements are the same.
138 if ((k > l && k < r) || (l == start - 1 && r == end - 1)) {
139 return k;
140 } else if (k <= l) {
141 return ArgMaxAtK(arr, start, l + 1, k);
142 } else {
143 return ArgMaxAtK(arr, r, end, k);
144 }
145 }
146
147 // Note: k is 1-based here. e.g. k=3 means get the top-3 numbers.
148 inline static void MaxK(const std::vector<VAL_T>& array, int k, std::vector<VAL_T>* out) {
149 out->clear();
150 if (k <= 0) {
151 return;
152 }
153 for (auto val : array) {
154 out->push_back(val);
155 }
156 if (static_cast<size_t>(k) >= array.size()) {
157 return;
158 }
159 ArgMaxAtK(out, 0, static_cast<int>(out->size()), k - 1);
160 out->erase(out->begin() + k, out->end());
161 }
162
163 inline static void Assign(std::vector<VAL_T>* array, VAL_T t, size_t n) {
164 array->resize(n);
165 for (size_t i = 0; i < array->size(); ++i) {
166 (*array)[i] = t;
167 }
168 }
169
170 inline static bool CheckAllZero(const std::vector<VAL_T>& array) {
171 for (size_t i = 0; i < array.size(); ++i) {
172 if (array[i] != VAL_T(0)) {
173 return false;
174 }
175 }
176 return true;
177 }
178
179 inline static bool CheckAll(const std::vector<VAL_T>& array, VAL_T t) {
180 for (size_t i = 0; i < array.size(); ++i) {
181 if (array[i] != t) {
182 return false;
183 }
184 }
185 return true;
186 }
187};
188
189} // namespace LightGBM
190
191#endif // LightGBM_UTILS_ARRAY_AGRS_H_
192
Contains some operation for a array, e.g. ArgMax, TopK.
Definition array_args.h:14
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
NLOHMANN_BASIC_JSON_TPL_DECLARATION void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL &j1, nlohmann::NLOHMANN_BASIC_JSON_TPL &j2) noexcept(//NOLINT(readability-inconsistent-declaration-parameter-name, cert-dcl58-cpp) is_nothrow_move_constructible< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value &&//NOLINT(misc-redundant-expression) is_nothrow_move_assignable< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value)
exchanges the values of two JSON objects
Definition json.hpp:24418