4#ifndef XGBOOST_TREE_HIST_SAMPLER_H_
5#define XGBOOST_TREE_HIST_SAMPLER_H_
11#include "../../common/random.h"
23 static constexpr std::uint64_t kBase = 16807;
24 static constexpr std::uint64_t kMod =
static_cast<std::uint64_t
>(1) << 63;
26 using EngineT = std::linear_congruential_engine<uint64_t, kBase, 0, kMod>;
31 static std::uint64_t SimpleSkip(std::uint64_t exponent, std::uint64_t initial_seed,
32 std::uint64_t base, std::uint64_t mod) {
33 CHECK_LE(exponent, mod);
34 std::uint64_t result = 1;
35 while (exponent > 0) {
36 if (exponent % 2 == 1) {
37 result = (result * base) % mod;
39 base = (base * base) % mod;
40 exponent = exponent >> 1;
43 return (result * initial_seed) % mod;
50 CHECK(out.Contiguous());
51 CHECK_EQ(param.sampling_method, TrainParam::kUniform)
52 <<
"Only uniform sampling is supported, gradient-based sampling is only support by GPU Hist.";
54 if (param.subsample >= 1.0) {
60#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG
61 std::bernoulli_distribution coin_flip(param.subsample);
62 CHECK_EQ(out.Shape(1), 1) <<
"Multi-target with sampling for R is not yet supported.";
63 for (
size_t i = 0; i < n_samples; ++i) {
64 if (!(out(i, 0).GetHess() >= 0.0f && coin_flip(rnd)) || out(i, 0).GetGrad() == 0.0f) {
69 std::uint64_t initial_seed = rnd();
71 auto n_threads =
static_cast<size_t>(ctx->
Threads());
72 std::size_t
const discard_size = n_samples / n_threads;
73 std::bernoulli_distribution coin_flip(param.subsample);
76#pragma omp parallel num_threads(n_threads)
79 const size_t tid = omp_get_thread_num();
80 const size_t ibegin = tid * discard_size;
81 const size_t iend = (tid == (n_threads - 1)) ? n_samples : ibegin + discard_size;
83 const uint64_t displaced_seed = RandomReplace::SimpleSkip(
84 ibegin, initial_seed, RandomReplace::kBase, RandomReplace::kMod);
85 RandomReplace::EngineT eng(displaced_seed);
86 std::size_t n_targets = out.Shape(1);
88 for (std::size_t i = ibegin; i < iend; ++i) {
89 if (!coin_flip(eng)) {
90 for (std::size_t j = 0; j < n_targets; ++j) {
96 for (std::size_t i = ibegin; i < iend; ++i) {
97 if (!coin_flip(eng)) {
OMP Exception class catches, saves and rethrows exception from OMP blocks.
Definition common.h:53
void Rethrow()
should be called from the main thread to rethrow the exception
Definition common.h:84
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
Definition common.h:65
A tensor view with static type and dimension.
Definition linalg.h:293
Copyright 2014-2023, XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2021-2023 by XGBoost Contributors.
GlobalRandomEngine & GlobalRandom()
global singleton of a random engine. This random engine is thread-local and only visible to current t...
Definition common.cc:23
namespace of xgboost
Definition base.h:90
std::size_t bst_row_t
Type for data row index.
Definition base.h:110
detail::GradientPairInternal< float > GradientPair
gradient statistics pair usually needed in gradient boosting
Definition base.h:256
Runtime context for XGBoost.
Definition context.h:84
std::int32_t Threads() const
Returns the automatically chosen number of threads based on the nthread parameter and the system sett...
Definition context.cc:203
training parameters for regression tree
Definition param.h:28