40 template <
typename AdapterBatchT>
41 auto GetRowCounts(AdapterBatchT
const& batch,
float missing, int32_t n_threads) {
42 std::vector<size_t> valid_counts(batch.Size(), 0);
43 common::ParallelFor(batch.Size(), n_threads, [&](
size_t i) {
44 auto line = batch.GetLine(i);
45 for (size_t j = 0; j < line.Size(); ++j) {
46 data::COOTuple elem = line.GetElement(j);
47 if (data::IsValidFunctor {missing}(elem)) {
61 template <
typename Batch,
typename BinIdxType,
typename GetOffset,
typename IsVal
id>
64 IsValid&& is_valid,
size_t nbins, GetOffset&& get_offset) {
65 auto batch_size = batch.Size();
66 BinIdxType* index_data = index_data_span.data();
67 auto const& ptrs = cut.Ptrs();
68 auto const& values = cut.Values();
69 std::atomic<bool> valid{
true};
70 common::ParallelFor(batch_size, batch_threads, [&](
size_t i) {
71 auto line = batch.GetLine(i);
72 size_t ibegin = row_ptr[rbegin + i];
74 auto tid = omp_get_thread_num();
75 for (
size_t j = 0; j < line.Size(); ++j) {
76 data::COOTuple elem = line.GetElement(j);
78 if (XGBOOST_EXPECT((std::isinf(elem.value)),
false)) {
82 if (common::IsCat(ft, elem.column_idx)) {
83 bin_idx = cut.SearchCatBin(elem.value, elem.column_idx, ptrs, values);
85 bin_idx = cut.SearchBin(elem.value, elem.column_idx, ptrs, values);
87 index_data[ibegin + k] = get_offset(bin_idx, j);
88 ++hit_count_tloc_[tid * nbins + bin_idx];
94 CHECK(valid) << error::InfInData();
98 void GatherHitCount(int32_t n_threads,
bst_bin_t n_bins_total) {
99 CHECK_EQ(hit_count.size(), n_bins_total);
100 common::ParallelFor(n_bins_total, n_threads, [&](
bst_omp_uint idx) {
101 for (int32_t tid = 0; tid < n_threads; ++tid) {
102 hit_count[idx] += hit_count_tloc_[tid * n_bins_total + idx];
103 hit_count_tloc_[tid * n_bins_total + idx] = 0;
108 template <
typename Batch,
typename IsVal
id>
109 void PushBatchImpl(int32_t n_threads, Batch
const& batch,
size_t rbegin, IsValid&& is_valid,
110 common::Span<FeatureType const> ft) {
113 size_t batch_threads =
114 std::max(
static_cast<size_t>(1), std::min(batch.Size(),
static_cast<size_t>(n_threads)));
116 auto n_bins_total = cut.TotalBins();
117 const size_t n_index = row_ptr[rbegin + batch.Size()];
118 ResizeIndex(n_index, isDense_);
120 index.SetBinOffset(cut.Ptrs());
124 using T = decltype(dtype);
125 common::Span<T> index_data_span = {index.data<T>(), index.Size()};
126 SetIndexData(index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total,
127 index.MakeCompressor<T>());
130 common::Span<uint32_t> index_data_span = {index.data<uint32_t>(), n_index};
132 SetIndexData(index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total,
133 [](
auto idx,
auto) {
return idx; });
135 this->GatherHitCount(n_threads, n_bins_total);
154 [[nodiscard]]
bst_bin_t MaxNumBinPerFeat()
const {
155 return std::max(
static_cast<bst_bin_t>(cut.MaxCategory() + 1), max_numeric_bins_per_feat);
162 GHistIndexMatrix(Context
const* ctx, DMatrix* x, bst_bin_t max_bins_per_feat,
163 double sparse_thresh,
bool sorted_sketch, common::Span<float const> hess = {});
168 GHistIndexMatrix(MetaInfo
const& info, common::HistogramCuts&& cuts, bst_bin_t max_bin_per_feat);
173 GHistIndexMatrix(Context
const* ctx, MetaInfo
const& info, EllpackPage
const& page,
174 BatchParam
const& p);
179 GHistIndexMatrix(SparsePage
const& page, common::Span<FeatureType const> ft,
180 common::HistogramCuts cuts, int32_t max_bins_per_feat,
bool is_dense,
181 double sparse_thresh, int32_t n_threads);
184 template <
typename Batch>
185 void PushAdapterBatch(Context
const* ctx,
size_t rbegin,
size_t prev_sum, Batch
const& batch,
186 float missing, common::Span<FeatureType const> ft,
double sparse_thresh,
187 size_t n_samples_total) {
188 auto n_bins_total = cut.TotalBins();
189 hit_count_tloc_.clear();
190 hit_count_tloc_.resize(ctx->
Threads() * n_bins_total, 0);
192 auto n_threads = ctx->
Threads();
193 auto valid_counts = GetRowCounts(batch, missing, n_threads);
195 auto it = common::MakeIndexTransformIter([&](
size_t ridx) {
return valid_counts[ridx]; });
196 common::PartialSum(n_threads, it, it + batch.Size(), prev_sum, row_ptr.begin() + rbegin);
197 auto is_valid = data::IsValidFunctor{missing};
199 PushBatchImpl(ctx->
Threads(), batch, rbegin, is_valid, ft);
201 if (rbegin + batch.Size() == n_samples_total) {
203 CHECK(!std::isnan(sparse_thresh));
204 this->columns_ = std::make_unique<common::ColumnMatrix>(*
this, sparse_thresh);
209 template <
typename Batch>
210 void PushAdapterBatchColumns(Context
const* ctx, Batch
const& batch,
float missing,
213 void ResizeIndex(
const size_t n_index,
const bool isDense);
215 void GetFeatureCounts(
size_t* counts)
const {
216 auto nfeature = cut.Ptrs().size() - 1;
217 for (
unsigned fid = 0; fid < nfeature; ++fid) {
218 auto ibegin = cut.Ptrs()[fid];
219 auto iend = cut.Ptrs()[fid + 1];
220 for (
auto i = ibegin; i < iend; ++i) {
221 counts[fid] += hit_count[i];
226 [[nodiscard]]
bool IsDense()
const {
return isDense_; }
227 void SetDense(
bool is_dense) { isDense_ = is_dense; }
231 [[nodiscard]] std::size_t
RowIdx(
size_t ridx)
const {
return row_ptr[ridx - base_rowid]; }
233 [[nodiscard]]
bst_row_t Size()
const {
return row_ptr.empty() ? 0 : row_ptr.size() - 1; }
234 [[nodiscard]] bst_feature_t Features()
const {
return cut.Ptrs().size() - 1; }
236 [[nodiscard]]
bool ReadColumnPage(common::AlignedResourceReadStream* fi);
237 [[nodiscard]] std::size_t WriteColumnPage(common::AlignedFileWriteStream* fo)
const;
239 [[nodiscard]] common::ColumnMatrix
const& Transpose()
const;
241 [[nodiscard]]
bst_bin_t GetGindex(
size_t ridx,
size_t fidx)
const;
243 [[nodiscard]]
float GetFvalue(
size_t ridx,
size_t fidx,
bool is_cat)
const;
244 [[nodiscard]]
float GetFvalue(std::vector<std::uint32_t>
const& ptrs,
245 std::vector<float>
const& values, std::vector<float>
const& mins,
246 bst_row_t ridx, bst_feature_t fidx,
bool is_cat)
const;
248 [[nodiscard]] common::HistogramCuts& Cuts() {
return cut; }
249 [[nodiscard]] common::HistogramCuts
const& Cuts()
const {
return cut; }
252 std::unique_ptr<common::ColumnMatrix> columns_;
253 std::vector<size_t> hit_count_tloc_;
In-memory storage unit of sparse batch, stored in CSR format.
Definition data.h:328
std::int32_t Threads() const
Returns the automatically chosen number of threads based on the nthread parameter and the system sett...
Definition context.cc:203