Medial Code Documentation
Loading...
Searching...
No Matches
objective.h
Go to the documentation of this file.
1
7#ifndef XGBOOST_OBJECTIVE_H_
8#define XGBOOST_OBJECTIVE_H_
9
10#include <dmlc/registry.h>
11#include <xgboost/base.h>
12#include <xgboost/data.h>
14#include <xgboost/model.h>
15#include <xgboost/task.h>
16
17#include <cstdint> // std::int32_t
18#include <functional>
19#include <string>
20#include <utility>
21#include <vector>
22
23namespace xgboost {
24
25class RegTree;
26struct Context;
27
29class ObjFunction : public Configurable {
30 protected:
31 Context const* ctx_;
32
33 public:
34 static constexpr float DefaultBaseScore() { return 0.5f; }
35
36 public:
38 ~ObjFunction() override = default;
43 virtual void Configure(const std::vector<std::pair<std::string, std::string> >& args) = 0;
51 virtual void GetGradient(const HostDeviceVector<bst_float>& preds,
52 const MetaInfo& info,
53 int iteration,
54 HostDeviceVector<GradientPair>* out_gpair) = 0;
55
57 virtual const char* DefaultEvalMetric() const = 0;
61 virtual Json DefaultMetricConfig() const { return Json{Null{}}; }
62
63 // the following functions are optional, most of time default implementation is good enough
69
76 this->PredTransform(io_preds);
77 }
84 virtual bst_float ProbToMargin(bst_float base_score) const {
85 return base_score;
86 }
93 virtual void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) const;
97 virtual struct ObjInfo Task() const = 0;
102 virtual bst_target_t Targets(MetaInfo const& info) const {
103 if (info.labels.Shape(1) > 1) {
104 LOG(FATAL) << "multioutput is not supported by current objective function";
105 }
106 return 1;
107 }
108
124 virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& /*position*/,
125 MetaInfo const& /*info*/, float /*learning_rate*/,
126 HostDeviceVector<float> const& /*prediction*/,
127 std::int32_t /*group_idx*/, RegTree* /*p_tree*/) const {}
128
134 static ObjFunction* Create(const std::string& name, Context const* ctx);
135};
136
141 : public dmlc::FunctionRegEntryBase<ObjFunctionReg,
142 std::function<ObjFunction* ()> > {
143};
144
157#define XGBOOST_REGISTER_OBJECTIVE(UniqueId, Name) \
158 static DMLC_ATTRIBUTE_UNUSED ::xgboost::ObjFunctionReg & \
159 __make_ ## ObjFunctionReg ## _ ## UniqueId ## __ = \
160 ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->__REGISTER__(Name)
161} // namespace xgboost
162#endif // XGBOOST_OBJECTIVE_H_
Common base class for function registry.
Definition registry.h:151
Definition host_device_vector.h:87
Definition json.h:296
Data structure representing JSON format.
Definition json.h:357
Meta information about dataset, always sit in memory.
Definition data.h:48
linalg::Tensor< float, 2 > labels
label of each instance
Definition data.h:60
interface of objective function
Definition objective.h:29
virtual void InitEstimation(MetaInfo const &info, linalg::Tensor< float, 1 > *base_score) const
Make initialize estimation of prediction.
Definition objective.cc:35
virtual void PredTransform(HostDeviceVector< bst_float > *) const
transform prediction values, this is only called when Prediction is called
Definition objective.h:68
virtual bst_float ProbToMargin(bst_float base_score) const
transform probability value back to margin this is used to transform user-set base_score back to marg...
Definition objective.h:84
virtual const char * DefaultEvalMetric() const =0
virtual void GetGradient(const HostDeviceVector< bst_float > &preds, const MetaInfo &info, int iteration, HostDeviceVector< GradientPair > *out_gpair)=0
Get gradient over each of predictions, given existing information.
virtual void EvalTransform(HostDeviceVector< bst_float > *io_preds)
transform prediction values, this is only called when Eval is called, usually it redirect to PredTran...
Definition objective.h:75
virtual void Configure(const std::vector< std::pair< std::string, std::string > > &args)=0
Configure the objective with the specified parameters.
virtual Json DefaultMetricConfig() const
Return the configuration for the default metric.
Definition objective.h:61
virtual struct ObjInfo Task() const =0
Return task of this objective.
~ObjFunction() override=default
virtual destructor
virtual void UpdateTreeLeaf(HostDeviceVector< bst_node_t > const &, MetaInfo const &, float, HostDeviceVector< float > const &, std::int32_t, RegTree *) const
Update the leaf values after a tree is built.
Definition objective.h:124
static ObjFunction * Create(const std::string &name, Context const *ctx)
Create an objective function according to name.
Definition objective.cc:20
virtual bst_target_t Targets(MetaInfo const &info) const
Return number of targets for input matrix.
Definition objective.h:102
define regression tree to be the most common tree model.
Definition tree_model.h:158
A tensor storage.
Definition linalg.h:742
A device-and-host vector abstraction layer.
Copyright 2015-2023 by XGBoost Contributors.
Copyright 2015-2023 by XGBoost Contributors.
Defines the abstract interface for different components in XGBoost.
namespace of xgboost
Definition base.h:90
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition base.h:118
float bst_float
float type, used for storing statistics
Definition base.h:97
Registry utility that helps to build registry singletons.
Definition model.h:31
Runtime context for XGBoost.
Definition context.h:84
Registry entry for objective factory functions.
Definition objective.h:142
A struct returned by objective, which determines task at hand. The struct is not used by any algorith...
Definition task.h:24