Medial Code Documentation
Loading...
Searching...
No Matches
test_partitioner.h
1
4#ifndef XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
5#define XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
6#include <xgboost/context.h> // for Context
7#include <xgboost/linalg.h> // for Constant, Vector
8#include <xgboost/logging.h> // for CHECK
9#include <xgboost/tree_model.h> // for RegTree
10
11#include <vector> // for vector
12
13#include "../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry, MultiExpandEntry
14
15namespace xgboost::tree {
16inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntry> *candidates) {
17 CHECK(!tree->IsMultiTarget());
18 tree->ExpandNode(
19 /*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
20 /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
21 /*left_sum=*/0.0f,
22 /*right_sum=*/0.0f);
23 candidates->front().split.split_value = split_value;
24 candidates->front().split.sindex = 0;
25 candidates->front().split.sindex |= (1U << 31);
26}
27
28inline void GetMultiSplitForTest(RegTree *tree, float split_value,
29 std::vector<MultiExpandEntry> *candidates) {
30 CHECK(tree->IsMultiTarget());
31 auto n_targets = tree->NumTargets();
32 Context ctx;
33 linalg::Vector<float> base_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
34 linalg::Vector<float> left_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
35 linalg::Vector<float> right_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
36
37 tree->ExpandNode(/*nidx=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
38 /*default_left=*/true, base_weight.HostView(), left_weight.HostView(),
39 right_weight.HostView());
40 candidates->front().split.split_value = split_value;
41 candidates->front().split.sindex = 0;
42 candidates->front().split.sindex |= (1U << 31);
43}
44} // namespace xgboost::tree
45#endif // XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
Copyright 2014-2023, XGBoost Contributors.
defines console logging options for xgboost. Use to enforce unified print behavior.
Copyright 2021-2023 by XGBoost Contributors.
auto Constant(Context const *ctx, T v, Index &&...index)
Create an array with value v.
Definition linalg.h:958
Copyright 2021-2023 by XGBoost Contributors.
Definition tree_updater.h:25
Copyright 2014-2023 by Contributors.