4#ifndef XGBOOST_TEST_PREDICTOR_H_
5#define XGBOOST_TEST_PREDICTOR_H_
13#include "../../../src/gbm/gbtree_model.h"
14#include "../helpers.h"
17inline gbm::GBTreeModel CreateTestModel(LearnerModelParam
const* param, Context
const* ctx,
18 size_t n_classes = 1) {
19 gbm::GBTreeModel model(param, ctx);
21 for (
size_t i = 0; i < n_classes; ++i) {
22 std::vector<std::unique_ptr<RegTree>> trees;
23 trees.push_back(std::unique_ptr<RegTree>(
new RegTree));
25 (*trees.back())[0].SetLeaf(1.5f);
26 (*trees.back()).Stat(0).sum_hess = 1.0f;
28 model.CommitModelGroup(std::move(trees), i);
34inline auto CreatePredictorForTest(Context
const* ctx) {
43template <
typename Page>
44void TestPredictionFromGradientIndex(Context
const* ctx,
size_t rows,
size_t cols,
45 std::shared_ptr<DMatrix> p_hist) {
46 constexpr size_t kClasses { 3 };
48 LearnerModelParam mparam{MakeMP(cols, .5, kClasses)};
51 std::unique_ptr<Predictor> predictor =
52 std::unique_ptr<Predictor>(CreatePredictorForTest(&cuda_ctx));
53 predictor->Configure({});
55 gbm::GBTreeModel model = CreateTestModel(&mparam, ctx, kClasses);
58 auto p_precise = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();
60 PredictionCacheEntry approx_out_predictions;
61 predictor->InitOutPredictions(p_hist->Info(), &approx_out_predictions.predictions, model);
62 predictor->PredictBatch(p_hist.get(), &approx_out_predictions, model, 0);
64 PredictionCacheEntry precise_out_predictions;
65 predictor->InitOutPredictions(p_precise->Info(), &precise_out_predictions.predictions, model);
66 predictor->PredictBatch(p_precise.get(), &precise_out_predictions, model, 0);
68 for (
size_t i = 0; i < rows; ++i) {
69 CHECK_EQ(approx_out_predictions.predictions.HostVector()[i],
70 precise_out_predictions.predictions.HostVector()[i]);
78 auto p_dmat = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();
79 PredictionCacheEntry precise_out_predictions;
80 predictor->InitOutPredictions(p_dmat->Info(), &precise_out_predictions.predictions, model);
81 predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0);
82 CHECK(!p_dmat->PageExists<Page>());
87void TestTrainingPrediction(Context
const* ctx,
size_t rows,
size_t bins,
88 std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist);
90void TestInplacePrediction(Context
const* ctx, std::shared_ptr<DMatrix> x,
bst_row_t rows,
93void TestPredictionWithLesserFeatures(Context
const* ctx);
95void TestPredictionDeviceAccess();
97void TestCategoricalPrediction(Context
const* ctx,
bool is_column_split);
99void TestCategoricalPredictionColumnSplit(Context
const* ctx);
101void TestPredictionWithLesserFeaturesColumnSplit(Context
const* ctx);
103void TestCategoricalPredictLeaf(Context
const* ctx,
bool is_column_split);
105void TestCategoricalPredictLeafColumnSplit(Context
const* ctx);
107void TestIterationRange(Context
const* ctx);
109void TestIterationRangeColumnSplit(Context
const* ctx);
111void TestSparsePrediction(Context
const* ctx,
float sparsity);
113void TestSparsePredictionColumnSplit(Context
const* ctx,
float sparsity);
115void TestVectorLeafPrediction(Context
const* ctx);
static Predictor * Create(std::string const &name, Context const *ctx)
Creates a new Predictor*.
Definition predictor.cc:27
Copyright 2014-2023, XGBoost Contributors.
namespace of xgboost
Definition base.h:90
Context MakeCUDACtx(std::int32_t device)
Make a context that uses CUDA if device >= 0.
Definition helpers.h:410
uint32_t bst_feature_t
Type for data column (feature) index.
Definition base.h:101
std::size_t bst_row_t
Type for data row index.
Definition base.h:110
Copyright 2017-2023 by Contributors.
bool IsCPU() const
Is XGBoost running on CPU?
Definition context.h:133