Medial Code Documentation
Loading...
Searching...
No Matches
predict_fn.h
1
4#ifndef XGBOOST_PREDICTOR_PREDICT_FN_H_
5#define XGBOOST_PREDICTOR_PREDICT_FN_H_
6#include "../common/categorical.h"
8
9namespace xgboost::predictor {
11template <bool has_categorical>
12XGBOOST_DEVICE bool GetDecision(RegTree::Node const &node, bst_node_t nid, float fvalue,
14 if (has_categorical && common::IsCat(cats.split_type, nid)) {
15 auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
16 return common::Decision(node_categories, fvalue);
17 } else {
18 return fvalue < node.SplitCond();
19 }
20}
21
22template <bool has_missing, bool has_categorical>
23inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid,
24 float fvalue, bool is_missing,
26 if (has_missing && is_missing) {
27 return node.DefaultChild();
28 } else {
29 return node.LeftChild() + !GetDecision<has_categorical>(node, nid, fvalue, cats);
30 }
31}
32
33template <bool has_missing, bool has_categorical>
34inline XGBOOST_DEVICE bst_node_t GetNextNodeMulti(MultiTargetTree const &tree,
35 bst_node_t const nidx, float fvalue,
36 bool is_missing,
37 RegTree::CategoricalSplitMatrix const &cats) {
38 if (has_missing && is_missing) {
39 return tree.DefaultChild(nidx);
40 } else {
41 if (has_categorical && common::IsCat(cats.split_type, nidx)) {
42 auto node_categories =
43 cats.categories.subspan(cats.node_ptr[nidx].beg, cats.node_ptr[nidx].size);
44 return common::Decision(node_categories, fvalue) ? tree.LeftChild(nidx)
45 : tree.RightChild(nidx);
46 } else {
47 return tree.LeftChild(nidx) + !(fvalue < tree.SplitCond(nidx));
48 }
49 }
50}
51
52} // namespace xgboost::predictor
53#endif // XGBOOST_PREDICTOR_PREDICT_FN_H_
tree node
Definition tree_model.h:166
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition tree_model.h:181
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition tree_model.h:185
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition tree_model.h:199
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition base.h:64
XGBOOST_DEVICE bool Decision(common::Span< CatBitField::value_type const > cats, float cat)
Whether should it traverse to left branch of a tree.
Definition categorical.h:55
Copyright 2017-2023 by XGBoost Contributors.
Definition predictor_oneapi.cc:22
XGBOOST_DEVICE bool GetDecision(RegTree::Node const &node, bst_node_t nid, float fvalue, RegTree::CategoricalSplitMatrix const &cats)
Whether it should traverse to the left branch of a tree.
Definition predict_fn.h:12
std::int32_t bst_node_t
Type for tree node index.
Definition base.h:112
CSR-like matrix for categorical splits.
Definition tree_model.h:655
Copyright 2014-2023 by Contributors.