32 std::priority_queue<ExpandEntryT, std::vector<ExpandEntryT>,
33 std::function<bool(ExpandEntryT, ExpandEntryT)>>;
38 max_node_batch_size_(max_node_batch_size),
39 queue_(param.grow_policy == TrainParam::kDepthWise ? DepthWise<ExpandEntryT>
40 : LossGuide<ExpandEntryT>) {}
41 template <
typename EntryIterT>
42 void Push(EntryIterT begin, EntryIterT end) {
43 for (
auto it = begin; it != end; ++it) {
44 const ExpandEntryT& e = *it;
45 if (e.split.loss_chg >
kRtEps) {
50 void Push(
const std::vector<ExpandEntryT> &entries) {
51 this->Push(entries.begin(), entries.end());
53 void Push(ExpandEntryT
const& e) { queue_.push(e); }
56 return queue_.empty();
61 bool IsChildValid(ExpandEntryT
const& parent_entry) {
62 if (param_.max_depth > 0 && parent_entry.depth + 1 >= param_.max_depth)
return false;
63 if (param_.max_leaves > 0 && num_leaves_ >= param_.max_leaves)
return false;
70 std::vector<ExpandEntryT> Pop() {
71 if (queue_.empty())
return {};
73 if (param_.grow_policy == TrainParam::kLossGuide) {
74 ExpandEntryT e = queue_.top();
77 if (e.IsValid(param_, num_leaves_)) {
85 std::vector<ExpandEntryT> result;
86 ExpandEntryT e = queue_.top();
88 while (e.depth == level && !queue_.empty() && result.size() < max_node_batch_size_) {
90 if (e.IsValid(param_, num_leaves_)) {
92 result.emplace_back(e);
95 if (!queue_.empty()) {
105 std::size_t max_node_batch_size_;