Medial Code Documentation
Loading...
Searching...
No Matches
span.h
1
29#ifndef XGBOOST_SPAN_H_
30#define XGBOOST_SPAN_H_
31
32#include <xgboost/base.h>
33#include <xgboost/logging.h>
34
35#include <cinttypes> // size_t
36#include <cstdio>
37#include <iterator>
38#include <limits> // numeric_limits
39#include <type_traits>
40#include <utility> // for move
41
42#if defined(__CUDACC__)
43#include <cuda_runtime.h>
44#endif // defined(__CUDACC__)
45
63#if defined(_MSC_VER) && _MSC_VER < 1910
64
65#define __span_noexcept
66
67#pragma push_macro("constexpr")
68#define constexpr /*constexpr*/
69
70#else
71
72#define __span_noexcept noexcept
73
74#endif // defined(_MSC_VER) && _MSC_VER < 1910
75
76namespace xgboost {
77namespace common {
78
79#if defined(__CUDA_ARCH__)
80// Usual logging facility is not available inside device code.
81
82#if defined(_MSC_VER)
83
84// Windows CUDA doesn't have __assert_fail.
85#define CUDA_KERNEL_CHECK(cond) \
86 do { \
87 if (XGBOOST_EXPECT(!(cond), false)) { \
88 asm("trap;"); \
89 } \
90 } while (0)
91
92#else // defined(_MSC_VER)
93
94#define __ASSERT_STR_HELPER(x) #x
95
96#define CUDA_KERNEL_CHECK(cond) \
97 (XGBOOST_EXPECT((cond), true) \
98 ? static_cast<void>(0) \
99 : __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, __PRETTY_FUNCTION__))
100
101#endif // defined(_MSC_VER)
102
103#define KERNEL_CHECK CUDA_KERNEL_CHECK
104
105#define SPAN_CHECK KERNEL_CHECK
106
107#else // ------------------------------ not CUDA ----------------------------
108
109#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
110
111#define KERNEL_CHECK(cond)
112
113#define SPAN_CHECK(cond) KERNEL_CHECK(cond)
114
115#else
116
117#define KERNEL_CHECK(cond) (XGBOOST_EXPECT((cond), true) ? static_cast<void>(0) : std::terminate())
118
119#define SPAN_CHECK(cond) KERNEL_CHECK(cond)
120
121#endif // defined(XGBOOST_STRICT_R_MODE)
122
123#endif // __CUDA_ARCH__
124
125#define SPAN_LT(lhs, rhs) SPAN_CHECK((lhs) < (rhs))
126
127namespace detail {
134using ptrdiff_t = typename std::conditional< // NOLINT
135 std::is_same<std::ptrdiff_t, std::int64_t>::value,
136 std::ptrdiff_t, std::int64_t>::type;
137} // namespace detail
138
139#if defined(_MSC_VER) && _MSC_VER < 1910
140constexpr const std::size_t
141dynamic_extent = std::numeric_limits<std::size_t>::max(); // NOLINT
142#else
143constexpr std::size_t dynamic_extent = std::numeric_limits<std::size_t>::max(); // NOLINT
144#endif // defined(_MSC_VER) && _MSC_VER < 1910
145
146enum class byte : unsigned char {}; // NOLINT
147
148template <class ElementType, std::size_t Extent>
149class Span;
150
151namespace detail {
152
153template <typename SpanType, bool IsConst>
155 using ElementType = typename SpanType::element_type;
156
157 public:
158 using iterator_category = std::random_access_iterator_tag; // NOLINT
159 using value_type = typename SpanType::value_type; // NOLINT
160 using difference_type = detail::ptrdiff_t; // NOLINT
161
162 using reference = typename std::conditional< // NOLINT
163 IsConst, const ElementType, ElementType>::type&;
164 using pointer = typename std::add_pointer<reference>::type; // NOLINT
165
166 constexpr SpanIterator() = default;
167
169 const SpanType* _span,
170 typename SpanType::index_type _idx) __span_noexcept :
171 span_(_span), index_(_idx) {}
172
174 template <bool B, typename std::enable_if<!B && IsConst>::type* = nullptr>
175 XGBOOST_DEVICE constexpr SpanIterator( // NOLINT
176 const SpanIterator<SpanType, B>& other_) __span_noexcept
177 : SpanIterator(other_.span_, other_.index_) {}
178
179 XGBOOST_DEVICE reference operator*() const {
180 SPAN_CHECK(index_ < span_->size());
181 return *(span_->data() + index_);
182 }
183 XGBOOST_DEVICE reference operator[](difference_type n) const {
184 return *(*this + n);
185 }
186
187 XGBOOST_DEVICE pointer operator->() const {
188 SPAN_CHECK(index_ != span_->size());
189 return span_->data() + index_;
190 }
191
192 XGBOOST_DEVICE SpanIterator& operator++() {
193 SPAN_CHECK(index_ != span_->size());
194 index_++;
195 return *this;
196 }
197
198 XGBOOST_DEVICE SpanIterator operator++(int) {
199 auto ret = *this;
200 ++(*this);
201 return ret;
202 }
203
204 XGBOOST_DEVICE SpanIterator& operator--() {
205 SPAN_CHECK(index_ != 0 && index_ <= span_->size());
206 index_--;
207 return *this;
208 }
209
210 XGBOOST_DEVICE SpanIterator operator--(int) {
211 auto ret = *this;
212 --(*this);
213 return ret;
214 }
215
216 XGBOOST_DEVICE SpanIterator operator+(difference_type n) const {
217 auto ret = *this;
218 return ret += n;
219 }
220
221 XGBOOST_DEVICE SpanIterator& operator+=(difference_type n) {
222 SPAN_CHECK((index_ + n) <= span_->size());
223 index_ += n;
224 return *this;
225 }
226
227 XGBOOST_DEVICE difference_type operator-(SpanIterator rhs) const {
228 SPAN_CHECK(span_ == rhs.span_);
229 return index_ - rhs.index_;
230 }
231
232 XGBOOST_DEVICE SpanIterator operator-(difference_type n) const {
233 auto ret = *this;
234 return ret -= n;
235 }
236
237 XGBOOST_DEVICE SpanIterator& operator-=(difference_type n) {
238 return *this += -n;
239 }
240
241 // friends
242 XGBOOST_DEVICE constexpr friend bool operator==(
243 SpanIterator _lhs, SpanIterator _rhs) __span_noexcept {
244 return _lhs.span_ == _rhs.span_ && _lhs.index_ == _rhs.index_;
245 }
246
247 XGBOOST_DEVICE constexpr friend bool operator!=(
248 SpanIterator _lhs, SpanIterator _rhs) __span_noexcept {
249 return !(_lhs == _rhs);
250 }
251
252 XGBOOST_DEVICE constexpr friend bool operator<(
253 SpanIterator _lhs, SpanIterator _rhs) __span_noexcept {
254 return _lhs.index_ < _rhs.index_;
255 }
256
257 XGBOOST_DEVICE constexpr friend bool operator<=(
258 SpanIterator _lhs, SpanIterator _rhs) __span_noexcept {
259 return !(_rhs < _lhs);
260 }
261
262 XGBOOST_DEVICE constexpr friend bool operator>(
263 SpanIterator _lhs, SpanIterator _rhs) __span_noexcept {
264 return _rhs < _lhs;
265 }
266
267 XGBOOST_DEVICE constexpr friend bool operator>=(
268 SpanIterator _lhs, SpanIterator _rhs) __span_noexcept {
269 return !(_rhs > _lhs);
270 }
271
272 protected:
273 const SpanType *span_ { nullptr };
274 typename SpanType::index_type index_ { 0 };
275};
276
277
278// It's tempting to use constexpr instead of structs to do the following meta
279// programming. But remember that we are supporting MSVC 2013 here.
280
288template <std::size_t Extent, std::size_t Offset, std::size_t Count>
289struct ExtentValue : public std::integral_constant<
290 std::size_t, Count != dynamic_extent ?
291 Count : (Extent != dynamic_extent ? Extent - Offset : Extent)> {};
292
297template <typename T, std::size_t Extent>
298struct ExtentAsBytesValue : public std::integral_constant<
299 std::size_t,
300 Extent == dynamic_extent ?
301 Extent : sizeof(T) * Extent> {};
302
303template <std::size_t From, std::size_t To>
304struct IsAllowedExtentConversion : public std::integral_constant<
305 bool, From == To || From == dynamic_extent || To == dynamic_extent> {};
306
307template <class From, class To>
308struct IsAllowedElementTypeConversion : public std::integral_constant<
309 bool, std::is_convertible<From(*)[], To(*)[]>::value> {};
310
311template <class T>
312struct IsSpanOracle : std::false_type {};
313
314template <class T, std::size_t Extent>
315struct IsSpanOracle<Span<T, Extent>> : std::true_type {};
316
317template <class T>
318struct IsSpan : public IsSpanOracle<typename std::remove_cv<T>::type> {};
319
320// Re-implement std algorithms here to adopt CUDA.
321template <typename T>
322struct Less {
323 XGBOOST_DEVICE constexpr bool operator()(const T& _x, const T& _y) const {
324 return _x < _y;
325 }
326};
327
328template <typename T>
329struct Greater {
330 XGBOOST_DEVICE constexpr bool operator()(const T& _x, const T& _y) const {
331 return _x > _y;
332 }
333};
334
335template <class InputIt1, class InputIt2,
336 class Compare =
338XGBOOST_DEVICE bool LexicographicalCompare(InputIt1 first1, InputIt1 last1,
339 InputIt2 first2, InputIt2 last2) {
340 Compare comp;
341 for (; first1 != last1 && first2 != last2; ++first1, ++first2) {
342 if (comp(*first1, *first2)) {
343 return true;
344 }
345 if (comp(*first2, *first1)) {
346 return false;
347 }
348 }
349 return first1 == last1 && first2 != last2;
350}
351
352} // namespace detail
353
354
422template <typename T,
423 std::size_t Extent = dynamic_extent>
424class Span {
425 public:
426 using element_type = T; // NOLINT
427 using value_type = typename std::remove_cv<T>::type; // NOLINT
428 using index_type = std::size_t; // NOLINT
429 using difference_type = detail::ptrdiff_t; // NOLINT
430 using pointer = T*; // NOLINT
431 using reference = T&; // NOLINT
432
433 using iterator = detail::SpanIterator<Span<T, Extent>, false>; // NOLINT
434 using const_iterator = const detail::SpanIterator<Span<T, Extent>, true>; // NOLINT
435 using reverse_iterator = std::reverse_iterator<iterator>; // NOLINT
436 using const_reverse_iterator = const std::reverse_iterator<const_iterator>; // NOLINT
437
438 // constructors
439 constexpr Span() __span_noexcept = default;
440
441 XGBOOST_DEVICE Span(pointer _ptr, index_type _count) :
442 size_(_count), data_(_ptr) {
443 SPAN_CHECK(!(Extent != dynamic_extent && _count != Extent));
444 SPAN_CHECK(_ptr || _count == 0);
445 }
446
447 XGBOOST_DEVICE Span(pointer _first, pointer _last) :
448 size_(_last - _first), data_(_first) {
449 SPAN_CHECK(data_ || size_ == 0);
450 }
451
452 template <std::size_t N>
453 XGBOOST_DEVICE constexpr Span(element_type (&arr)[N]) // NOLINT
454 __span_noexcept : size_(N), data_(&arr[0]) {}
455
456 template <class Container,
457 class = typename std::enable_if<
458 !std::is_const<element_type>::value &&
460 std::is_convertible<typename Container::pointer, pointer>::value &&
461 std::is_convertible<typename Container::pointer,
462 decltype(std::declval<Container>().data())>::value>::type>
463 Span(Container& _cont) : // NOLINT
464 size_(_cont.size()), data_(_cont.data()) {
465 static_assert(!detail::IsSpan<Container>::value, "Wrong constructor of Span is called.");
466 }
467
468 template <class Container,
469 class = typename std::enable_if<
470 std::is_const<element_type>::value &&
472 std::is_convertible<typename Container::pointer, pointer>::value &&
473 std::is_convertible<typename Container::pointer,
474 decltype(std::declval<Container>().data())>::value>::type>
475 Span(const Container& _cont) : size_(_cont.size()), // NOLINT
476 data_(_cont.data()) {
477 static_assert(!detail::IsSpan<Container>::value, "Wrong constructor of Span is called.");
478 }
479
480 template <class U, std::size_t OtherExtent,
481 class = typename std::enable_if<
484 XGBOOST_DEVICE constexpr Span(const Span<U, OtherExtent>& _other) // NOLINT
485 __span_noexcept : size_(_other.size()), data_(_other.data()) {}
486
487 XGBOOST_DEVICE constexpr Span(const Span& _other)
488 __span_noexcept : size_(_other.size()), data_(_other.data()) {}
489
490 XGBOOST_DEVICE Span& operator=(const Span& _other) __span_noexcept {
491 size_ = _other.size();
492 data_ = _other.data();
493 return *this;
494 }
495
496 XGBOOST_DEVICE ~Span() __span_noexcept {}; // NOLINT
497
498 XGBOOST_DEVICE constexpr iterator begin() const __span_noexcept { // NOLINT
499 return {this, 0};
500 }
501
502 XGBOOST_DEVICE constexpr iterator end() const __span_noexcept { // NOLINT
503 return {this, size()};
504 }
505
506 XGBOOST_DEVICE constexpr const_iterator cbegin() const __span_noexcept { // NOLINT
507 return {this, 0};
508 }
509
510 XGBOOST_DEVICE constexpr const_iterator cend() const __span_noexcept { // NOLINT
511 return {this, size()};
512 }
513
514 constexpr reverse_iterator rbegin() const __span_noexcept { // NOLINT
515 return reverse_iterator{end()};
516 }
517
518 constexpr reverse_iterator rend() const __span_noexcept { // NOLINT
519 return reverse_iterator{begin()};
520 }
521
522 XGBOOST_DEVICE constexpr const_reverse_iterator crbegin() const __span_noexcept { // NOLINT
523 return const_reverse_iterator{cend()};
524 }
525
526 XGBOOST_DEVICE constexpr const_reverse_iterator crend() const __span_noexcept { // NOLINT
527 return const_reverse_iterator{cbegin()};
528 }
529
530 // element access
531
532 XGBOOST_DEVICE reference front() const { // NOLINT
533 return (*this)[0];
534 }
535
536 XGBOOST_DEVICE reference back() const { // NOLINT
537 return (*this)[size() - 1];
538 }
539
540 XGBOOST_DEVICE reference operator[](index_type _idx) const {
541 SPAN_LT(_idx, size());
542 return data()[_idx];
543 }
544
545 XGBOOST_DEVICE reference operator()(index_type _idx) const {
546 return this->operator[](_idx);
547 }
548
549 XGBOOST_DEVICE constexpr pointer data() const __span_noexcept { // NOLINT
550 return data_;
551 }
552
553 // Observers
554 XGBOOST_DEVICE constexpr index_type size() const __span_noexcept { // NOLINT
555 return size_;
556 }
557 XGBOOST_DEVICE constexpr index_type size_bytes() const __span_noexcept { // NOLINT
558 return size() * sizeof(T);
559 }
560
561 XGBOOST_DEVICE constexpr bool empty() const __span_noexcept { // NOLINT
562 return size() == 0;
563 }
564
565 // Subviews
566 template <std::size_t Count>
567 XGBOOST_DEVICE Span<element_type, Count> first() const { // NOLINT
568 SPAN_CHECK(Count <= size());
569 return {data(), Count};
570 }
571
573 std::size_t _count) const {
574 SPAN_CHECK(_count <= size());
575 return {data(), _count};
576 }
577
578 template <std::size_t Count>
579 XGBOOST_DEVICE Span<element_type, Count> last() const { // NOLINT
580 SPAN_CHECK(Count <= size());
581 return {data() + size() - Count, Count};
582 }
583
585 std::size_t _count) const {
586 SPAN_CHECK(_count <= size());
587 return subspan(size() - _count, _count);
588 }
589
594 template <std::size_t Offset,
595 std::size_t Count = dynamic_extent>
596 XGBOOST_DEVICE auto subspan() const -> // NOLINT
597 Span<element_type,
598 detail::ExtentValue<Extent, Offset, Count>::value> {
599 SPAN_CHECK((Count == dynamic_extent) ?
600 (Offset <= size()) : (Offset + Count <= size()));
601 return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
602 }
603
605 index_type _offset,
606 index_type _count = dynamic_extent) const {
607 SPAN_CHECK((_count == dynamic_extent) ?
608 (_offset <= size()) : (_offset + _count <= size()));
609 return {data() + _offset, _count ==
610 dynamic_extent ? size() - _offset : _count};
611 }
612
613 private:
614 index_type size_ { 0 };
615 pointer data_ { nullptr };
616};
617
618template <class T, std::size_t X, class U, std::size_t Y>
619XGBOOST_DEVICE bool operator==(Span<T, X> l, Span<U, Y> r) {
620 if (l.size() != r.size()) {
621 return false;
622 }
623 for (auto l_beg = l.cbegin(), r_beg = r.cbegin(); l_beg != l.cend();
624 ++l_beg, ++r_beg) {
625 if (*l_beg != *r_beg) {
626 return false;
627 }
628 }
629 return true;
630}
631
632template <class T, std::size_t X, class U, std::size_t Y>
633XGBOOST_DEVICE constexpr bool operator!=(Span<T, X> l, Span<U, Y> r) {
634 return !(l == r);
635}
636
637template <class T, std::size_t X, class U, std::size_t Y>
638XGBOOST_DEVICE constexpr bool operator<(Span<T, X> l, Span<U, Y> r) {
639 return detail::LexicographicalCompare(l.begin(), l.end(),
640 r.begin(), r.end());
641}
642
643template <class T, std::size_t X, class U, std::size_t Y>
644XGBOOST_DEVICE constexpr bool operator<=(Span<T, X> l, Span<U, Y> r) {
645 return !(l > r);
646}
647
648template <class T, std::size_t X, class U, std::size_t Y>
649XGBOOST_DEVICE constexpr bool operator>(Span<T, X> l, Span<U, Y> r) {
650 return detail::LexicographicalCompare<
651 typename Span<T, X>::iterator, typename Span<U, Y>::iterator,
652 detail::Greater<typename Span<T, X>::element_type>>(l.begin(), l.end(),
653 r.begin(), r.end());
654}
655
656template <class T, std::size_t X, class U, std::size_t Y>
657XGBOOST_DEVICE constexpr bool operator>=(Span<T, X> l, Span<U, Y> r) {
658 return !(l < r);
659}
660
661template <class T, std::size_t E>
662XGBOOST_DEVICE auto as_bytes(Span<T, E> s) __span_noexcept -> // NOLINT
663 Span<const byte, detail::ExtentAsBytesValue<T, E>::value> {
664 return {reinterpret_cast<const byte*>(s.data()), s.size_bytes()};
665}
666
667template <class T, std::size_t E>
668XGBOOST_DEVICE auto as_writable_bytes(Span<T, E> s) __span_noexcept -> // NOLINT
669 Span<byte, detail::ExtentAsBytesValue<T, E>::value> {
670 return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
671}
672
676template <typename It>
677class IterSpan {
678 public:
679 using value_type = typename std::iterator_traits<It>::value_type; // NOLINT
680 using index_type = std::size_t; // NOLINT
681 using iterator = It; // NOLINT
682
683 private:
684 It it_;
685 index_type size_{0};
686
687 public:
688 IterSpan() = default;
689 XGBOOST_DEVICE IterSpan(It it, index_type size) : it_{std::move(it)}, size_{size} {}
691 : it_{span.data()}, size_{span.size()} {}
692
693 [[nodiscard]] XGBOOST_DEVICE index_type size() const noexcept { return size_; } // NOLINT
694 [[nodiscard]] XGBOOST_DEVICE decltype(auto) operator[](index_type i) const { return it_[i]; }
695 [[nodiscard]] XGBOOST_DEVICE decltype(auto) operator[](index_type i) { return it_[i]; }
696 [[nodiscard]] XGBOOST_DEVICE bool empty() const noexcept { return size() == 0; } // NOLINT
697 [[nodiscard]] XGBOOST_DEVICE It data() const noexcept { return it_; } // NOLINT
698 [[nodiscard]] XGBOOST_DEVICE IterSpan<It> subspan( // NOLINT
699 index_type _offset, index_type _count = dynamic_extent) const {
700 SPAN_CHECK((_count == dynamic_extent) ? (_offset <= size()) : (_offset + _count <= size()));
701 return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count};
702 }
703 [[nodiscard]] XGBOOST_DEVICE constexpr iterator begin() const noexcept { // NOLINT
704 return {this, 0};
705 }
706 [[nodiscard]] XGBOOST_DEVICE constexpr iterator end() const noexcept { // NOLINT
707 return {this, size()};
708 }
709};
710} // namespace common
711} // namespace xgboost
712
713#if defined(_MSC_VER) &&_MSC_VER < 1910
714#undef constexpr
715#pragma pop_macro("constexpr")
716#undef __span_noexcept
717#endif // _MSC_VER < 1910
718
719#endif // XGBOOST_SPAN_H_
A simple custom Span type that uses general iterator instead of pointer.
Definition span.h:677
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition span.h:596
Copyright 2015-2023 by XGBoost Contributors.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition base.h:64
defines console logging options for xgboost. Use to enforce unified print behavior.
detail namespace with internal helper functions
Definition json.hpp:249
bool operator<(const value_t lhs, const value_t rhs) noexcept
comparison operator for JSON types
Definition json.hpp:2889
namespace of xgboost
Definition base.h:90
Definition span.h:318
Definition span.h:322