Medial Code Documentation
Loading...
Searching...
No Matches
linalg_op.h
1
4#ifndef XGBOOST_COMMON_LINALG_OP_H_
5#define XGBOOST_COMMON_LINALG_OP_H_
6#include <cstdint> // std::int32_t
7#include <type_traits>
8
9#include "common.h"
10#include "threading_utils.h"
11#include "transform_iterator.h" // MakeIndexTransformIter
12#include "xgboost/context.h" // Context
13#include "xgboost/linalg.h"
14
15namespace xgboost {
16namespace linalg {
17template <typename T, int32_t D, typename Fn>
18void ElementWiseTransformHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& fn) {
19 if (t.Contiguous()) {
20 auto ptr = t.Values().data();
21 common::ParallelFor(t.Size(), n_threads, [&](size_t i) { ptr[i] = fn(i, ptr[i]); });
22 } else {
23 common::ParallelFor(t.Size(), n_threads, [&](size_t i) {
24 auto& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
25 v = fn(i, v);
26 });
27 }
28}
29
30template <typename T, int32_t D, typename Fn>
31void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& fn) {
32 static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
33 "For function with return, use transform instead.");
34 if (t.Contiguous()) {
35 auto ptr = t.Values().data();
36 common::ParallelFor(t.Size(), n_threads, [&](size_t i) { fn(i, ptr[i]); });
37 } else {
38 common::ParallelFor(t.Size(), n_threads, [&](size_t i) {
39 auto& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
40 fn(i, v);
41 });
42 }
43}
44
45#if !defined(XGBOOST_USE_CUDA)
46template <typename T, int32_t D, typename Fn>
47void ElementWiseKernelDevice(linalg::TensorView<T, D>, Fn&&, void* = nullptr) {
48 common::AssertGPUSupport();
49}
50
51template <typename T, int32_t D, typename Fn>
52void ElementWiseTransformDevice(linalg::TensorView<T, D>, Fn&&, void* = nullptr) {
53 common::AssertGPUSupport();
54}
55
56template <typename T, int32_t D, typename Fn>
57void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
58 if (!ctx->IsCPU()) {
59 common::AssertGPUSupport();
60 }
61 ElementWiseKernelHost(t, ctx->Threads(), fn);
62}
63#endif // !defined(XGBOOST_USE_CUDA)
64
65template <typename T, std::int32_t kDim>
66auto cbegin(TensorView<T, kDim> const& v) { // NOLINT
67 auto it = common::MakeIndexTransformIter([&](size_t i) -> std::remove_cv_t<T> const& {
68 return linalg::detail::Apply(v, linalg::UnravelIndex(i, v.Shape()));
69 });
70 return it;
71}
72
73template <typename T, std::int32_t kDim>
74auto cend(TensorView<T, kDim> const& v) { // NOLINT
75 return cbegin(v) + v.Size();
76}
77
78template <typename T, std::int32_t kDim>
79auto begin(TensorView<T, kDim>& v) { // NOLINT
80 auto it = common::MakeIndexTransformIter(
81 [&](size_t i) -> T& { return linalg::detail::Apply(v, linalg::UnravelIndex(i, v.Shape())); });
82 return it;
83}
84
85template <typename T, std::int32_t kDim>
86auto end(TensorView<T, kDim>& v) { // NOLINT
87 return begin(v) + v.Size();
88}
89} // namespace linalg
90} // namespace xgboost
91#endif // XGBOOST_COMMON_LINALG_OP_H_
Copyright 2014-2023, XGBoost Contributors.
Copyright 2021-2023 by XGBoost Contributors.
LINALG_HD auto UnravelIndex(size_t idx, common::Span< size_t const, D > shape)
Turns linear index into multi-dimension index.
Definition linalg.h:613
namespace of xgboost
Definition base.h:90
std::int32_t Threads() const
Returns the automatically chosen number of threads based on the nthread parameter and the system sett...
Definition context.cc:203
bool IsCPU() const
Is XGBoost running on CPU?
Definition context.h:133
Copyright 2015-2023 by XGBoost Contributors.