Medial Code Documentation
Loading...
Searching...
No Matches
driver.h
1
4#ifndef XGBOOST_TREE_DRIVER_H_
5#define XGBOOST_TREE_DRIVER_H_
6#include <xgboost/span.h>
7#include <queue>
8#include <vector>
9#include "./param.h"
10
11namespace xgboost {
12namespace tree {
13
14template <typename ExpandEntryT>
15inline bool DepthWise(const ExpandEntryT& lhs, const ExpandEntryT& rhs) {
16 return lhs.GetNodeId() > rhs.GetNodeId(); // favor small depth
17}
18
19template <typename ExpandEntryT>
20inline bool LossGuide(const ExpandEntryT& lhs, const ExpandEntryT& rhs) {
21 if (lhs.GetLossChange() == rhs.GetLossChange()) {
22 return lhs.GetNodeId() > rhs.GetNodeId(); // favor small timestamp
23 } else {
24 return lhs.GetLossChange() < rhs.GetLossChange(); // favor large loss_chg
25 }
26}
27
28// Drives execution of tree building on device
29template <typename ExpandEntryT>
30class Driver {
31 using ExpandQueue =
32 std::priority_queue<ExpandEntryT, std::vector<ExpandEntryT>,
33 std::function<bool(ExpandEntryT, ExpandEntryT)>>;
34
35 public:
36 explicit Driver(TrainParam param, std::size_t max_node_batch_size = 256)
37 : param_(param),
38 max_node_batch_size_(max_node_batch_size),
39 queue_(param.grow_policy == TrainParam::kDepthWise ? DepthWise<ExpandEntryT>
40 : LossGuide<ExpandEntryT>) {}
41 template <typename EntryIterT>
42 void Push(EntryIterT begin, EntryIterT end) {
43 for (auto it = begin; it != end; ++it) {
44 const ExpandEntryT& e = *it;
45 if (e.split.loss_chg > kRtEps) {
46 queue_.push(e);
47 }
48 }
49 }
50 void Push(const std::vector<ExpandEntryT> &entries) {
51 this->Push(entries.begin(), entries.end());
52 }
53 void Push(ExpandEntryT const& e) { queue_.push(e); }
54
55 bool IsEmpty() {
56 return queue_.empty();
57 }
58
59 // Can a child of this entry still be expanded?
60 // can be used to avoid extra work
61 bool IsChildValid(ExpandEntryT const& parent_entry) {
62 if (param_.max_depth > 0 && parent_entry.depth + 1 >= param_.max_depth) return false;
63 if (param_.max_leaves > 0 && num_leaves_ >= param_.max_leaves) return false;
64 return true;
65 }
66
67 // Return the set of nodes to be expanded
68 // This set has no dependencies between entries so they may be expanded in
69 // parallel or asynchronously
70 std::vector<ExpandEntryT> Pop() {
71 if (queue_.empty()) return {};
72 // Return a single entry for loss guided mode
73 if (param_.grow_policy == TrainParam::kLossGuide) {
74 ExpandEntryT e = queue_.top();
75 queue_.pop();
76
77 if (e.IsValid(param_, num_leaves_)) {
78 num_leaves_++;
79 return {e};
80 } else {
81 return {};
82 }
83 }
84 // Return nodes on same level for depth wise
85 std::vector<ExpandEntryT> result;
86 ExpandEntryT e = queue_.top();
87 int level = e.depth;
88 while (e.depth == level && !queue_.empty() && result.size() < max_node_batch_size_) {
89 queue_.pop();
90 if (e.IsValid(param_, num_leaves_)) {
91 num_leaves_++;
92 result.emplace_back(e);
93 }
94
95 if (!queue_.empty()) {
96 e = queue_.top();
97 }
98 }
99 return result;
100 }
101
102 private:
103 TrainParam param_;
104 bst_node_t num_leaves_ = 1;
105 std::size_t max_node_batch_size_;
106 ExpandQueue queue_;
107};
108} // namespace tree
109} // namespace xgboost
110
111#endif // XGBOOST_TREE_DRIVER_H_
Definition driver.h:30
namespace of xgboost
Definition base.h:90
std::int32_t bst_node_t
Type for tree node index.
Definition base.h:112
constexpr bst_float kRtEps
small eps gap for minimum split decision.
Definition base.h:319
training parameters for regression tree
Definition param.h:28
Copyright 2014-2023 by XGBoost Contributors.