Medial Code Documentation
Loading...
Searching...
No Matches
tree_updater.h
Go to the documentation of this file.
1
8#ifndef XGBOOST_TREE_UPDATER_H_
9#define XGBOOST_TREE_UPDATER_H_
10
11#include <dmlc/registry.h>
12#include <xgboost/base.h> // for Args, GradientPair
13#include <xgboost/data.h> // DMatrix
14#include <xgboost/host_device_vector.h> // for HostDeviceVector
15#include <xgboost/linalg.h> // for VectorView
16#include <xgboost/model.h> // for Configurable
17#include <xgboost/span.h> // for Span
18#include <xgboost/tree_model.h> // for RegTree
19
20#include <functional> // for function
21#include <string> // for string
22#include <vector> // for vector
23
24namespace xgboost {
25namespace tree {
26struct TrainParam;
27}
28
29class Json;
30struct Context;
31struct ObjInfo;
32
36class TreeUpdater : public Configurable {
37 protected:
38 Context const* ctx_ = nullptr;
39
40 public:
41 explicit TreeUpdater(const Context* ctx) : ctx_(ctx) {}
43 ~TreeUpdater() override = default;
48 virtual void Configure(const Args& args) = 0;
55 [[nodiscard]] virtual bool CanModifyTree() const { return false; }
60 [[nodiscard]] virtual bool HasNodePosition() const { return false; }
74 virtual void Update(tree::TrainParam const* param, HostDeviceVector<GradientPair>* gpair,
76 const std::vector<RegTree*>& out_trees) = 0;
77
88 virtual bool UpdatePredictionCache(const DMatrix* /*data*/,
89 linalg::MatrixView<float> /*out_preds*/) {
90 return false;
91 }
92
93 [[nodiscard]] virtual char const* Name() const = 0;
94
101 static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo const* task);
102};
103
109 TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo const* task)>> {};
110
123#define XGBOOST_REGISTER_TREE_UPDATER(UniqueId, Name) \
124 static DMLC_ATTRIBUTE_UNUSED ::xgboost::TreeUpdaterReg& \
125 __make_ ## TreeUpdaterReg ## _ ## UniqueId ## __ = \
126 ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->__REGISTER__(Name)
127
128} // namespace xgboost
129#endif // XGBOOST_TREE_UPDATER_H_
Common base class for function registry.
Definition registry.h:151
Internal data structured used by XGBoost during training.
Definition data.h:509
Definition host_device_vector.h:87
Data structure representing JSON format.
Definition json.h:357
interface of tree update module, that performs update of a tree.
Definition tree_updater.h:36
~TreeUpdater() override=default
virtual destructor
virtual void Update(tree::TrainParam const *param, HostDeviceVector< GradientPair > *gpair, DMatrix *data, common::Span< HostDeviceVector< bst_node_t > > out_position, const std::vector< RegTree * > &out_trees)=0
perform update to the tree models
static TreeUpdater * Create(const std::string &name, Context const *ctx, ObjInfo const *task)
Create a tree updater given name.
Definition tree_updater.cc:17
virtual bool UpdatePredictionCache(const DMatrix *, linalg::MatrixView< float >)
determines whether updater has enough knowledge about a given dataset to quickly update prediction ca...
Definition tree_updater.h:88
virtual bool CanModifyTree() const
Whether this updater can be used for updating existing trees.
Definition tree_updater.h:55
virtual bool HasNodePosition() const
Wether the out_position in Update is valid. This determines whether adaptive tree can be used.
Definition tree_updater.h:60
virtual void Configure(const Args &args)=0
Initialize the updater with given arguments.
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
A tensor view with static type and dimension.
Definition linalg.h:293
A device-and-host vector abstraction layer.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2021-2023 by XGBoost Contributors.
Defines the abstract interface for different components in XGBoost.
namespace of xgboost
Definition base.h:90
Registry utility that helps to build registry singletons.
Definition model.h:31
Runtime context for XGBoost.
Definition context.h:84
A struct returned by objective, which determines task at hand. The struct is not used by any algorith...
Definition task.h:24
Registry entry for tree updater.
Definition tree_updater.h:109
training parameters for regression tree
Definition param.h:28
Copyright 2014-2023 by Contributors.