6#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
7#define XGBOOST_DATA_ARRAY_INTERFACE_H_
19#include "../common/bitfield.h"
20#include "../common/common.h"
21#include "../common/error_msg.h"
24#include "xgboost/json.h"
27#include "xgboost/span.h"
29#if defined(XGBOOST_USE_CUDA)
36 static char const *Contiguous() {
return "Memory should be contiguous."; }
37 static char const *TypestrFormat() {
38 return "`typestr' should be of format <endian><type><size of type in bytes>.";
40 static char const *Dimension(int32_t d) {
41 static std::string str;
44 str += std::to_string(d);
45 str +=
" dimensional array is valid.";
49 return "Only version <= 3 of `__cuda_array_interface__' and `__array_interface__' are "
52 static char const *OfType(std::string
const &type) {
53 static std::string str;
55 str +=
" should be of ";
61 static std::string TypeStr(
char c) {
70 return "Unsigned integer";
72 return "Floating point";
74 return "Complex floating point";
88 LOG(FATAL) <<
"Invalid type code: " << c <<
" in `typestr' of input array."
89 <<
"\nPlease verify the `__cuda_array_interface__/__array_interface__' "
90 <<
"of your input data complies to: "
91 <<
"https://docs.scipy.org/doc/numpy/reference/arrays.interface.html"
92 <<
"\nOr open an issue.";
97 static std::string UnSupportedType(
StringView typestr) {
98 return TypeStr(typestr[1]) +
"-" + typestr[2] +
" is not supported.";
107 enum Type : std::int8_t { kF2, kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
109 template <
typename PtrType>
110 static PtrType GetPtrFromArrayData(Object::Map
const &obj) {
111 auto data_it = obj.find(
"data");
112 if (data_it == obj.cend() || IsA<Null>(data_it->second)) {
113 LOG(FATAL) <<
"Empty data passed in.";
115 auto p_data =
reinterpret_cast<PtrType
>(
116 static_cast<size_t>(get<Integer const>(get<Array const>(data_it->second).at(0))));
120 static void Validate(Object::Map
const &array) {
121 auto version_it = array.find(
"version");
122 if (version_it == array.cend() || IsA<Null>(version_it->second)) {
123 LOG(FATAL) <<
"Missing `version' field for array interface";
125 if (get<Integer const>(version_it->second) > 3) {
126 LOG(FATAL) << ArrayInterfaceErrors::Version();
129 auto typestr_it = array.find(
"typestr");
130 if (typestr_it == array.cend() || IsA<Null>(typestr_it->second)) {
131 LOG(FATAL) <<
"Missing `typestr' field for array interface";
134 auto typestr = get<String const>(typestr_it->second);
135 CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat();
137 auto shape_it = array.find(
"shape");
138 if (shape_it == array.cend() || IsA<Null>(shape_it->second)) {
139 LOG(FATAL) <<
"Missing `shape' field for array interface";
141 auto data_it = array.find(
"data");
142 if (data_it == array.cend() || IsA<Null>(data_it->second)) {
143 LOG(FATAL) <<
"Missing `data' field for array interface";
149 static size_t ExtractMask(Object::Map
const &column,
151 auto &s_mask = *p_out;
152 auto const &mask_it = column.find(
"mask");
153 if (mask_it != column.cend() && !IsA<Null>(mask_it->second)) {
154 auto const &j_mask = get<Object const>(mask_it->second);
157 auto p_mask = GetPtrFromArrayData<RBitField8::value_type *>(j_mask);
159 auto j_shape = get<Array const>(j_mask.at(
"shape"));
160 CHECK_EQ(j_shape.size(), 1) << ArrayInterfaceErrors::Dimension(1);
161 auto typestr = get<String const>(j_mask.at(
"typestr"));
163 int64_t
const type_length = typestr.at(2) - 48;
165 if (typestr.at(1) ==
't') {
166 CHECK_EQ(type_length, 1) <<
"mask with bitfield type should be of 1 byte per bitfield.";
167 }
else if (typestr.at(1) ==
'i') {
168 CHECK_EQ(type_length, 1) <<
"mask with integer type should be of 1 byte per integer.";
170 LOG(FATAL) <<
"mask must be of integer type or bit field type.";
182 size_t const n_bits =
static_cast<size_t>(get<Integer>(j_shape.at(0)));
185 size_t const span_size = RBitField8::ComputeStorageSize(n_bits);
187 auto strides_it = j_mask.find(
"strides");
188 if (strides_it != j_mask.cend() && !IsA<Null>(strides_it->second)) {
189 auto strides = get<Array const>(strides_it->second);
190 CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1);
191 CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ArrayInterfaceErrors::Contiguous();
194 s_mask = {p_mask, span_size};
203 static void HandleRowVector(std::vector<size_t>
const &shape, std::vector<size_t> *p_out) {
205 if (shape.size() == 2 && D == 1) {
208 CHECK(m == 1 || n == 1);
223 static void ExtractShape(Object::Map
const &array,
size_t (&out_shape)[D]) {
224 auto const &j_shape = get<Array const>(array.at(
"shape"));
225 std::vector<size_t> shape_arr(j_shape.size(), 0);
226 std::transform(j_shape.cbegin(), j_shape.cend(), shape_arr.begin(),
227 [](
Json in) { return get<Integer const>(in); });
229 HandleRowVector<D>(shape_arr, &shape_arr);
232 for (i = 0; i < shape_arr.size(); ++i) {
233 CHECK_LT(i, D) << ArrayInterfaceErrors::Dimension(D);
234 out_shape[i] = shape_arr[i];
237 std::fill(out_shape + i, out_shape + D, 1);
245 size_t (&shape)[D],
size_t (&stride)[D]) {
246 auto strides_it = array.find(
"strides");
248 if (strides_it == array.cend() || IsA<Null>(strides_it->second)) {
250 linalg::detail::CalcStride(shape, stride);
259 auto const &j_shape = get<Array const>(array.at(
"shape"));
260 std::vector<size_t> shape_arr(j_shape.size(), 0);
261 std::transform(j_shape.cbegin(), j_shape.cend(), shape_arr.begin(),
262 [](
Json in) { return get<Integer const>(in); });
264 auto const &j_strides = get<Array const>(strides_it->second);
265 CHECK_EQ(j_strides.size(), j_shape.size()) <<
"stride and shape don't match.";
266 std::vector<size_t> stride_arr(j_strides.size(), 0);
267 std::transform(j_strides.cbegin(), j_strides.cend(), stride_arr.begin(),
268 [](
Json in) { return get<Integer const>(in); });
271 HandleRowVector<D>(shape_arr, &stride_arr);
273 for (i = 0; i < stride_arr.size(); ++i) {
276 CHECK_LT(i, D) << ArrayInterfaceErrors::Dimension(D);
278 stride[i] = stride_arr[i] / itemsize;
280 std::fill(stride + i, stride + D, 1);
282 size_t stride_tmp[D];
283 linalg::detail::CalcStride(shape, stride_tmp);
284 return std::equal(stride_tmp, stride_tmp + D, stride);
287 static void *ExtractData(Object::Map
const &array,
size_t size) {
289 void *p_data = ArrayInterfaceHandler::GetPtrFromArrayData<void *>(array);
291 CHECK_EQ(size, 0) <<
"Empty data with non-zero shape.";
308template <
typename T,
typename E =
void>
311#if defined(XGBOOST_USE_CUDA)
314 static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF2;
319 static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF4;
323 static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF8;
327 std::enable_if_t<std::is_same<T, long double>::value && sizeof(long double) == 16>> {
328 static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF16;
333 static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU1;
337 static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU2;
341 static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU4;
345 static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU8;
350 static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI1;
354 static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI2;
358 static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI4;
362 static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI8;
365#if !defined(XGBOOST_USE_CUDA)
387template <
int32_t D,
bool allow_mask = (D == 1)>
389 static_assert(D > 0,
"Invalid dimension for array interface.");
399 void Initialize(Object::Map
const &array) {
400 ArrayInterfaceHandler::Validate(array);
402 auto typestr = get<String const>(array.at(
"typestr"));
404 ArrayInterfaceHandler::ExtractShape(array, shape);
405 size_t itemsize = typestr[2] -
'0';
407 n = linalg::detail::CalcSize(shape);
409 data = ArrayInterfaceHandler::ExtractData(array, n);
410 static_assert(allow_mask ? D == 1 : D >= 1,
"Masked ndarray is not supported.");
412 auto alignment = this->ElementAlignment();
413 auto ptr =
reinterpret_cast<uintptr_t
>(this->data);
414 CHECK_EQ(ptr % alignment, 0) <<
"Input pointer misalignment.";
418 size_t n_bits = ArrayInterfaceHandler::ExtractMask(array, &s_mask);
423 CHECK_EQ(n_bits, n) <<
"Shape of bit mask doesn't match data shape. "
424 <<
"XGBoost doesn't support internal broadcasting.";
427 auto mask_it = array.find(
"mask");
428 CHECK(mask_it == array.cend() || IsA<Null>(mask_it->second))
429 <<
"Masked array is not yet supported.";
432 auto stream_it = array.find(
"stream");
433 if (stream_it != array.cend() && !IsA<Null>(stream_it->second)) {
434 int64_t stream = get<Integer const>(stream_it->second);
441 explicit ArrayInterface(Object::Map
const &array) { this->Initialize(array); }
444 if (IsA<Object>(array)) {
445 this->Initialize(get<Object const>(array));
448 if (IsA<Array>(array)) {
449 CHECK_EQ(get<Array const>(array).size(), 1)
450 <<
"Column: " << ArrayInterfaceErrors::Dimension(1);
451 this->Initialize(get<Object const>(get<Array const>(array)[0]));
461 using T = ArrayInterfaceHandler::Type;
462 if (typestr.size() == 4 && typestr[1] ==
'f' && typestr[2] ==
'1' && typestr[3] ==
'6') {
463 CHECK(
sizeof(
long double) == 16) << error::NoF128();
465 }
else if (typestr[1] ==
'f' && typestr[2] ==
'2') {
466#if defined(XGBOOST_USE_CUDA)
469 LOG(FATAL) <<
"Half type is not supported.";
471 }
else if (typestr[1] ==
'f' && typestr[2] ==
'4') {
473 }
else if (typestr[1] ==
'f' && typestr[2] ==
'8') {
475 }
else if (typestr[1] ==
'i' && typestr[2] ==
'1') {
477 }
else if (typestr[1] ==
'i' && typestr[2] ==
'2') {
479 }
else if (typestr[1] ==
'i' && typestr[2] ==
'4') {
481 }
else if (typestr[1] ==
'i' && typestr[2] ==
'8') {
483 }
else if (typestr[1] ==
'u' && typestr[2] ==
'1') {
485 }
else if (typestr[1] ==
'u' && typestr[2] ==
'2') {
487 }
else if (typestr[1] ==
'u' && typestr[2] ==
'4') {
489 }
else if (typestr[1] ==
'u' && typestr[2] ==
'8') {
492 LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(typestr);
497 [[nodiscard]]
XGBOOST_DEVICE std::size_t Shape(
size_t i)
const {
return shape[i]; }
498 [[nodiscard]]
XGBOOST_DEVICE std::size_t Stride(
size_t i)
const {
return strides[i]; }
500 template <
typename Fn>
501 XGBOOST_HOST_DEV_INLINE
decltype(
auto) DispatchCall(Fn func)
const {
502 using T = ArrayInterfaceHandler::Type;
505#if defined(XGBOOST_USE_CUDA)
506 return func(
reinterpret_cast<__half
const *
>(data));
510 return func(
reinterpret_cast<float const *
>(data));
512 return func(
reinterpret_cast<double const *
>(data));
517 return func(
reinterpret_cast<double const *
>(data));
521 return func(
reinterpret_cast<long double const *
>(data));
524 return func(
reinterpret_cast<int8_t
const *
>(data));
526 return func(
reinterpret_cast<int16_t
const *
>(data));
528 return func(
reinterpret_cast<int32_t
const *
>(data));
530 return func(
reinterpret_cast<int64_t
const *
>(data));
532 return func(
reinterpret_cast<uint8_t
const *
>(data));
534 return func(
reinterpret_cast<uint16_t
const *
>(data));
536 return func(
reinterpret_cast<uint32_t
const *
>(data));
538 return func(
reinterpret_cast<uint64_t
const *
>(data));
541 return func(
reinterpret_cast<uint64_t
const *
>(data));
545 return this->DispatchCall([](
auto *typed_data_ptr) {
546 return sizeof(std::remove_pointer_t<
decltype(typed_data_ptr)>);
549 [[nodiscard]]
XGBOOST_DEVICE std::size_t ElementAlignment()
const {
550 return this->DispatchCall([](
auto *typed_data_ptr) {
551 return std::alignment_of<std::remove_pointer_t<
decltype(typed_data_ptr)>>::value;
555 template <
typename T = float,
typename... Index>
556 XGBOOST_HOST_DEV_INLINE T operator()(Index &&...index)
const {
557 static_assert(
sizeof...(index) <= D,
"Invalid index.");
558 return this->DispatchCall([=](
auto const *p_values) -> T {
559 std::size_t offset = linalg::detail::Offset<0ul>(strides, 0ul, index...);
560#if defined(XGBOOST_USE_CUDA)
562 using Type = std::conditional_t<
564 std::remove_cv_t<std::remove_pointer_t<
decltype(p_values)>>>::value &&
565 std::is_same<std::size_t, std::remove_cv_t<T>>::value,
566 unsigned long long, T>;
567 return static_cast<T
>(
static_cast<Type
>(p_values[offset]));
569 return static_cast<T
>(p_values[offset]);
577 std::size_t strides[D]{0};
579 std::size_t shape[D]{0};
581 void const *data{
nullptr};
585 bool is_contiguous{
false};
587 ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16};
590template <std::
int32_t D,
typename Fn>
593 CHECK_EQ(array.valid.Capacity(), 0);
594 auto dispatch = [&](
auto t) {
595 using T = std::remove_const_t<
decltype(t)>
const;
602 std::numeric_limits<std::size_t>::max()},
603 array.shape, array.strides, device});
605 switch (
array.type) {
606 case ArrayInterfaceHandler::kF2: {
607#if defined(XGBOOST_USE_CUDA)
612 case ArrayInterfaceHandler::kF4: {
616 case ArrayInterfaceHandler::kF8: {
620 case ArrayInterfaceHandler::kF16: {
621 using T =
long double;
622 CHECK(
sizeof(
long double) == 16) << error::NoF128();
626 case ArrayInterfaceHandler::kI1: {
627 dispatch(std::int8_t{});
630 case ArrayInterfaceHandler::kI2: {
631 dispatch(std::int16_t{});
634 case ArrayInterfaceHandler::kI4: {
635 dispatch(std::int32_t{});
638 case ArrayInterfaceHandler::kI8: {
639 dispatch(std::int64_t{});
642 case ArrayInterfaceHandler::kU1: {
643 dispatch(std::uint8_t{});
646 case ArrayInterfaceHandler::kU2: {
647 dispatch(std::uint16_t{});
650 case ArrayInterfaceHandler::kU4: {
651 dispatch(std::uint32_t{});
654 case ArrayInterfaceHandler::kU8: {
655 dispatch(std::uint64_t{});
664template <
typename T,
int32_t D>
667 template <
typename... I>
669 static_assert(
sizeof...(ind) <= D,
"Invalid index.");
670 return array.template operator()<T>(ind...);
676 CHECK(!array.valid.Data()) <<
"Meta info " << key <<
" should be dense, found validity mask";
Utilities for consuming array interface.
Definition array_interface.h:105
static void HandleRowVector(std::vector< size_t > const &shape, std::vector< size_t > *p_out)
Handle vector inputs.
Definition array_interface.h:203
static bool IsCudaPtr(void const *ptr)
Whether the ptr is allocated by CUDA.
Definition array_interface.h:367
static bool ExtractStride(Object::Map const &array, size_t itemsize, size_t(&shape)[D], size_t(&stride)[D])
Extracts the optiona ‘strides’ field and returns whether the array is c-contiguous.
Definition array_interface.h:244
static void SyncCudaStream(int64_t stream)
Sync the CUDA stream.
Definition array_interface.h:366
A type erased view over array_interface protocol defined by numpy.
Definition array_interface.h:388
Data structure representing JSON format.
Definition json.h:357
static Json Load(StringView str, std::ios::openmode mode=std::ios::in)
Decode the JSON object.
Definition json.cc:652
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
A tensor view with static type and dimension.
Definition linalg.h:293
Copyright 2015-2023 by XGBoost Contributors.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition base.h:64
Copyright 2015-2023 by XGBoost Contributors.
defines console logging options for xgboost. Use to enforce unified print behavior.
Copyright 2021-2023 by XGBoost Contributors.
@ array
array (ordered collection of values)
namespace of xgboost
Definition base.h:90
Definition array_interface.h:35
Definition string_view.h:15
Dispatch compile time type to runtime type.
Definition array_interface.h:309
Helper for type casting.
Definition array_interface.h:665