Medial Code Documentation
Loading...
Searching...
No Matches
adaptive.h
1
4#pragma once
5
6#include <algorithm>
7#include <cstdint> // std::int32_t
8#include <limits>
9#include <vector> // std::vector
10
11#include "../collective/aggregator.h"
12#include "../collective/communicator-inl.h"
13#include "../common/common.h"
14#include "xgboost/base.h" // bst_node_t
15#include "xgboost/context.h" // Context
16#include "xgboost/data.h" // MetaInfo
17#include "xgboost/host_device_vector.h" // HostDeviceVector
18#include "xgboost/tree_model.h" // RegTree
19
20namespace xgboost {
21namespace obj {
22namespace detail {
23inline void FillMissingLeaf(std::vector<bst_node_t> const& maybe_missing,
24 std::vector<bst_node_t>* p_nidx, std::vector<size_t>* p_nptr) {
25 auto& h_node_idx = *p_nidx;
26 auto& h_node_ptr = *p_nptr;
27
28 for (auto leaf : maybe_missing) {
29 if (std::binary_search(h_node_idx.cbegin(), h_node_idx.cend(), leaf)) {
30 continue;
31 }
32 auto it = std::upper_bound(h_node_idx.cbegin(), h_node_idx.cend(), leaf);
33 auto pos = it - h_node_idx.cbegin();
34 h_node_idx.insert(h_node_idx.cbegin() + pos, leaf);
35 h_node_ptr.insert(h_node_ptr.cbegin() + pos, h_node_ptr[pos]);
36 }
37}
38
39inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_node_t> const& nidx,
40 MetaInfo const& info, float learning_rate, RegTree* p_tree) {
41 auto& tree = *p_tree;
42 auto& quantiles = *p_quantiles;
43 auto const& h_node_idx = nidx;
44
45 size_t n_leaf = collective::GlobalMax(info, h_node_idx.size());
46 CHECK(quantiles.empty() || quantiles.size() == n_leaf);
47 if (quantiles.empty()) {
48 quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
49 }
50
51 // number of workers that have valid quantiles
52 std::vector<int32_t> n_valids(quantiles.size());
53 std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(),
54 [](float q) { return static_cast<int32_t>(!std::isnan(q)); });
55 collective::GlobalSum(info, &n_valids);
56 // convert to 0 for all reduce
57 std::replace_if(
58 quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f);
59 // use the mean value
60 collective::GlobalSum(info, &quantiles);
61 for (size_t i = 0; i < n_leaf; ++i) {
62 if (n_valids[i] > 0) {
63 quantiles[i] /= static_cast<float>(n_valids[i]);
64 } else {
65 // Use original leaf value if no worker can provide the quantile.
66 quantiles[i] = tree[h_node_idx[i]].LeafValue();
67 }
68 }
69
70 for (size_t i = 0; i < nidx.size(); ++i) {
71 auto nidx = h_node_idx[i];
72 auto q = quantiles[i];
73 CHECK(tree[nidx].IsLeaf());
74 tree[nidx].SetLeaf(q * learning_rate);
75 }
76}
77
78inline std::size_t IdxY(MetaInfo const& info, bst_group_t group_idx) {
79 std::size_t y_idx{0};
80 if (info.labels.Shape(1) > 1) {
81 y_idx = group_idx;
82 }
83 CHECK_LE(y_idx, info.labels.Shape(1));
84 return y_idx;
85}
86
87void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
88 std::int32_t group_idx, MetaInfo const& info, float learning_rate,
89 HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
90
91void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
92 std::int32_t group_idx, MetaInfo const& info, float learning_rate,
93 HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
94} // namespace detail
95
96inline void UpdateTreeLeaf(Context const* ctx, HostDeviceVector<bst_node_t> const& position,
97 std::int32_t group_idx, MetaInfo const& info, float learning_rate,
98 HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
99 if (ctx->IsCPU()) {
100 detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, learning_rate,
101 predt, alpha, p_tree);
102 } else {
103 position.SetDevice(ctx->gpu_id);
104 detail::UpdateTreeLeafDevice(ctx, position.ConstDeviceSpan(), group_idx, info, learning_rate,
105 predt, alpha, p_tree);
106 }
107}
108} // namespace obj
109} // namespace xgboost
Copyright 2014-2023, XGBoost Contributors.
A device-and-host vector abstraction layer.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
detail namespace with internal helper functions
Definition json.hpp:249
void GlobalSum(MetaInfo const &info, T *values, size_t size)
Find the global sum of the given values across all workers.
Definition aggregator.h:91
T GlobalMax(MetaInfo const &info, T value)
Find the global max of the given value across all workers.
Definition aggregator.h:72
namespace of xgboost
Definition base.h:90
std::uint32_t bst_group_t
Type for ranking group index.
Definition base.h:114
bool IsCPU() const
Is XGBoost running on CPU?
Definition context.h:133
Copyright 2014-2023 by Contributors.