11#include "../collective/aggregator.h"
12#include "../collective/communicator-inl.h"
13#include "../common/common.h"
23inline void FillMissingLeaf(std::vector<bst_node_t>
const& maybe_missing,
24 std::vector<bst_node_t>* p_nidx, std::vector<size_t>* p_nptr) {
25 auto& h_node_idx = *p_nidx;
26 auto& h_node_ptr = *p_nptr;
28 for (
auto leaf : maybe_missing) {
29 if (std::binary_search(h_node_idx.cbegin(), h_node_idx.cend(), leaf)) {
32 auto it = std::upper_bound(h_node_idx.cbegin(), h_node_idx.cend(), leaf);
33 auto pos = it - h_node_idx.cbegin();
34 h_node_idx.insert(h_node_idx.cbegin() + pos, leaf);
35 h_node_ptr.insert(h_node_ptr.cbegin() + pos, h_node_ptr[pos]);
39inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_node_t>
const& nidx,
40 MetaInfo
const& info,
float learning_rate, RegTree* p_tree) {
42 auto& quantiles = *p_quantiles;
43 auto const& h_node_idx = nidx;
46 CHECK(quantiles.empty() || quantiles.size() == n_leaf);
47 if (quantiles.empty()) {
48 quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
52 std::vector<int32_t> n_valids(quantiles.size());
53 std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(),
54 [](
float q) { return static_cast<int32_t>(!std::isnan(q)); });
58 quantiles.begin(), quantiles.end(), [](
float q) { return std::isnan(q); }, 0.f);
61 for (
size_t i = 0; i < n_leaf; ++i) {
62 if (n_valids[i] > 0) {
63 quantiles[i] /=
static_cast<float>(n_valids[i]);
66 quantiles[i] = tree[h_node_idx[i]].LeafValue();
70 for (
size_t i = 0; i < nidx.size(); ++i) {
71 auto nidx = h_node_idx[i];
72 auto q = quantiles[i];
73 CHECK(tree[nidx].IsLeaf());
74 tree[nidx].SetLeaf(q * learning_rate);
78inline std::size_t IdxY(MetaInfo
const& info,
bst_group_t group_idx) {
80 if (info.labels.Shape(1) > 1) {
83 CHECK_LE(y_idx, info.labels.Shape(1));
87void UpdateTreeLeafDevice(Context
const* ctx, common::Span<bst_node_t const> position,
88 std::int32_t group_idx, MetaInfo
const& info,
float learning_rate,
89 HostDeviceVector<float>
const& predt,
float alpha, RegTree* p_tree);
91void UpdateTreeLeafHost(Context
const* ctx, std::vector<bst_node_t>
const& position,
92 std::int32_t group_idx, MetaInfo
const& info,
float learning_rate,
93 HostDeviceVector<float>
const& predt,
float alpha, RegTree* p_tree);
96inline void UpdateTreeLeaf(Context
const* ctx, HostDeviceVector<bst_node_t>
const& position,
97 std::int32_t group_idx, MetaInfo
const& info,
float learning_rate,
98 HostDeviceVector<float>
const& predt,
float alpha, RegTree* p_tree) {
100 detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, learning_rate,
101 predt, alpha, p_tree);
103 position.SetDevice(ctx->gpu_id);
104 detail::UpdateTreeLeafDevice(ctx, position.ConstDeviceSpan(), group_idx, info, learning_rate,
105 predt, alpha, p_tree);
Copyright 2014-2023, XGBoost Contributors.
A device-and-host vector abstraction layer.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
detail namespace with internal helper functions
Definition json.hpp:249
void GlobalSum(MetaInfo const &info, T *values, size_t size)
Find the global sum of the given values across all workers.
Definition aggregator.h:91
T GlobalMax(MetaInfo const &info, T value)
Find the global max of the given value across all workers.
Definition aggregator.h:72
namespace of xgboost
Definition base.h:90
std::uint32_t bst_group_t
Type for ranking group index.
Definition base.h:114
bool IsCPU() const
Is XGBoost running on CPU?
Definition context.h:133
Copyright 2014-2023 by Contributors.