Medial Code Documentation
Loading...
Searching...
No Matches
bitfield.h
Go to the documentation of this file.
1
5#ifndef XGBOOST_COMMON_BITFIELD_H_
6#define XGBOOST_COMMON_BITFIELD_H_
7
8#include <algorithm>
9#include <bitset>
10#include <cinttypes>
11#include <iostream>
12#include <sstream>
13#include <string>
14#include <vector>
15
16#if defined(__CUDACC__)
17#include <thrust/copy.h>
18#include <thrust/device_ptr.h>
19#include "device_helpers.cuh"
20#endif // defined(__CUDACC__)
21
22#include "xgboost/span.h"
23#include "common.h"
24
25namespace xgboost {
26
27#if defined(__CUDACC__)
28using BitFieldAtomicType = unsigned long long; // NOLINT
29
30__forceinline__ __device__ BitFieldAtomicType AtomicOr(BitFieldAtomicType* address,
31 BitFieldAtomicType val) {
32 BitFieldAtomicType old = *address, assumed; // NOLINT
33 do {
34 assumed = old;
35 old = atomicCAS(address, assumed, val | assumed);
36 } while (assumed != old);
37
38 return old;
39}
40
41__forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* address,
42 BitFieldAtomicType val) {
43 BitFieldAtomicType old = *address, assumed; // NOLINT
44 do {
45 assumed = old;
46 old = atomicCAS(address, assumed, val & assumed);
47 } while (assumed != old);
48
49 return old;
50}
51#endif // defined(__CUDACC__)
52
60template <typename VT, typename Direction, bool IsConst = false>
62 using value_type = std::conditional_t<IsConst, VT const, VT>; // NOLINT
63 using size_type = size_t; // NOLINT
64 using index_type = size_t; // NOLINT
65 using pointer = value_type*; // NOLINT
66
67 static index_type constexpr kValueSize = sizeof(value_type) * 8;
68 static index_type constexpr kOne = 1; // force correct type.
69
70 struct Pos {
71 index_type int_pos{0};
72 index_type bit_pos{0};
73 };
74
75 private:
76 value_type* bits_{nullptr};
77 size_type n_values_{0};
78 static_assert(!std::is_signed<VT>::value, "Must use an unsiged type as the underlying storage.");
79
80 public:
81 XGBOOST_DEVICE static Pos ToBitPos(index_type pos) {
82 Pos pos_v;
83 if (pos == 0) {
84 return pos_v;
85 }
86 pos_v.int_pos = pos / kValueSize;
87 pos_v.bit_pos = pos % kValueSize;
88 return pos_v;
89 }
90
91 public:
92 BitFieldContainer() = default;
93 XGBOOST_DEVICE explicit BitFieldContainer(common::Span<value_type> bits)
94 : bits_{bits.data()}, n_values_{bits.size()} {}
95 BitFieldContainer(BitFieldContainer const& other) = default;
96 BitFieldContainer(BitFieldContainer&& other) = default;
97 BitFieldContainer &operator=(BitFieldContainer const &that) = default;
98 BitFieldContainer &operator=(BitFieldContainer &&that) = default;
99
100 XGBOOST_DEVICE auto Bits() { return common::Span<value_type>{bits_, NumValues()}; }
101 XGBOOST_DEVICE auto Bits() const { return common::Span<value_type const>{bits_, NumValues()}; }
102
103 /*\brief Compute the size of needed memory allocation. The returned value is in terms
104 * of number of elements with `BitFieldContainer::value_type'.
105 */
106 XGBOOST_DEVICE static size_t ComputeStorageSize(index_type size) {
107 return common::DivRoundUp(size, kValueSize);
108 }
109#if defined(__CUDA_ARCH__)
110 __device__ BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
111 auto tid = blockIdx.x * blockDim.x + threadIdx.x;
112 size_t min_size = min(NumValues(), rhs.NumValues());
113 if (tid < min_size) {
114 Data()[tid] |= rhs.Data()[tid];
115 }
116 return *this;
117 }
118#else
119 BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
120 size_t min_size = std::min(NumValues(), rhs.NumValues());
121 for (size_t i = 0; i < min_size; ++i) {
122 Data()[i] |= rhs.Data()[i];
123 }
124 return *this;
125 }
126#endif // #if defined(__CUDA_ARCH__)
127
128#if defined(__CUDA_ARCH__)
129 __device__ BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
130 size_t min_size = min(NumValues(), rhs.NumValues());
131 auto tid = blockIdx.x * blockDim.x + threadIdx.x;
132 if (tid < min_size) {
133 Data()[tid] &= rhs.Data()[tid];
134 }
135 return *this;
136 }
137#else
138 BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
139 size_t min_size = std::min(NumValues(), rhs.NumValues());
140 for (size_t i = 0; i < min_size; ++i) {
141 Data()[i] &= rhs.Data()[i];
142 }
143 return *this;
144 }
145#endif // defined(__CUDA_ARCH__)
146
147#if defined(__CUDA_ARCH__)
148 __device__ auto Set(index_type pos) noexcept(true) {
149 Pos pos_v = Direction::Shift(ToBitPos(pos));
150 value_type& value = Data()[pos_v.int_pos];
151 value_type set_bit = kOne << pos_v.bit_pos;
152 using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
153 atomicOr(reinterpret_cast<Type *>(&value), set_bit);
154 }
155 __device__ void Clear(index_type pos) noexcept(true) {
156 Pos pos_v = Direction::Shift(ToBitPos(pos));
157 value_type& value = Data()[pos_v.int_pos];
158 value_type clear_bit = ~(kOne << pos_v.bit_pos);
159 using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
160 atomicAnd(reinterpret_cast<Type *>(&value), clear_bit);
161 }
162#else
163 void Set(index_type pos) noexcept(true) {
164 Pos pos_v = Direction::Shift(ToBitPos(pos));
165 value_type& value = Data()[pos_v.int_pos];
166 value_type set_bit = kOne << pos_v.bit_pos;
167 value |= set_bit;
168 }
169 void Clear(index_type pos) noexcept(true) {
170 Pos pos_v = Direction::Shift(ToBitPos(pos));
171 value_type& value = Data()[pos_v.int_pos];
172 value_type clear_bit = ~(kOne << pos_v.bit_pos);
173 value &= clear_bit;
174 }
175#endif // defined(__CUDA_ARCH__)
176
177 XGBOOST_DEVICE bool Check(Pos pos_v) const noexcept(true) {
178 pos_v = Direction::Shift(pos_v);
179 assert(pos_v.int_pos < NumValues());
180 value_type const value = Data()[pos_v.int_pos];
181 value_type const test_bit = kOne << pos_v.bit_pos;
182 value_type result = test_bit & value;
183 return static_cast<bool>(result);
184 }
185 [[nodiscard]] XGBOOST_DEVICE bool Check(index_type pos) const noexcept(true) {
186 Pos pos_v = ToBitPos(pos);
187 return Check(pos_v);
188 }
193 [[nodiscard]] XGBOOST_DEVICE size_type Capacity() const noexcept(true) {
194 return kValueSize * NumValues();
195 }
199 [[nodiscard]] XGBOOST_DEVICE size_type NumValues() const noexcept(true) { return n_values_; }
200
201 XGBOOST_DEVICE pointer Data() const noexcept(true) { return bits_; }
202
203 inline friend std::ostream& operator<<(std::ostream& os,
204 BitFieldContainer<VT, Direction, IsConst> field) {
205 os << "Bits "
206 << "storage size: " << field.NumValues() << "\n";
207 for (typename common::Span<value_type>::index_type i = 0; i < field.NumValues(); ++i) {
208 std::bitset<BitFieldContainer<VT, Direction, IsConst>::kValueSize> bset(field.Data()[i]);
209 os << bset << "\n";
210 }
211 return os;
212 }
213};
214
215// Bits start from left most bits (most significant bit).
216template <typename VT, bool IsConst = false>
217struct LBitsPolicy : public BitFieldContainer<VT, LBitsPolicy<VT, IsConst>, IsConst> {
219 using Pos = typename Container::Pos;
220 using value_type = typename Container::value_type; // NOLINT
221
222 XGBOOST_DEVICE static Pos Shift(Pos pos) {
223 pos.bit_pos = Container::kValueSize - pos.bit_pos - Container::kOne;
224 return pos;
225 }
226};
227
228// Bits start from right most bit (least significant bit) of each entry, but integer index
229// is from left to right.
230template <typename VT>
231struct RBitsPolicy : public BitFieldContainer<VT, RBitsPolicy<VT>> {
233 using Pos = typename Container::Pos;
234 using value_type = typename Container::value_type; // NOLINT
235
236 XGBOOST_DEVICE static Pos Shift(Pos pos) {
237 return pos;
238 }
239};
240
241// Format: <Const><Direction>BitField<size of underlying type in bits>, underlying type
242// must be unsigned.
245
248} // namespace xgboost
249
250#endif // XGBOOST_COMMON_BITFIELD_H_
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition base.h:64
namespace of xgboost
Definition base.h:90
Definition bitfield.h:70
A non-owning type with auxiliary methods defined for manipulating bits.
Definition bitfield.h:61
XGBOOST_DEVICE size_type NumValues() const noexcept(true)
Number of storage unit used in this bit field.
Definition bitfield.h:199
XGBOOST_DEVICE size_type Capacity() const noexcept(true)
Returns the total number of bits that can be viewed.
Definition bitfield.h:193
Definition bitfield.h:217
Definition bitfield.h:231
Copyright 2015-2023 by XGBoost Contributors.