7#ifndef XGBOOST_COMMON_QUANTILE_H_
8#define XGBOOST_COMMON_QUANTILE_H_
23#include "optional_weight.h"
24#include "threading_utils.h"
33template<
typename DType,
typename RType>
55 CHECK(
rmin >= 0 &&
rmax >= 0 &&
wmin >= 0) <<
"nonneg constraint";
56 CHECK(
rmax-
rmin -
wmin > -eps) <<
"relation constraint: min/max";
67 friend std::ostream& operator<<(std::ostream& os,
Entry const& e) {
68 os <<
"rmin: " << e.rmin <<
", "
69 <<
"rmax: " << e.rmax <<
", "
70 <<
"wmin: " << e.wmin <<
", "
71 <<
"value: " << e.value;
86 QEntry(DType value, RType weight)
87 : value(value), weight(weight) {}
89 inline bool operator<(
const QEntry &b)
const {
90 return value < b.value;
94 std::vector<QEntry> queue;
98 inline void Push(DType x, RType w) {
99 if (qtail == 0 || queue[qtail - 1].value != x) {
100 queue[qtail++] =
QEntry(x, w);
102 queue[qtail - 1].weight += w;
105 inline void MakeSummary(WQSummary *out) {
106 std::sort(queue.begin(), queue.begin() + qtail);
111 for (
size_t i = 0; i < qtail;) {
113 RType w = queue[i].weight;
114 while (j < qtail && queue[j].value == queue[i].value) {
115 w += queue[j].weight; ++j;
117 out->data[out->size++] = Entry(wsum, wsum + w, w, queue[i].value);
134 for (
size_t i = 1; i <
size; ++i) {
135 res = std::max(
data[i].RMaxPrev() -
data[i - 1].RMinNext(), res);
136 res = std::max(
data[i].rmax -
data[i].rmin -
data[i].wmin, res);
146 while (istart < size && qvalue >
data[istart].value) {
149 if (istart ==
size) {
151 return Entry(rmax, rmax, 0.0f, qvalue);
153 if (qvalue ==
data[istart].value) {
157 return Entry(0.0f, 0.0f, 0.0f, qvalue);
159 return Entry(
data[istart - 1].RMinNext(),
160 data[istart].RMaxPrev(),
175 CHECK_EQ(src.
size, 0);
180 CHECK_EQ(this->size, 0);
181 CHECK_EQ(src.
size, 0);
187 inline void MakeFromSorted(
const Entry* entries,
size_t n) {
189 for (
size_t i = 0; i < n;) {
192 for (; j < n && entries[j].value == entries[i].value; ++j) {}
193 data[
size++] =
Entry(entries[i].rmin, entries[i].rmax, entries[i].wmin,
205 for (
size_t i = 0; i <
size; ++i) {
208 CHECK(
data[i].rmin >=
data[i - 1].rmin +
data[i - 1].wmin) <<
"rmin range constraint";
209 CHECK(
data[i].rmax >=
data[i - 1].rmax +
data[i].wmin) <<
"rmax range constraint";
221 if (src.
size <= maxsize) {
224 const RType begin = src.
data[0].
rmax;
226 const size_t n = maxsize - 1;
230 size_t i = 1, lastidx = 0;
231 for (
size_t k = 1; k < n; ++k) {
232 RType dx2 = 2 * ((k * range) / n + begin);
234 while (i < src.
size - 1
236 if (i == src.
size - 1)
break;
242 if (i + 1 != lastidx) {
247 if (lastidx != src.
size - 1) {
268 RType aprev_rmin = 0, bprev_rmin = 0;
270 while (a != a_end && b != b_end) {
294 RType brmax = (b_end - 1)->rmax;
298 }
while (a != a_end);
301 RType armax = (a_end - 1)->rmax;
305 }
while (b != b_end);
307 this->size = dst -
data;
308 const RType tol = 10;
309 RType err_mingap, err_maxgap, err_wgap;
310 this->FixError(&err_mingap, &err_maxgap, &err_wgap);
311 if (err_mingap > tol || err_maxgap > tol || err_wgap > tol) {
312 LOG(INFO) <<
"mingap=" << err_mingap
313 <<
", maxgap=" << err_maxgap
314 <<
", wgap=" << err_wgap;
319 inline void Print()
const {
320 for (
size_t i = 0; i < this->
size; ++i) {
321 LOG(CONSOLE) <<
"[" << i <<
"] rmin=" <<
data[i].
rmin
329 inline void FixError(RType *err_mingap,
331 RType *err_wgap)
const {
335 RType prev_rmin = 0, prev_rmax = 0;
336 for (
size_t i = 0; i < this->
size; ++i) {
337 if (
data[i].rmin < prev_rmin) {
339 *err_mingap = std::max(*err_mingap, prev_rmin -
data[i].rmin);
343 if (
data[i].rmax < prev_rmax) {
345 *err_maxgap = std::max(*err_maxgap, prev_rmax -
data[i].rmax);
348 if (
data[i].rmax < rmin_next) {
350 *err_wgap = std::max(*err_wgap,
data[i].rmax - rmin_next);
358template<
typename DType,
typename RType>
366 inline static bool CheckLarge(
const Entry &e, RType chunk) {
367 return e.RMinNext() > e.RMaxPrev() + chunk;
371 if (src.
size <= maxsize) {
376 size_t n = maxsize - 2, nbig = 0;
380 if (range == 0.0f || maxsize <= 2) {
382 this->data[0] = src.
data[0];
383 this->data[1] = src.
data[src.
size - 1];
387 range = std::max(range,
static_cast<RType
>(1e-3f));
391 const RType chunk = 2 * range / n;
398 for (
size_t i = 1; i < src.
size - 1; ++i) {
401 if (CheckLarge(src.
data[i], chunk)) {
409 if (bid != src.
size - 2) {
416 LOG(INFO) <<
" check quantile stats, nbig=" << nbig <<
", n=" << n;
417 LOG(INFO) <<
" srcsize=" << src.
size <<
", maxsize=" << maxsize
418 <<
", range=" << range <<
", chunk=" << chunk;
420 CHECK(nbig < n) <<
"quantile: too many large chunk";
422 this->data[0] = src.
data[0];
427 size_t bid = 0, k = 1, lastidx = 0;
428 for (
size_t end = 1; end < src.
size; ++end) {
429 if (end == src.
size - 1 || CheckLarge(src.
data[end], chunk)) {
430 if (bid != end - 1) {
434 RType dx2 = 2 * ((k * mrange) / n + begin);
435 if (dx2 >= maxdx2)
break;
441 this->data[this->size++] = src.
data[i]; lastidx = i;
444 if (i + 1 != lastidx) {
445 this->data[this->size++] = src.
data[i + 1]; lastidx = i + 1;
450 if (lastidx != end) {
451 this->data[this->size++] = src.
data[end];
468template<
typename DType,
typename RType,
class TSummary>
471 static float constexpr kFactor = 8.0;
477 using Entry =
typename Summary::Entry;
480 std::vector<Entry> space;
482 this->space = src.space;
489 if (size > space.size()) {
501 this->
Reserve((max_nbyte -
sizeof(this->size)) /
sizeof(
Entry));
503 temp.
Reserve(this->size + src.size);
504 temp.SetCombine(*
this, src);
505 this->SetPrune(temp, space.size());
509 return sizeof(size_t) +
sizeof(
Entry) * nentry;
512 template<
typename TStream>
513 inline void Save(TStream &fo)
const {
514 fo.Write(&(this->size),
sizeof(this->size));
515 if (this->size != 0) {
516 fo.Write(this->data, this->size *
sizeof(
Entry));
520 template<
typename TStream>
521 inline void Load(TStream &fi) {
522 CHECK_EQ(fi.Read(&this->size,
sizeof(this->size)),
sizeof(this->size));
524 if (this->size != 0) {
525 CHECK_EQ(fi.Read(this->data, this->size *
sizeof(
Entry)),
526 this->size *
sizeof(
Entry));
535 inline void Init(
size_t maxn,
double eps) {
536 LimitSizeLevel(maxn, eps, &nlevel, &limit_size);
538 inqueue.queue.resize(1);
544 inline static void LimitSizeLevel
545 (
size_t maxn,
double eps,
size_t* out_nlevel,
size_t* out_limit_size) {
546 size_t& nlevel = *out_nlevel;
547 size_t& limit_size = *out_limit_size;
550 limit_size =
static_cast<size_t>(ceil(nlevel / eps)) + 1;
551 limit_size = std::min(maxn, limit_size);
552 size_t n = (1ULL << nlevel);
553 if (n * limit_size >= maxn)
break;
557 size_t n = (1ULL << nlevel);
558 CHECK(n * limit_size >= maxn) <<
"invalid init parameter";
559 CHECK(nlevel <= std::max(
static_cast<size_t>(1),
static_cast<size_t>(limit_size * eps)))
560 <<
"invalid init parameter";
568 inline void Push(DType x, RType w = 1) {
569 if (w ==
static_cast<RType
>(0))
return;
570 if (inqueue.qtail == inqueue.queue.size() && inqueue.queue[inqueue.qtail - 1].value != x) {
572 if (inqueue.queue.size() == 1) {
573 inqueue.queue.resize(limit_size * 2);
576 inqueue.MakeSummary(&temp);
585 inline void PushSummary(
const Summary& summary) {
587 temp.SetPrune(summary, limit_size * 2);
594 for (
size_t l = 1;
true; ++l) {
595 this->InitLevel(l + 1);
597 if (level[l].size == 0) {
598 level[l].SetPrune(temp, limit_size);
602 level[0].SetPrune(temp, limit_size);
603 temp.SetCombine(level[0], level[l]);
604 if (temp.size > limit_size) {
609 level[l].CopyFrom(temp);
break;
616 if (level.size() != 0) {
617 out->Reserve(limit_size * 2);
619 out->Reserve(inqueue.queue.size());
621 inqueue.MakeSummary(out);
622 if (level.size() != 0) {
623 level[0].SetPrune(*out, limit_size);
624 for (
size_t l = 1; l < level.size(); ++l) {
625 if (level[l].size == 0)
continue;
626 if (level[0].size == 0) {
627 level[0].CopyFrom(level[l]);
629 out->SetCombine(level[0], level[l]);
630 level[0].SetPrune(*out, limit_size);
633 out->CopyFrom(level[0]);
635 if (out->size > limit_size) {
637 temp.SetPrune(*out, limit_size);
643 inline void CheckValid(RType eps)
const {
644 for (
size_t l = 1; l < level.size(); ++l) {
645 level[l].CheckValid(eps);
649 inline void InitLevel(
size_t nlevel) {
650 if (level.size() >= nlevel)
return;
651 data.resize(limit_size * nlevel);
652 level.resize(nlevel,
Summary(
nullptr, 0));
653 for (
size_t l = 0; l < level.size(); ++l) {
658 typename Summary::Queue inqueue;
664 std::vector<Summary> level;
666 std::vector<Entry> data;
668 SummaryContainer temp;
676template<
typename DType,
typename RType =
unsigned>
686template<
typename DType,
typename RType =
unsigned>
692inline std::vector<float> UnrollGroupWeights(
MetaInfo const &info) {
693 std::vector<float>
const &group_weights = info.
weights_.HostVector();
694 if (group_weights.empty()) {
695 return group_weights;
699 CHECK_GE(group_ptr.size(), 2);
701 auto n_groups = group_ptr.size() - 1;
702 CHECK_EQ(info.
weights_.Size(), n_groups) << error::GroupWeight();
705 std::vector<float> results(n_samples);
706 CHECK_EQ(group_ptr.back(), n_samples)
707 << error::GroupSize() <<
" the number of rows from the data.";
708 size_t cur_group = 0;
709 for (
bst_row_t i = 0; i < n_samples; ++i) {
710 results[i] = group_weights[cur_group];
711 if (i == group_ptr[cur_group + 1]) {
721template <
typename Batch,
typename IsVal
id>
722std::vector<bst_row_t> CalcColumnSize(Batch
const &batch,
bst_feature_t const n_columns,
723 size_t const n_threads, IsValid &&is_valid) {
724 std::vector<std::vector<bst_row_t>> column_sizes_tloc(n_threads);
725 for (
auto &column : column_sizes_tloc) {
726 column.resize(n_columns, 0);
729 ParallelFor(batch.Size(), n_threads, [&](
omp_ulong i) {
730 auto &local_column_sizes = column_sizes_tloc.at(omp_get_thread_num());
731 auto const &line = batch.GetLine(i);
732 for (size_t j = 0; j < line.Size(); ++j) {
733 auto elem = line.GetElement(j);
734 if (is_valid(elem)) {
735 local_column_sizes[elem.column_idx]++;
740 auto &entries_per_columns = column_sizes_tloc.front();
741 CHECK_EQ(entries_per_columns.size(),
static_cast<size_t>(n_columns));
742 for (
size_t i = 1; i < n_threads; ++i) {
743 CHECK_EQ(column_sizes_tloc[i].size(),
static_cast<size_t>(n_columns));
744 for (
size_t j = 0; j < n_columns; ++j) {
745 entries_per_columns[j] += column_sizes_tloc[i][j];
748 return entries_per_columns;
751template <
typename Batch,
typename IsVal
id>
752std::vector<bst_feature_t> LoadBalance(Batch
const &batch,
size_t nnz, bst_feature_t n_columns,
753 size_t const nthreads, IsValid&& is_valid) {
758 size_t const total_entries = nnz;
759 size_t const entries_per_thread = DivRoundUp(total_entries, nthreads);
762 std::vector<bst_row_t> entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid);
763 std::vector<bst_feature_t> cols_ptr(nthreads + 1, 0);
765 size_t current_thread{1};
767 for (
auto col : entries_per_columns) {
768 cols_ptr.at(current_thread)++;
770 CHECK_LE(count, total_entries);
771 if (count > entries_per_thread) {
774 cols_ptr.at(current_thread) = cols_ptr[current_thread - 1];
778 for (; current_thread < cols_ptr.size() - 1; ++current_thread) {
779 cols_ptr[current_thread + 1] = cols_ptr[current_thread];
787template <
typename WQSketch>
790 std::vector<WQSketch> sketches_;
791 std::vector<std::set<float>> categories_;
792 std::vector<FeatureType>
const feature_types_;
794 std::vector<bst_row_t> columns_size_;
796 bool use_group_ind_{
false};
798 bool has_categorical_{
false};
811 static bool UseGroup(
MetaInfo const &info) {
812 size_t const num_groups =
815 bool const use_group_ind =
817 return use_group_ind;
820 static uint32_t SearchGroupIndFromRow(std::vector<bst_uint>
const &group_ptr,
821 size_t const base_rowid) {
822 CHECK_LT(base_rowid, group_ptr.back())
823 <<
"Row: " << base_rowid <<
" is not found in any group.";
825 std::upper_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid) -
826 group_ptr.cbegin() - 1;
830 void GatherSketchInfo(
MetaInfo const& info,
831 std::vector<typename WQSketch::SummaryContainer>
const &reduced,
832 std::vector<bst_row_t> *p_worker_segments,
833 std::vector<bst_row_t> *p_sketches_scan,
834 std::vector<typename WQSketch::Entry> *p_global_sketches);
836 void AllReduce(
MetaInfo const& info, std::vector<typename WQSketch::SummaryContainer> *p_reduced,
837 std::vector<int32_t> *p_num_cuts);
839 template <
typename Batch,
typename IsVal
id>
840 void PushRowPageImpl(Batch
const &batch,
size_t base_rowid,
OptionalWeights weights,
size_t nnz,
841 size_t n_features,
bool is_dense, IsValid is_valid) {
842 auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);
845#pragma omp parallel num_threads(n_threads_)
848 auto tid =
static_cast<uint32_t
>(omp_get_thread_num());
849 auto const begin = thread_columns_ptr[tid];
850 auto const end = thread_columns_ptr[tid + 1];
853 if (begin < end && end <= n_features) {
854 for (
size_t ridx = 0; ridx < batch.Size(); ++ridx) {
855 auto const &line = batch.GetLine(ridx);
856 auto w = weights[ridx + base_rowid];
858 for (
size_t ii = begin; ii < end; ii++) {
859 auto elem = line.GetElement(ii);
860 if (is_valid(elem)) {
861 if (IsCat(feature_types_, ii)) {
862 categories_[ii].emplace(elem.value);
864 sketches_[ii].Push(elem.value, w);
869 for (
size_t i = 0; i < line.Size(); ++i) {
870 auto const &elem = line.GetElement(i);
871 if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
872 if (IsCat(feature_types_, elem.column_idx)) {
873 categories_[elem.column_idx].emplace(elem.value);
875 sketches_[elem.column_idx].Push(elem.value, w);
894 void AllreduceCategories(
MetaInfo const& info);
903 std::vector<size_t> columns_size,
bool use_group);
905 template <
typename Batch>
906 void PushAdapterBatch(Batch
const &batch,
size_t base_rowid,
MetaInfo const &info,
float missing);
914 double sum_total{0.0};
924 inline void Init(
unsigned max_size) {
927 sketch->temp.
Reserve(max_size + 1);
928 sketch->temp.size = 0;
937 if (next_goal == -1.0f) {
939 last_fvalue = fvalue;
943 if (last_fvalue != fvalue) {
944 double rmax = rmin + wmin;
945 if (rmax >= next_goal && sketch->temp.size != max_size) {
946 if (sketch->temp.size == 0 ||
947 last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
949 sketch->temp.data[sketch->temp.size] =
952 static_cast<bst_float>(wmin), last_fvalue);
953 CHECK_LT(sketch->temp.size, max_size) <<
"invalid maximum size max_size=" << max_size
954 <<
", stemp.size" << sketch->temp.size;
957 if (sketch->temp.size == max_size) {
958 next_goal = sum_total * 2.0f + 1e-5f;
960 next_goal =
static_cast<bst_float>(sketch->temp.size * sum_total / max_size);
963 if (rmax >= next_goal) {
964 LOG(DEBUG) <<
"INFO: rmax=" << rmax <<
", sum_total=" << sum_total
965 <<
", naxt_goal=" << next_goal <<
", size=" << sketch->temp.size;
970 last_fvalue = fvalue;
978 double rmax = rmin + wmin;
979 if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
980 CHECK_LE(sketch->temp.size, max_size)
981 <<
"Finalize: invalid maximum size, max_size=" << max_size
982 <<
", stemp.size=" << sketch->temp.size;
994 std::vector<SortedQuantile> sketches_;
1000 std::vector<size_t> columns_size,
bool use_group)
1002 monitor_.Init(__func__);
1003 sketches_.resize(columns_size.size());
1005 for (
auto &sketch : sketches_) {
1006 sketch.sketch = &Super::sketches_[i];
1007 sketch.Init(max_bins_);
1008 auto eps = 2.0 / max_bins;
1009 sketch.sketch->Init(columns_size_[i], eps);
Copyright 2020-2023, XGBoost Contributors.
OMP Exception class catches, saves and rethrows exception from OMP blocks.
Definition common.h:53
void Rethrow()
should be called from the main thread to rethrow the exception
Definition common.h:84
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
Definition common.h:65
In-memory storage unit of sparse batch, stored in CSR format.
Definition data.h:328
Definition hist_util.h:37
Definition quantile.h:897
template for all quantile sketch algorithm that uses merge/prune scheme
Definition quantile.h:469
TSummary Summary
type of summary type
Definition quantile.h:475
typename Summary::Entry Entry
the entry type
Definition quantile.h:477
void Init(size_t maxn, double eps)
initialize the quantile sketch, given the performance specification
Definition quantile.h:535
void GetSummary(SummaryContainer *out)
get the summary after finalize
Definition quantile.h:615
void PushTemp()
push up temp
Definition quantile.h:592
void Push(DType x, RType w=1)
add an element to a sketch
Definition quantile.h:568
Definition quantile.h:788
Definition quantile.h:993
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
Quantile sketch use WQSummary.
Definition quantile.h:678
Quantile sketch use WXQSummary.
Definition quantile.h:688
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition base.h:64
Copyright 2015-2023 by XGBoost Contributors.
defines console logging options for xgboost. Use to enforce unified print behavior.
detail namespace with internal helper functions
Definition json.hpp:249
T * BeginPtr(std::vector< T > &vec)
safely get the beginning address of a vector
Definition base.h:284
Copyright 2017-2023, XGBoost Contributors.
Definition span.h:77
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
std::uint32_t bst_group_t
Type for ranking group index.
Definition base.h:114
dmlc::omp_ulong omp_ulong
define unsigned long for openmp loop
Definition base.h:322
std::size_t bst_row_t
Type for data row index.
Definition base.h:110
int32_t bst_bin_t
Type for histogram bin index.
Definition base.h:103
float bst_float
float type, used for storing statistics
Definition base.h:97
Runtime context for XGBoost.
Definition context.h:84
Element from a sparse vector.
Definition data.h:216
Timing utility used to measure total method execution time over the lifetime of the containing object...
Definition timer.h:47
Definition optional_weight.h:12
same as summary, but use STL to backup the space
Definition quantile.h:479
void Reserve(size_t size)
reserve space for summary
Definition quantile.h:488
void Load(TStream &fi)
load data structure from input stream
Definition quantile.h:521
void Save(TStream &fo) const
save the data structure into stream
Definition quantile.h:513
static size_t CalcMemCost(size_t nentry)
return the number of bytes this data structure cost in serialization
Definition quantile.h:508
void Reduce(const Summary &src, size_t max_nbyte)
do elementwise combination of summary array this[i] = combine(this[i], src[i]) for each i
Definition quantile.h:500
Quantile structure accepts sorted data, extracted from histmaker.
Definition quantile.h:912
void Finalize(unsigned max_size)
push final unfinished value to the sketch
Definition quantile.h:977
void Push(bst_float fvalue, bst_float w, unsigned max_size)
push a new element to sketch
Definition quantile.h:936
bst_float last_fvalue
last seen feature value
Definition quantile.h:918
double next_goal
current size of sketch
Definition quantile.h:920
double rmin
statistics used in the sketch
Definition quantile.h:916
an entry in the sketch summary
Definition quantile.h:36
DType value
the value of data
Definition quantile.h:44
RType wmin
maximum weight
Definition quantile.h:42
XGBOOST_DEVICE RType RMaxPrev() const
Definition quantile.h:63
RType rmin
minimum rank
Definition quantile.h:38
XGBOOST_DEVICE RType RMinNext() const
Definition quantile.h:59
RType rmax
maximum rank
Definition quantile.h:40
void CheckValid(RType eps=0) const
debug function, check Valid
Definition quantile.h:54
input data queue before entering the summary
Definition quantile.h:76
experimental wsummary
Definition quantile.h:34
Entry Query(DType qvalue, size_t &istart) const
query qvalue, start from istart
Definition quantile.h:145
size_t size
number of elements in the summary
Definition quantile.h:125
RType MaxRank() const
Definition quantile.h:166
Entry * data
data field
Definition quantile.h:123
void SetPrune(const WQSummary &src, size_t maxsize)
set current summary to be pruned summary of src assume data field is already allocated to be at least...
Definition quantile.h:220
void SetCombine(const WQSummary &sa, const WQSummary &sb)
set current summary to be merged summary of sa and sb
Definition quantile.h:256
void CheckValid(RType eps) const
debug function, validate whether the summary run consistency check to check if it is a valid summary
Definition quantile.h:204
void CopyFrom(const WQSummary &src)
copy content from src
Definition quantile.h:173
RType MaxError() const
Definition quantile.h:132
try to do efficient pruning
Definition quantile.h:359
Copyright 2015-2023 by XGBoost Contributors.