Medial Code Documentation
Loading...
Searching...
No Matches
io_utils.h
1
4#ifndef XGBOOST_TREE_IO_UTILS_H_
5#define XGBOOST_TREE_IO_UTILS_H_
6#include <string> // for string
7#include <type_traits> // for enable_if_t, is_same, conditional_t
8#include <vector> // for vector
9
10#include "xgboost/json.h" // for Json
11
12namespace xgboost {
13template <bool typed>
14using FloatArrayT = std::conditional_t<typed, F32Array const, Array const>;
15template <bool typed>
16using U8ArrayT = std::conditional_t<typed, U8Array const, Array const>;
17template <bool typed>
18using I32ArrayT = std::conditional_t<typed, I32Array const, Array const>;
19template <bool typed>
20using I64ArrayT = std::conditional_t<typed, I64Array const, Array const>;
21template <bool typed, bool feature_is_64>
22using IndexArrayT = std::conditional_t<feature_is_64, I64ArrayT<typed>, I32ArrayT<typed>>;
23
24// typed array, not boolean
25template <typename JT, typename T>
26std::enable_if_t<!std::is_same<T, Json>::value && !std::is_same<JT, Boolean>::value, T> GetElem(
27 std::vector<T> const& arr, size_t i) {
28 return arr[i];
29}
30// typed array boolean
31template <typename JT, typename T>
32std::enable_if_t<!std::is_same<T, Json>::value && std::is_same<T, uint8_t>::value &&
33 std::is_same<JT, Boolean>::value,
34 bool>
35GetElem(std::vector<T> const& arr, size_t i) {
36 return arr[i] == 1;
37}
38// json array
39template <typename JT, typename T>
40std::enable_if_t<
41 std::is_same<T, Json>::value,
42 std::conditional_t<std::is_same<JT, Integer>::value, int64_t,
43 std::conditional_t<std::is_same<Boolean, JT>::value, bool, float>>>
44GetElem(std::vector<T> const& arr, size_t i) {
45 if (std::is_same<JT, Boolean>::value && !IsA<Boolean>(arr[i])) {
46 return get<Integer const>(arr[i]) == 1;
47 }
48 return get<JT const>(arr[i]);
49}
50
51namespace tree_field {
52inline std::string const kLossChg{"loss_changes"};
53inline std::string const kSumHess{"sum_hessian"};
54inline std::string const kBaseWeight{"base_weights"};
55
56inline std::string const kSplitIdx{"split_indices"};
57inline std::string const kSplitCond{"split_conditions"};
58inline std::string const kDftLeft{"default_left"};
59
60inline std::string const kParent{"parents"};
61inline std::string const kLeft{"left_children"};
62inline std::string const kRight{"right_children"};
63} // namespace tree_field
64} // namespace xgboost
65#endif // XGBOOST_TREE_IO_UTILS_H_
namespace of xgboost
Definition base.h:90