Medial Code Documentation
Loading...
Searching...
No Matches
categorical.h
Go to the documentation of this file.
1
5#ifndef XGBOOST_COMMON_CATEGORICAL_H_
6#define XGBOOST_COMMON_CATEGORICAL_H_
7
8#include <limits>
9
10#include "bitfield.h"
11#include "xgboost/base.h"
12#include "xgboost/data.h"
13#include "xgboost/span.h"
14
15namespace xgboost {
16namespace common {
17
18using CatBitField = LBitField32;
19using KCatBitField = CLBitField32;
20
21// Cast the categorical type.
22template <typename T>
23XGBOOST_DEVICE bst_cat_t AsCat(T const& v) {
24 return static_cast<bst_cat_t>(v);
25}
26
27/* \brief Whether is fidx a categorical feature.
28 *
29 * \param ft Feature type for all features.
30 * \param fidx Feature index.
31 * \return Whether feature pointed by fidx is categorical feature.
32 */
33inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx) {
34 return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
35}
36
37constexpr inline bst_cat_t OutOfRangeCat() {
38 // See the round trip assert in `InvalidCat`.
39 return static_cast<bst_cat_t>(16777217) - static_cast<bst_cat_t>(1);
40}
41
42inline XGBOOST_DEVICE bool InvalidCat(float cat) {
43 constexpr auto kMaxCat = OutOfRangeCat();
44 static_assert(static_cast<bst_cat_t>(static_cast<float>(kMaxCat)) == kMaxCat);
45 static_assert(static_cast<bst_cat_t>(static_cast<float>(kMaxCat + 1)) != kMaxCat + 1);
46 static_assert(static_cast<float>(kMaxCat + 1) == kMaxCat);
47 return cat < 0 || cat >= kMaxCat;
48}
49
56 KCatBitField const s_cats(cats);
57 if (XGBOOST_EXPECT(InvalidCat(cat), false)) {
58 return true;
59 }
60
61 auto pos = KCatBitField::ToBitPos(cat);
62 // If the input category is larger than the size of the bit field, it implies that the
63 // category is not chosen. Otherwise the bit field would have the category instead of
64 // being smaller than the category value.
65 if (pos.int_pos >= cats.size()) {
66 return true;
67 }
68 return !s_cats.Check(AsCat(cat));
69}
70
71inline void InvalidCategory() {
72 // OutOfRangeCat() can be accurately represented, but everything after it will be
73 // rounded toward it, so we use >= for comparison check. As a result, we require input
74 // values to be less than this last representable value.
75 auto str = std::to_string(OutOfRangeCat());
76 LOG(FATAL) << "Invalid categorical value detected. Categorical value should be non-negative, "
77 "less than total number of categories in training data and less than " +
78 str;
79}
80
81inline void CheckMaxCat(float max_cat, size_t n_categories) {
82 CHECK_GE(max_cat + 1, n_categories)
83 << "Maximum cateogry should not be lesser than the total number of categories.";
84}
85
89XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot) {
90 bool use_one_hot = n_cats < max_cat_to_onehot;
91 return use_one_hot;
92}
93
94struct IsCatOp {
95 XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
96};
97} // namespace common
98} // namespace xgboost
99
100#endif // XGBOOST_COMMON_CATEGORICAL_H_
Copyright 2019-2023, XGBoost Contributors.
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
Copyright 2015-2023 by XGBoost Contributors.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition base.h:64
Copyright 2015-2023 by XGBoost Contributors.
XGBOOST_DEVICE bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot)
Whether should we use onehot encoding for categorical data.
Definition categorical.h:89
XGBOOST_DEVICE bool Decision(common::Span< CatBitField::value_type const > cats, float cat)
Whether should it traverse to left branch of a tree.
Definition categorical.h:55
namespace of xgboost
Definition base.h:90
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
int32_t bst_cat_t
Categorical value type.
Definition base.h:99
A non-owning type with auxiliary methods defined for manipulating bits.
Definition bitfield.h:61
Definition categorical.h:94