Medial Code Documentation
Loading...
Searching...
No Matches
quantile.h
Go to the documentation of this file.
1
7#ifndef XGBOOST_COMMON_QUANTILE_H_
8#define XGBOOST_COMMON_QUANTILE_H_
9
10#include <xgboost/data.h>
11#include <xgboost/logging.h>
12
13#include <algorithm>
14#include <cmath>
15#include <cstring>
16#include <iostream>
17#include <set>
18#include <vector>
19
20#include "categorical.h"
21#include "common.h"
22#include "error_msg.h" // GroupWeight
23#include "optional_weight.h" // OptionalWeights
24#include "threading_utils.h"
25#include "timer.h"
26
27namespace xgboost::common {
33template<typename DType, typename RType>
34struct WQSummary {
36 struct Entry {
38 RType rmin;
40 RType rmax;
42 RType wmin;
44 DType value;
45 // constructor
46 XGBOOST_DEVICE Entry() {} // NOLINT
47 // constructor
48 XGBOOST_DEVICE Entry(RType rmin, RType rmax, RType wmin, DType value)
49 : rmin(rmin), rmax(rmax), wmin(wmin), value(value) {}
54 inline void CheckValid(RType eps = 0) const {
55 CHECK(rmin >= 0 && rmax >= 0 && wmin >= 0) << "nonneg constraint";
56 CHECK(rmax- rmin - wmin > -eps) << "relation constraint: min/max";
57 }
59 XGBOOST_DEVICE inline RType RMinNext() const {
60 return rmin + wmin;
61 }
63 XGBOOST_DEVICE inline RType RMaxPrev() const {
64 return rmax - wmin;
65 }
66
67 friend std::ostream& operator<<(std::ostream& os, Entry const& e) {
68 os << "rmin: " << e.rmin << ", "
69 << "rmax: " << e.rmax << ", "
70 << "wmin: " << e.wmin << ", "
71 << "value: " << e.value;
72 return os;
73 }
74 };
76 struct Queue {
77 // entry in the queue
78 struct QEntry {
79 // value of the instance
80 DType value;
81 // weight of instance
82 RType weight;
83 // default constructor
84 QEntry() = default;
85 // constructor
86 QEntry(DType value, RType weight)
87 : value(value), weight(weight) {}
88 // comparator on value
89 inline bool operator<(const QEntry &b) const {
90 return value < b.value;
91 }
92 };
93 // the input queue
94 std::vector<QEntry> queue;
95 // end of the queue
96 size_t qtail;
97 // push data to the queue
98 inline void Push(DType x, RType w) {
99 if (qtail == 0 || queue[qtail - 1].value != x) {
100 queue[qtail++] = QEntry(x, w);
101 } else {
102 queue[qtail - 1].weight += w;
103 }
104 }
105 inline void MakeSummary(WQSummary *out) {
106 std::sort(queue.begin(), queue.begin() + qtail);
107 out->size = 0;
108 // start update sketch
109 RType wsum = 0;
110 // construct data with unique weights
111 for (size_t i = 0; i < qtail;) {
112 size_t j = i + 1;
113 RType w = queue[i].weight;
114 while (j < qtail && queue[j].value == queue[i].value) {
115 w += queue[j].weight; ++j;
116 }
117 out->data[out->size++] = Entry(wsum, wsum + w, w, queue[i].value);
118 wsum += w; i = j;
119 }
120 }
121 };
125 size_t size;
126 // constructor
127 WQSummary(Entry *data, size_t size)
128 : data(data), size(size) {}
132 inline RType MaxError() const {
133 RType res = data[0].rmax - data[0].rmin - data[0].wmin;
134 for (size_t i = 1; i < size; ++i) {
135 res = std::max(data[i].RMaxPrev() - data[i - 1].RMinNext(), res);
136 res = std::max(data[i].rmax - data[i].rmin - data[i].wmin, res);
137 }
138 return res;
139 }
145 inline Entry Query(DType qvalue, size_t &istart) const { // NOLINT(*)
146 while (istart < size && qvalue > data[istart].value) {
147 ++istart;
148 }
149 if (istart == size) {
150 RType rmax = data[size - 1].rmax;
151 return Entry(rmax, rmax, 0.0f, qvalue);
152 }
153 if (qvalue == data[istart].value) {
154 return data[istart];
155 } else {
156 if (istart == 0) {
157 return Entry(0.0f, 0.0f, 0.0f, qvalue);
158 } else {
159 return Entry(data[istart - 1].RMinNext(),
160 data[istart].RMaxPrev(),
161 0.0f, qvalue);
162 }
163 }
164 }
166 inline RType MaxRank() const {
167 return data[size - 1].rmax;
168 }
173 inline void CopyFrom(const WQSummary &src) {
174 if (!src.data) {
175 CHECK_EQ(src.size, 0);
176 size = 0;
177 return;
178 }
179 if (!data) {
180 CHECK_EQ(this->size, 0);
181 CHECK_EQ(src.size, 0);
182 return;
183 }
184 size = src.size;
185 std::memcpy(data, src.data, sizeof(Entry) * size);
186 }
187 inline void MakeFromSorted(const Entry* entries, size_t n) {
188 size = 0;
189 for (size_t i = 0; i < n;) {
190 size_t j = i + 1;
191 // ignore repeated values
192 for (; j < n && entries[j].value == entries[i].value; ++j) {}
193 data[size++] = Entry(entries[i].rmin, entries[i].rmax, entries[i].wmin,
194 entries[i].value);
195 i = j;
196 }
197 }
204 inline void CheckValid(RType eps) const {
205 for (size_t i = 0; i < size; ++i) {
206 data[i].CheckValid(eps);
207 if (i != 0) {
208 CHECK(data[i].rmin >= data[i - 1].rmin + data[i - 1].wmin) << "rmin range constraint";
209 CHECK(data[i].rmax >= data[i - 1].rmax + data[i].wmin) << "rmax range constraint";
210 }
211 }
212 }
213
220 void SetPrune(const WQSummary &src, size_t maxsize) {
221 if (src.size <= maxsize) {
222 this->CopyFrom(src); return;
223 }
224 const RType begin = src.data[0].rmax;
225 const RType range = src.data[src.size - 1].rmin - src.data[0].rmax;
226 const size_t n = maxsize - 1;
227 data[0] = src.data[0];
228 this->size = 1;
229 // lastidx is used to avoid duplicated records
230 size_t i = 1, lastidx = 0;
231 for (size_t k = 1; k < n; ++k) {
232 RType dx2 = 2 * ((k * range) / n + begin);
233 // find first i such that d < (rmax[i+1] + rmin[i+1]) / 2
234 while (i < src.size - 1
235 && dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i;
236 if (i == src.size - 1) break;
237 if (dx2 < src.data[i].RMinNext() + src.data[i + 1].RMaxPrev()) {
238 if (i != lastidx) {
239 data[size++] = src.data[i]; lastidx = i;
240 }
241 } else {
242 if (i + 1 != lastidx) {
243 data[size++] = src.data[i + 1]; lastidx = i + 1;
244 }
245 }
246 }
247 if (lastidx != src.size - 1) {
248 data[size++] = src.data[src.size - 1];
249 }
250 }
256 inline void SetCombine(const WQSummary &sa,
257 const WQSummary &sb) {
258 if (sa.size == 0) {
259 this->CopyFrom(sb); return;
260 }
261 if (sb.size == 0) {
262 this->CopyFrom(sa); return;
263 }
264 CHECK(sa.size > 0 && sb.size > 0);
265 const Entry *a = sa.data, *a_end = sa.data + sa.size;
266 const Entry *b = sb.data, *b_end = sb.data + sb.size;
267 // extended rmin value
268 RType aprev_rmin = 0, bprev_rmin = 0;
269 Entry *dst = this->data;
270 while (a != a_end && b != b_end) {
271 // duplicated value entry
272 if (a->value == b->value) {
273 *dst = Entry(a->rmin + b->rmin,
274 a->rmax + b->rmax,
275 a->wmin + b->wmin, a->value);
276 aprev_rmin = a->RMinNext();
277 bprev_rmin = b->RMinNext();
278 ++dst; ++a; ++b;
279 } else if (a->value < b->value) {
280 *dst = Entry(a->rmin + bprev_rmin,
281 a->rmax + b->RMaxPrev(),
282 a->wmin, a->value);
283 aprev_rmin = a->RMinNext();
284 ++dst; ++a;
285 } else {
286 *dst = Entry(b->rmin + aprev_rmin,
287 b->rmax + a->RMaxPrev(),
288 b->wmin, b->value);
289 bprev_rmin = b->RMinNext();
290 ++dst; ++b;
291 }
292 }
293 if (a != a_end) {
294 RType brmax = (b_end - 1)->rmax;
295 do {
296 *dst = Entry(a->rmin + bprev_rmin, a->rmax + brmax, a->wmin, a->value);
297 ++dst; ++a;
298 } while (a != a_end);
299 }
300 if (b != b_end) {
301 RType armax = (a_end - 1)->rmax;
302 do {
303 *dst = Entry(b->rmin + aprev_rmin, b->rmax + armax, b->wmin, b->value);
304 ++dst; ++b;
305 } while (b != b_end);
306 }
307 this->size = dst - data;
308 const RType tol = 10;
309 RType err_mingap, err_maxgap, err_wgap;
310 this->FixError(&err_mingap, &err_maxgap, &err_wgap);
311 if (err_mingap > tol || err_maxgap > tol || err_wgap > tol) {
312 LOG(INFO) << "mingap=" << err_mingap
313 << ", maxgap=" << err_maxgap
314 << ", wgap=" << err_wgap;
315 }
316 CHECK(size <= sa.size + sb.size) << "bug in combine";
317 }
318 // helper function to print the current content of sketch
319 inline void Print() const {
320 for (size_t i = 0; i < this->size; ++i) {
321 LOG(CONSOLE) << "[" << i << "] rmin=" << data[i].rmin
322 << ", rmax=" << data[i].rmax
323 << ", wmin=" << data[i].wmin
324 << ", v=" << data[i].value;
325 }
326 }
327 // try to fix rounding error
328 // and re-establish invariance
329 inline void FixError(RType *err_mingap,
330 RType *err_maxgap,
331 RType *err_wgap) const {
332 *err_mingap = 0;
333 *err_maxgap = 0;
334 *err_wgap = 0;
335 RType prev_rmin = 0, prev_rmax = 0;
336 for (size_t i = 0; i < this->size; ++i) {
337 if (data[i].rmin < prev_rmin) {
338 data[i].rmin = prev_rmin;
339 *err_mingap = std::max(*err_mingap, prev_rmin - data[i].rmin);
340 } else {
341 prev_rmin = data[i].rmin;
342 }
343 if (data[i].rmax < prev_rmax) {
344 data[i].rmax = prev_rmax;
345 *err_maxgap = std::max(*err_maxgap, prev_rmax - data[i].rmax);
346 }
347 RType rmin_next = data[i].RMinNext();
348 if (data[i].rmax < rmin_next) {
349 data[i].rmax = rmin_next;
350 *err_wgap = std::max(*err_wgap, data[i].rmax - rmin_next);
351 }
352 prev_rmax = data[i].rmax;
353 }
354 }
355};
356
358template<typename DType, typename RType>
359struct WXQSummary : public WQSummary<DType, RType> {
360 // redefine entry type
361 using Entry = typename WQSummary<DType, RType>::Entry;
362 // constructor
363 WXQSummary(Entry *data, size_t size)
365 // check if the block is large chunk
366 inline static bool CheckLarge(const Entry &e, RType chunk) {
367 return e.RMinNext() > e.RMaxPrev() + chunk;
368 }
369 // set prune
370 inline void SetPrune(const WQSummary<DType, RType> &src, size_t maxsize) {
371 if (src.size <= maxsize) {
372 this->CopyFrom(src); return;
373 }
374 RType begin = src.data[0].rmax;
375 // n is number of points exclude the min/max points
376 size_t n = maxsize - 2, nbig = 0;
377 // these is the range of data exclude the min/max point
378 RType range = src.data[src.size - 1].rmin - begin;
379 // prune off zero weights
380 if (range == 0.0f || maxsize <= 2) {
381 // special case, contain only two effective data pts
382 this->data[0] = src.data[0];
383 this->data[1] = src.data[src.size - 1];
384 this->size = 2;
385 return;
386 } else {
387 range = std::max(range, static_cast<RType>(1e-3f));
388 }
389 // Get a big enough chunk size, bigger than range / n
390 // (multiply by 2 is a safe factor)
391 const RType chunk = 2 * range / n;
392 // minimized range
393 RType mrange = 0;
394 {
395 // first scan, grab all the big chunk
396 // moving block index, exclude the two ends.
397 size_t bid = 0;
398 for (size_t i = 1; i < src.size - 1; ++i) {
399 // detect big chunk data point in the middle
400 // always save these data points.
401 if (CheckLarge(src.data[i], chunk)) {
402 if (bid != i - 1) {
403 // accumulate the range of the rest points
404 mrange += src.data[i].RMaxPrev() - src.data[bid].RMinNext();
405 }
406 bid = i; ++nbig;
407 }
408 }
409 if (bid != src.size - 2) {
410 mrange += src.data[src.size-1].RMaxPrev() - src.data[bid].RMinNext();
411 }
412 }
413 // assert: there cannot be more than n big data points
414 if (nbig >= n) {
415 // see what was the case
416 LOG(INFO) << " check quantile stats, nbig=" << nbig << ", n=" << n;
417 LOG(INFO) << " srcsize=" << src.size << ", maxsize=" << maxsize
418 << ", range=" << range << ", chunk=" << chunk;
419 src.Print();
420 CHECK(nbig < n) << "quantile: too many large chunk";
421 }
422 this->data[0] = src.data[0];
423 this->size = 1;
424 // The counter on the rest of points, to be selected equally from small chunks.
425 n = n - nbig;
426 // find the rest of point
427 size_t bid = 0, k = 1, lastidx = 0;
428 for (size_t end = 1; end < src.size; ++end) {
429 if (end == src.size - 1 || CheckLarge(src.data[end], chunk)) {
430 if (bid != end - 1) {
431 size_t i = bid;
432 RType maxdx2 = src.data[end].RMaxPrev() * 2;
433 for (; k < n; ++k) {
434 RType dx2 = 2 * ((k * mrange) / n + begin);
435 if (dx2 >= maxdx2) break;
436 while (i < end &&
437 dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i;
438 if (i == end) break;
439 if (dx2 < src.data[i].RMinNext() + src.data[i + 1].RMaxPrev()) {
440 if (i != lastidx) {
441 this->data[this->size++] = src.data[i]; lastidx = i;
442 }
443 } else {
444 if (i + 1 != lastidx) {
445 this->data[this->size++] = src.data[i + 1]; lastidx = i + 1;
446 }
447 }
448 }
449 }
450 if (lastidx != end) {
451 this->data[this->size++] = src.data[end];
452 lastidx = end;
453 }
454 bid = end;
455 // shift base by the gap
456 begin += src.data[bid].RMinNext() - src.data[bid].RMaxPrev();
457 }
458 }
459 }
460};
468template<typename DType, typename RType, class TSummary>
470 public:
471 static float constexpr kFactor = 8.0;
472
473 public:
475 using Summary = TSummary;
477 using Entry = typename Summary::Entry;
479 struct SummaryContainer : public Summary {
480 std::vector<Entry> space;
481 SummaryContainer(const SummaryContainer &src) : Summary(nullptr, src.size) {
482 this->space = src.space;
483 this->data = dmlc::BeginPtr(this->space);
484 }
485 SummaryContainer() : Summary(nullptr, 0) {
486 }
488 inline void Reserve(size_t size) {
489 if (size > space.size()) {
490 space.resize(size);
491 this->data = dmlc::BeginPtr(space);
492 }
493 }
500 inline void Reduce(const Summary &src, size_t max_nbyte) {
501 this->Reserve((max_nbyte - sizeof(this->size)) / sizeof(Entry));
502 SummaryContainer temp;
503 temp.Reserve(this->size + src.size);
504 temp.SetCombine(*this, src);
505 this->SetPrune(temp, space.size());
506 }
508 inline static size_t CalcMemCost(size_t nentry) {
509 return sizeof(size_t) + sizeof(Entry) * nentry;
510 }
512 template<typename TStream>
513 inline void Save(TStream &fo) const { // NOLINT(*)
514 fo.Write(&(this->size), sizeof(this->size));
515 if (this->size != 0) {
516 fo.Write(this->data, this->size * sizeof(Entry));
517 }
518 }
520 template<typename TStream>
521 inline void Load(TStream &fi) { // NOLINT(*)
522 CHECK_EQ(fi.Read(&this->size, sizeof(this->size)), sizeof(this->size));
523 this->Reserve(this->size);
524 if (this->size != 0) {
525 CHECK_EQ(fi.Read(this->data, this->size * sizeof(Entry)),
526 this->size * sizeof(Entry));
527 }
528 }
529 };
535 inline void Init(size_t maxn, double eps) {
536 LimitSizeLevel(maxn, eps, &nlevel, &limit_size);
537 // lazy reserve the space, if there is only one value, no need to allocate space
538 inqueue.queue.resize(1);
539 inqueue.qtail = 0;
540 data.clear();
541 level.clear();
542 }
543
544 inline static void LimitSizeLevel
545 (size_t maxn, double eps, size_t* out_nlevel, size_t* out_limit_size) {
546 size_t& nlevel = *out_nlevel;
547 size_t& limit_size = *out_limit_size;
548 nlevel = 1;
549 while (true) {
550 limit_size = static_cast<size_t>(ceil(nlevel / eps)) + 1;
551 limit_size = std::min(maxn, limit_size);
552 size_t n = (1ULL << nlevel);
553 if (n * limit_size >= maxn) break;
554 ++nlevel;
555 }
556 // check invariant
557 size_t n = (1ULL << nlevel);
558 CHECK(n * limit_size >= maxn) << "invalid init parameter";
559 CHECK(nlevel <= std::max(static_cast<size_t>(1), static_cast<size_t>(limit_size * eps)))
560 << "invalid init parameter";
561 }
562
568 inline void Push(DType x, RType w = 1) {
569 if (w == static_cast<RType>(0)) return;
570 if (inqueue.qtail == inqueue.queue.size() && inqueue.queue[inqueue.qtail - 1].value != x) {
571 // jump from lazy one value to limit_size * 2
572 if (inqueue.queue.size() == 1) {
573 inqueue.queue.resize(limit_size * 2);
574 } else {
575 temp.Reserve(limit_size * 2);
576 inqueue.MakeSummary(&temp);
577 // cleanup queue
578 inqueue.qtail = 0;
579 this->PushTemp();
580 }
581 }
582 inqueue.Push(x, w);
583 }
584
585 inline void PushSummary(const Summary& summary) {
586 temp.Reserve(limit_size * 2);
587 temp.SetPrune(summary, limit_size * 2);
588 PushTemp();
589 }
590
592 inline void PushTemp() {
593 temp.Reserve(limit_size * 2);
594 for (size_t l = 1; true; ++l) {
595 this->InitLevel(l + 1);
596 // check if level l is empty
597 if (level[l].size == 0) {
598 level[l].SetPrune(temp, limit_size);
599 break;
600 } else {
601 // level 0 is actually temp space
602 level[0].SetPrune(temp, limit_size);
603 temp.SetCombine(level[0], level[l]);
604 if (temp.size > limit_size) {
605 // try next level
606 level[l].size = 0;
607 } else {
608 // if merged record is still smaller, no need to send to next level
609 level[l].CopyFrom(temp); break;
610 }
611 }
612 }
613 }
615 inline void GetSummary(SummaryContainer *out) {
616 if (level.size() != 0) {
617 out->Reserve(limit_size * 2);
618 } else {
619 out->Reserve(inqueue.queue.size());
620 }
621 inqueue.MakeSummary(out);
622 if (level.size() != 0) {
623 level[0].SetPrune(*out, limit_size);
624 for (size_t l = 1; l < level.size(); ++l) {
625 if (level[l].size == 0) continue;
626 if (level[0].size == 0) {
627 level[0].CopyFrom(level[l]);
628 } else {
629 out->SetCombine(level[0], level[l]);
630 level[0].SetPrune(*out, limit_size);
631 }
632 }
633 out->CopyFrom(level[0]);
634 } else {
635 if (out->size > limit_size) {
636 temp.Reserve(limit_size);
637 temp.SetPrune(*out, limit_size);
638 out->CopyFrom(temp);
639 }
640 }
641 }
642 // used for debug, check if the sketch is valid
643 inline void CheckValid(RType eps) const {
644 for (size_t l = 1; l < level.size(); ++l) {
645 level[l].CheckValid(eps);
646 }
647 }
648 // initialize level space to at least nlevel
649 inline void InitLevel(size_t nlevel) {
650 if (level.size() >= nlevel) return;
651 data.resize(limit_size * nlevel);
652 level.resize(nlevel, Summary(nullptr, 0));
653 for (size_t l = 0; l < level.size(); ++l) {
654 level[l].data = dmlc::BeginPtr(data) + l * limit_size;
655 }
656 }
657 // input data queue
658 typename Summary::Queue inqueue;
659 // number of levels
660 size_t nlevel;
661 // size of summary in each level
662 size_t limit_size;
663 // the level of each summaries
664 std::vector<Summary> level;
665 // content of the summary
666 std::vector<Entry> data;
667 // temporal summary, used for temp-merge
668 SummaryContainer temp;
669};
670
676template<typename DType, typename RType = unsigned>
678 public QuantileSketchTemplate<DType, RType, WQSummary<DType, RType> > {
679};
680
686template<typename DType, typename RType = unsigned>
688 public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> > {
689};
690
691namespace detail {
692inline std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
693 std::vector<float> const &group_weights = info.weights_.HostVector();
694 if (group_weights.empty()) {
695 return group_weights;
696 }
697
698 auto const &group_ptr = info.group_ptr_;
699 CHECK_GE(group_ptr.size(), 2);
700
701 auto n_groups = group_ptr.size() - 1;
702 CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight();
703
704 bst_row_t n_samples = info.num_row_;
705 std::vector<float> results(n_samples);
706 CHECK_EQ(group_ptr.back(), n_samples)
707 << error::GroupSize() << " the number of rows from the data.";
708 size_t cur_group = 0;
709 for (bst_row_t i = 0; i < n_samples; ++i) {
710 results[i] = group_weights[cur_group];
711 if (i == group_ptr[cur_group + 1]) {
712 cur_group++;
713 }
714 }
715 return results;
716}
717} // namespace detail
718
719class HistogramCuts;
720
721template <typename Batch, typename IsValid>
722std::vector<bst_row_t> CalcColumnSize(Batch const &batch, bst_feature_t const n_columns,
723 size_t const n_threads, IsValid &&is_valid) {
724 std::vector<std::vector<bst_row_t>> column_sizes_tloc(n_threads);
725 for (auto &column : column_sizes_tloc) {
726 column.resize(n_columns, 0);
727 }
728
729 ParallelFor(batch.Size(), n_threads, [&](omp_ulong i) {
730 auto &local_column_sizes = column_sizes_tloc.at(omp_get_thread_num());
731 auto const &line = batch.GetLine(i);
732 for (size_t j = 0; j < line.Size(); ++j) {
733 auto elem = line.GetElement(j);
734 if (is_valid(elem)) {
735 local_column_sizes[elem.column_idx]++;
736 }
737 }
738 });
739 // reduce to first thread
740 auto &entries_per_columns = column_sizes_tloc.front();
741 CHECK_EQ(entries_per_columns.size(), static_cast<size_t>(n_columns));
742 for (size_t i = 1; i < n_threads; ++i) {
743 CHECK_EQ(column_sizes_tloc[i].size(), static_cast<size_t>(n_columns));
744 for (size_t j = 0; j < n_columns; ++j) {
745 entries_per_columns[j] += column_sizes_tloc[i][j];
746 }
747 }
748 return entries_per_columns;
749}
750
751template <typename Batch, typename IsValid>
752std::vector<bst_feature_t> LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns,
753 size_t const nthreads, IsValid&& is_valid) {
754 /* Some sparse datasets have their mass concentrating on small number of features. To
755 * avoid waiting for a few threads running forever, we here distribute different number
756 * of columns to different threads according to number of entries.
757 */
758 size_t const total_entries = nnz;
759 size_t const entries_per_thread = DivRoundUp(total_entries, nthreads);
760
761 // Need to calculate the size for each batch.
762 std::vector<bst_row_t> entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid);
763 std::vector<bst_feature_t> cols_ptr(nthreads + 1, 0);
764 size_t count{0};
765 size_t current_thread{1};
766
767 for (auto col : entries_per_columns) {
768 cols_ptr.at(current_thread)++; // add one column to thread
769 count += col;
770 CHECK_LE(count, total_entries);
771 if (count > entries_per_thread) {
772 current_thread++;
773 count = 0;
774 cols_ptr.at(current_thread) = cols_ptr[current_thread - 1];
775 }
776 }
777 // Idle threads.
778 for (; current_thread < cols_ptr.size() - 1; ++current_thread) {
779 cols_ptr[current_thread + 1] = cols_ptr[current_thread];
780 }
781 return cols_ptr;
782}
783
787template <typename WQSketch>
789 protected:
790 std::vector<WQSketch> sketches_;
791 std::vector<std::set<float>> categories_;
792 std::vector<FeatureType> const feature_types_;
793
794 std::vector<bst_row_t> columns_size_;
795 int32_t max_bins_;
796 bool use_group_ind_{false};
797 int32_t n_threads_;
798 bool has_categorical_{false};
799 Monitor monitor_;
800
801 public:
802 /* \brief Initialize necessary info.
803 *
804 * \param columns_size Size of each column.
805 * \param max_bins maximum number of bins for each feature.
806 * \param use_group whether is assigned to group to data instance.
807 */
808 SketchContainerImpl(Context const *ctx, std::vector<bst_row_t> columns_size, int32_t max_bins,
809 common::Span<FeatureType const> feature_types, bool use_group);
810
811 static bool UseGroup(MetaInfo const &info) {
812 size_t const num_groups =
813 info.group_ptr_.size() == 0 ? 0 : info.group_ptr_.size() - 1;
814 // Use group index for weights?
815 bool const use_group_ind =
816 num_groups != 0 && (info.weights_.Size() != info.num_row_);
817 return use_group_ind;
818 }
819
820 static uint32_t SearchGroupIndFromRow(std::vector<bst_uint> const &group_ptr,
821 size_t const base_rowid) {
822 CHECK_LT(base_rowid, group_ptr.back())
823 << "Row: " << base_rowid << " is not found in any group.";
824 bst_group_t group_ind =
825 std::upper_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid) -
826 group_ptr.cbegin() - 1;
827 return group_ind;
828 }
829 // Gather sketches from all workers.
830 void GatherSketchInfo(MetaInfo const& info,
831 std::vector<typename WQSketch::SummaryContainer> const &reduced,
832 std::vector<bst_row_t> *p_worker_segments,
833 std::vector<bst_row_t> *p_sketches_scan,
834 std::vector<typename WQSketch::Entry> *p_global_sketches);
835 // Merge sketches from all workers.
836 void AllReduce(MetaInfo const& info, std::vector<typename WQSketch::SummaryContainer> *p_reduced,
837 std::vector<int32_t> *p_num_cuts);
838
839 template <typename Batch, typename IsValid>
840 void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz,
841 size_t n_features, bool is_dense, IsValid is_valid) {
842 auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);
843
845#pragma omp parallel num_threads(n_threads_)
846 {
847 exc.Run([&]() {
848 auto tid = static_cast<uint32_t>(omp_get_thread_num());
849 auto const begin = thread_columns_ptr[tid];
850 auto const end = thread_columns_ptr[tid + 1];
851
852 // do not iterate if no columns are assigned to the thread
853 if (begin < end && end <= n_features) {
854 for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
855 auto const &line = batch.GetLine(ridx);
856 auto w = weights[ridx + base_rowid];
857 if (is_dense) {
858 for (size_t ii = begin; ii < end; ii++) {
859 auto elem = line.GetElement(ii);
860 if (is_valid(elem)) {
861 if (IsCat(feature_types_, ii)) {
862 categories_[ii].emplace(elem.value);
863 } else {
864 sketches_[ii].Push(elem.value, w);
865 }
866 }
867 }
868 } else {
869 for (size_t i = 0; i < line.Size(); ++i) {
870 auto const &elem = line.GetElement(i);
871 if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
872 if (IsCat(feature_types_, elem.column_idx)) {
873 categories_[elem.column_idx].emplace(elem.value);
874 } else {
875 sketches_[elem.column_idx].Push(elem.value, w);
876 }
877 }
878 }
879 }
880 }
881 }
882 });
883 }
884 exc.Rethrow();
885 }
886
887 /* \brief Push a CSR matrix. */
888 void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
889
890 void MakeCuts(MetaInfo const& info, HistogramCuts* cuts);
891
892 private:
893 // Merge all categories from other workers.
894 void AllreduceCategories(MetaInfo const& info);
895};
896
897class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
898 public:
900
901 public:
903 std::vector<size_t> columns_size, bool use_group);
904
905 template <typename Batch>
906 void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing);
907};
908
914 double sum_total{0.0};
916 double rmin, wmin;
920 double next_goal;
921 // pointer to the sketch to put things in
923 // initialize the space
924 inline void Init(unsigned max_size) {
925 next_goal = -1.0f;
926 rmin = wmin = 0.0f;
927 sketch->temp.Reserve(max_size + 1);
928 sketch->temp.size = 0;
929 }
936 inline void Push(bst_float fvalue, bst_float w, unsigned max_size) {
937 if (next_goal == -1.0f) {
938 next_goal = 0.0f;
939 last_fvalue = fvalue;
940 wmin = w;
941 return;
942 }
943 if (last_fvalue != fvalue) {
944 double rmax = rmin + wmin;
945 if (rmax >= next_goal && sketch->temp.size != max_size) {
946 if (sketch->temp.size == 0 ||
947 last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
948 // push to sketch
949 sketch->temp.data[sketch->temp.size] =
951 static_cast<bst_float>(rmin), static_cast<bst_float>(rmax),
952 static_cast<bst_float>(wmin), last_fvalue);
953 CHECK_LT(sketch->temp.size, max_size) << "invalid maximum size max_size=" << max_size
954 << ", stemp.size" << sketch->temp.size;
955 ++sketch->temp.size;
956 }
957 if (sketch->temp.size == max_size) {
958 next_goal = sum_total * 2.0f + 1e-5f;
959 } else {
960 next_goal = static_cast<bst_float>(sketch->temp.size * sum_total / max_size);
961 }
962 } else {
963 if (rmax >= next_goal) {
964 LOG(DEBUG) << "INFO: rmax=" << rmax << ", sum_total=" << sum_total
965 << ", naxt_goal=" << next_goal << ", size=" << sketch->temp.size;
966 }
967 }
968 rmin = rmax;
969 wmin = w;
970 last_fvalue = fvalue;
971 } else {
972 wmin += w;
973 }
974 }
975
977 inline void Finalize(unsigned max_size) {
978 double rmax = rmin + wmin;
979 if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
980 CHECK_LE(sketch->temp.size, max_size)
981 << "Finalize: invalid maximum size, max_size=" << max_size
982 << ", stemp.size=" << sketch->temp.size;
983 // push to sketch
984 sketch->temp.data[sketch->temp.size] = common::WXQuantileSketch<bst_float, bst_float>::Entry(
985 static_cast<bst_float>(rmin), static_cast<bst_float>(rmax), static_cast<bst_float>(wmin),
986 last_fvalue);
987 ++sketch->temp.size;
988 }
989 sketch->PushTemp();
990 }
991};
992
993class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float, float>> {
994 std::vector<SortedQuantile> sketches_;
996
997 public:
998 explicit SortedSketchContainer(Context const *ctx, int32_t max_bins,
1000 std::vector<size_t> columns_size, bool use_group)
1001 : SketchContainerImpl{ctx, columns_size, max_bins, ft, use_group} {
1002 monitor_.Init(__func__);
1003 sketches_.resize(columns_size.size());
1004 size_t i = 0;
1005 for (auto &sketch : sketches_) {
1006 sketch.sketch = &Super::sketches_[i];
1007 sketch.Init(max_bins_);
1008 auto eps = 2.0 / max_bins;
1009 sketch.sketch->Init(columns_size_[i], eps);
1010 ++i;
1011 }
1012 }
1016 void PushColPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian);
1017};
1018} // namespace xgboost::common
1019#endif // XGBOOST_COMMON_QUANTILE_H_
Copyright 2020-2023, XGBoost Contributors.
OMP Exception class catches, saves and rethrows exception from OMP blocks.
Definition common.h:53
void Rethrow()
should be called from the main thread to rethrow the exception
Definition common.h:84
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
Definition common.h:65
Meta information about dataset, always sit in memory.
Definition data.h:48
HostDeviceVector< bst_float > weights_
weights of each instance, optional
Definition data.h:69
std::vector< bst_group_t > group_ptr_
the index of begin and end of a group needed when the learning task is ranking.
Definition data.h:67
uint64_t num_row_
number of rows in the data
Definition data.h:54
In-memory storage unit of sparse batch, stored in CSR format.
Definition data.h:328
Definition hist_util.h:37
Definition quantile.h:897
template for all quantile sketch algorithm that uses merge/prune scheme
Definition quantile.h:469
TSummary Summary
type of summary type
Definition quantile.h:475
typename Summary::Entry Entry
the entry type
Definition quantile.h:477
void Init(size_t maxn, double eps)
initialize the quantile sketch, given the performance specification
Definition quantile.h:535
void GetSummary(SummaryContainer *out)
get the summary after finalize
Definition quantile.h:615
void PushTemp()
push up temp
Definition quantile.h:592
void Push(DType x, RType w=1)
add an element to a sketch
Definition quantile.h:568
Definition quantile.h:788
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
Quantile sketch use WQSummary.
Definition quantile.h:678
Quantile sketch use WXQSummary.
Definition quantile.h:688
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition base.h:64
Copyright 2015-2023 by XGBoost Contributors.
defines console logging options for xgboost. Use to enforce unified print behavior.
detail namespace with internal helper functions
Definition json.hpp:249
T * BeginPtr(std::vector< T > &vec)
safely get the beginning address of a vector
Definition base.h:284
Copyright 2017-2023, XGBoost Contributors.
Definition span.h:77
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
std::uint32_t bst_group_t
Type for ranking group index.
Definition base.h:114
dmlc::omp_ulong omp_ulong
define unsigned long for openmp loop
Definition base.h:322
std::size_t bst_row_t
Type for data row index.
Definition base.h:110
int32_t bst_bin_t
Type for histogram bin index.
Definition base.h:103
float bst_float
float type, used for storing statistics
Definition base.h:97
Runtime context for XGBoost.
Definition context.h:84
Element from a sparse vector.
Definition data.h:216
Timing utility used to measure total method execution time over the lifetime of the containing object...
Definition timer.h:47
Definition optional_weight.h:12
same as summary, but use STL to backup the space
Definition quantile.h:479
void Reserve(size_t size)
reserve space for summary
Definition quantile.h:488
void Load(TStream &fi)
load data structure from input stream
Definition quantile.h:521
void Save(TStream &fo) const
save the data structure into stream
Definition quantile.h:513
static size_t CalcMemCost(size_t nentry)
return the number of bytes this data structure cost in serialization
Definition quantile.h:508
void Reduce(const Summary &src, size_t max_nbyte)
do elementwise combination of summary array this[i] = combine(this[i], src[i]) for each i
Definition quantile.h:500
Quantile structure accepts sorted data, extracted from histmaker.
Definition quantile.h:912
void Finalize(unsigned max_size)
push final unfinished value to the sketch
Definition quantile.h:977
void Push(bst_float fvalue, bst_float w, unsigned max_size)
push a new element to sketch
Definition quantile.h:936
bst_float last_fvalue
last seen feature value
Definition quantile.h:918
double next_goal
current size of sketch
Definition quantile.h:920
double rmin
statistics used in the sketch
Definition quantile.h:916
an entry in the sketch summary
Definition quantile.h:36
DType value
the value of data
Definition quantile.h:44
RType wmin
maximum weight
Definition quantile.h:42
XGBOOST_DEVICE RType RMaxPrev() const
Definition quantile.h:63
RType rmin
minimum rank
Definition quantile.h:38
XGBOOST_DEVICE RType RMinNext() const
Definition quantile.h:59
RType rmax
maximum rank
Definition quantile.h:40
void CheckValid(RType eps=0) const
debug function, check Valid
Definition quantile.h:54
input data queue before entering the summary
Definition quantile.h:76
experimental wsummary
Definition quantile.h:34
Entry Query(DType qvalue, size_t &istart) const
query qvalue, start from istart
Definition quantile.h:145
size_t size
number of elements in the summary
Definition quantile.h:125
RType MaxRank() const
Definition quantile.h:166
Entry * data
data field
Definition quantile.h:123
void SetPrune(const WQSummary &src, size_t maxsize)
set current summary to be pruned summary of src assume data field is already allocated to be at least...
Definition quantile.h:220
void SetCombine(const WQSummary &sa, const WQSummary &sb)
set current summary to be merged summary of sa and sb
Definition quantile.h:256
void CheckValid(RType eps) const
debug function, validate whether the summary run consistency check to check if it is a valid summary
Definition quantile.h:204
void CopyFrom(const WQSummary &src)
copy content from src
Definition quantile.h:173
RType MaxError() const
Definition quantile.h:132
try to do efficient pruning
Definition quantile.h:359
Copyright 2015-2023 by XGBoost Contributors.