Medial Code Documentation
Loading...
Searching...
No Matches
optional_weight.h
1
4#ifndef XGBOOST_COMMON_OPTIONAL_WEIGHT_H_
5#define XGBOOST_COMMON_OPTIONAL_WEIGHT_H_
6#include "xgboost/base.h" // XGBOOST_DEVICE
7#include "xgboost/context.h" // Context
8#include "xgboost/host_device_vector.h" // HostDeviceVector
9#include "xgboost/span.h" // Span
10
11namespace xgboost::common {
13 Span<float const> weights;
14 float dft{1.0f}; // fixme: make this compile time constant
15
16 explicit OptionalWeights(Span<float const> w) : weights{w} {}
17 explicit OptionalWeights(float w) : dft{w} {}
18
19 XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
20 [[nodiscard]] auto Empty() const { return weights.empty(); }
21 [[nodiscard]] auto Size() const { return weights.size(); }
22};
23
24inline OptionalWeights MakeOptionalWeights(Context const* ctx,
25 HostDeviceVector<float> const& weights) {
26 if (ctx->IsCUDA()) {
27 weights.SetDevice(ctx->gpu_id);
28 }
29 return OptionalWeights{ctx->IsCPU() ? weights.ConstHostSpan() : weights.ConstDeviceSpan()};
30}
31} // namespace xgboost::common
32#endif // XGBOOST_COMMON_OPTIONAL_WEIGHT_H_
Definition host_device_vector.h:87
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
Copyright 2014-2023, XGBoost Contributors.
A device-and-host vector abstraction layer.
Copyright 2015-2023 by XGBoost Contributors.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition base.h:64
Copyright 2017-2023, XGBoost Contributors.
Definition span.h:77
Runtime context for XGBoost.
Definition context.h:84
bool IsCPU() const
Is XGBoost running on CPU?
Definition context.h:133
bool IsCUDA() const
Is XGBoost running on a CUDA device?
Definition context.h:137
Definition optional_weight.h:12