Medial Code Documentation
Loading...
Searching...
No Matches
predictor.h
Go to the documentation of this file.
1
7#pragma once
8#include <xgboost/base.h>
9#include <xgboost/cache.h> // for DMatrixCache
10#include <xgboost/context.h> // for Context
11#include <xgboost/context.h>
12#include <xgboost/data.h>
14
15#include <functional> // for function
16#include <memory> // for shared_ptr
17#include <string>
18#include <utility> // for make_pair
19#include <vector>
20
21// Forward declarations
22namespace xgboost::gbm {
23struct GBTreeModel;
24} // namespace xgboost::gbm
25
26namespace xgboost {
31 // A storage for caching prediction values
33 // The version of current cache, corresponding number of layers of trees
34 std::uint32_t version{0};
35
36 PredictionCacheEntry() = default;
42 void Update(std::uint32_t v) { version += v; }
43 void Reset() { version = 0; }
44};
45
49class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
50 // We cache up to 64 DMatrix for all threads
51 std::size_t static constexpr DefaultSize() { return 64; }
52
53 public:
55 PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, std::int32_t device) {
56 auto p_cache = this->CacheItem(m);
57 if (device != Context::kCpuId) {
58 p_cache->predictions.SetDevice(device);
59 }
60 return *p_cache;
61 }
62};
63
72class Predictor {
73 protected:
74 Context const* ctx_;
75
76 public:
77 explicit Predictor(Context const* ctx) : ctx_{ctx} {}
78
79 virtual ~Predictor() = default;
80
86 virtual void Configure(Args const&);
87
95 void InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_float>* out_predt,
96 const gbm::GBTreeModel& model) const;
97
108 virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
109 const gbm::GBTreeModel& model, uint32_t tree_begin,
110 uint32_t tree_end = 0) const = 0;
111
125 virtual bool InplacePredict(std::shared_ptr<DMatrix> p_fmat, const gbm::GBTreeModel& model,
126 float missing, PredictionCacheEntry* out_preds,
127 uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0;
141 virtual void PredictInstance(const SparsePage::Inst& inst,
142 std::vector<bst_float>* out_preds,
143 const gbm::GBTreeModel& model,
144 unsigned tree_end = 0,
145 bool is_column_split = false) const = 0;
146
157 virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
158 const gbm::GBTreeModel& model,
159 unsigned tree_end = 0) const = 0;
160
176 virtual void
178 const gbm::GBTreeModel &model, unsigned tree_end = 0,
179 std::vector<bst_float> const *tree_weights = nullptr,
180 bool approximate = false, int condition = 0,
181 unsigned condition_feature = 0) const = 0;
182
183 virtual void PredictInteractionContributions(
184 DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
185 const gbm::GBTreeModel &model, unsigned tree_end = 0,
186 std::vector<bst_float> const *tree_weights = nullptr,
187 bool approximate = false) const = 0;
188
195 static Predictor* Create(std::string const& name, Context const* ctx);
196};
197
202 : public dmlc::FunctionRegEntryBase<PredictorReg, std::function<Predictor*(Context const*)>> {};
203
204#define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \
205 static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \
206 __make_##PredictorReg##_##UniqueId##__ = \
207 ::dmlc::Registry<::xgboost::PredictorReg>::Get()->__REGISTER__(Name)
208} // namespace xgboost
Definition svm.cpp:71
Common base class for function registry.
Definition registry.h:151
Thread-aware FIFO cache for DMatrix related data.
Definition cache.h:26
std::shared_ptr< PredictionCacheEntry > CacheItem(std::shared_ptr< DMatrix > m, Args const &... args)
Cache a new DMatrix if it's not in the cache already.
Definition cache.h:145
Internal data structured used by XGBoost during training.
Definition data.h:509
Definition host_device_vector.h:87
Meta information about dataset, always sit in memory.
Definition data.h:48
A container for managed prediction caches.
Definition predictor.h:49
Performs prediction on individual training instances or batches of instances for GBTree.
Definition predictor.h:72
static Predictor * Create(std::string const &name, Context const *ctx)
Creates a new Predictor*.
Definition predictor.cc:27
void InitOutPredictions(const MetaInfo &info, HostDeviceVector< bst_float > *out_predt, const gbm::GBTreeModel &model) const
Initialize output prediction.
Definition predictor.cc:46
virtual void Configure(Args const &)
Configure and register input matrices in prediction cache.
Definition predictor.cc:25
virtual void PredictContribution(DMatrix *dmat, HostDeviceVector< bst_float > *out_contribs, const gbm::GBTreeModel &model, unsigned tree_end=0, std::vector< bst_float > const *tree_weights=nullptr, bool approximate=false, int condition=0, unsigned condition_feature=0) const =0
feature contributions to individual predictions; the output will be a vector of length (nfeats + 1) *...
virtual void PredictLeaf(DMatrix *dmat, HostDeviceVector< bst_float > *out_preds, const gbm::GBTreeModel &model, unsigned tree_end=0) const =0
predict the leaf index of each tree, the output will be nsample * ntree vector this is only valid in ...
virtual bool InplacePredict(std::shared_ptr< DMatrix > p_fmat, const gbm::GBTreeModel &model, float missing, PredictionCacheEntry *out_preds, uint32_t tree_begin=0, uint32_t tree_end=0) const =0
Inplace prediction.
virtual void PredictInstance(const SparsePage::Inst &inst, std::vector< bst_float > *out_preds, const gbm::GBTreeModel &model, unsigned tree_end=0, bool is_column_split=false) const =0
online prediction function, predict score for one instance at a time NOTE: use the batch prediction i...
virtual void PredictBatch(DMatrix *dmat, PredictionCacheEntry *out_preds, const gbm::GBTreeModel &model, uint32_t tree_begin, uint32_t tree_end=0) const =0
Generate batch predictions for a given feature matrix.
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
Copyright 2014-2023, XGBoost Contributors.
A device-and-host vector abstraction layer.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2019-2023, XGBoost Contributors.
Definition linear_updater.h:23
namespace of xgboost
Definition base.h:90
Runtime context for XGBoost.
Definition context.h:84
Contains pointer to input matrix and associated cached predictions.
Definition predictor.h:30
void Update(std::uint32_t v)
Update the cache entry by number of versions.
Definition predictor.h:42
Registry entry for predictor.
Definition predictor.h:202
Definition gbtree_model.h:84