1#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
2#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
4#include <LightGBM/objective_function.h>
10#include "binary_objective.hpp"
19 num_class_ = config.num_class;
24 for (
auto str : strs) {
25 auto tokens = Common::Split(str.c_str(),
':');
26 if (tokens.size() == 2) {
27 if (tokens[0] == std::string(
"num_class")) {
28 Common::Atoi(tokens[1].c_str(), &num_class_);
33 Log::Fatal(
"Objective should contain num_class field");
42 label_ = metadata.
label();
44 label_int_.resize(num_data_);
45 class_init_probs_.resize(num_class_, 0.0);
46 double sum_weight = 0.0;
47 for (
int i = 0; i < num_data_; ++i) {
48 label_int_[i] =
static_cast<int>(label_[i]);
49 if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
50 Log::Fatal(
"Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
52 if (weights_ ==
nullptr) {
53 class_init_probs_[label_int_[i]] += 1.0;
55 class_init_probs_[label_int_[i]] += weights_[i];
56 sum_weight += weights_[i];
59 if (weights_ ==
nullptr) {
60 sum_weight = num_data_;
62 for (
int i = 0; i < num_class_; ++i) {
63 class_init_probs_[i] /= sum_weight;
68 if (weights_ ==
nullptr) {
69 std::vector<double> rec;
70 #pragma omp parallel for schedule(static) private(rec)
72 rec.resize(num_class_);
73 for (
int k = 0; k < num_class_; ++k) {
74 size_t idx =
static_cast<size_t>(num_data_) * k + i;
75 rec[k] =
static_cast<double>(score[idx]);
77 Common::Softmax(&rec);
78 for (
int k = 0; k < num_class_; ++k) {
80 size_t idx =
static_cast<size_t>(num_data_) * k + i;
81 if (label_int_[i] == k) {
82 gradients[idx] =
static_cast<score_t>(p - 1.0f);
84 gradients[idx] =
static_cast<score_t>(p);
86 hessians[idx] =
static_cast<score_t>(2.0f * p * (1.0f - p));
90 std::vector<double> rec;
91 #pragma omp parallel for schedule(static) private(rec)
93 rec.resize(num_class_);
94 for (
int k = 0; k < num_class_; ++k) {
95 size_t idx =
static_cast<size_t>(num_data_) * k + i;
96 rec[k] =
static_cast<double>(score[idx]);
98 Common::Softmax(&rec);
99 for (
int k = 0; k < num_class_; ++k) {
101 size_t idx =
static_cast<size_t>(num_data_) * k + i;
102 if (label_int_[i] == k) {
103 gradients[idx] =
static_cast<score_t>((p - 1.0f) * weights_[i]);
105 gradients[idx] =
static_cast<score_t>((p) * weights_[i]);
107 hessians[idx] =
static_cast<score_t>((2.0f * p * (1.0f - p))* weights_[i]);
113 void ConvertOutput(
const double* input,
double* output)
const override {
114 Common::Softmax(input, output, num_class_);
117 const char* GetName()
const override {
121 std::string ToString()
const override {
122 std::stringstream str_buf;
123 str_buf << GetName() <<
" ";
124 str_buf <<
"num_class:" << num_class_;
125 return str_buf.str();
128 bool SkipEmptyClass()
const override {
return true; }
130 int NumModelPerIteration()
const override {
return num_class_; }
132 int NumPredictOneRow()
const override {
return num_class_; }
136 double BoostFromScore(
int class_id)
const override {
137 return std::log(std::max<double>(kEpsilon, class_init_probs_[class_id]));
140 bool ClassNeedTrain(
int class_id)
const override {
141 if (std::fabs(class_init_probs_[class_id]) <= kEpsilon
142 || std::fabs(class_init_probs_[class_id]) >= 1.0 - kEpsilon) {
157 std::vector<int> label_int_;
160 std::vector<double> class_init_probs_;
169 num_class_ = config.num_class;
170 for (
int i = 0; i < num_class_; ++i) {
171 binary_loss_.emplace_back(
174 sigmoid_ = config.sigmoid;
177 explicit MulticlassOVA(
const std::vector<std::string>& strs) {
180 for (
auto str : strs) {
181 auto tokens = Common::Split(str.c_str(),
':');
182 if (tokens.size() == 2) {
183 if (tokens[0] == std::string(
"num_class")) {
184 Common::Atoi(tokens[1].c_str(), &num_class_);
185 }
else if (tokens[0] == std::string(
"sigmoid")) {
186 Common::Atof(tokens[1].c_str(), &sigmoid_);
190 if (num_class_ < 0) {
191 Log::Fatal(
"Objective should contain num_class field");
193 if (sigmoid_ <= 0.0) {
194 Log::Fatal(
"Sigmoid parameter %f should be greater than zero", sigmoid_);
202 num_data_ = num_data;
203 for (
int i = 0; i < num_class_; ++i) {
204 binary_loss_[i]->Init(metadata, num_data);
209 for (
int i = 0; i < num_class_; ++i) {
210 int64_t bias =
static_cast<int64_t
>(num_data_) * i;
211 binary_loss_[i]->GetGradients(score + bias, gradients + bias, hessians + bias);
215 const char* GetName()
const override {
216 return "multiclassova";
219 void ConvertOutput(
const double* input,
double* output)
const override {
220 for (
int i = 0; i < num_class_; ++i) {
221 output[i] = 1.0f / (1.0f + std::exp(-sigmoid_ * input[i]));
225 std::string ToString()
const override {
226 std::stringstream str_buf;
227 str_buf << GetName() <<
" ";
228 str_buf <<
"num_class:" << num_class_ <<
" ";
229 str_buf <<
"sigmoid:" << sigmoid_;
230 return str_buf.str();
233 bool SkipEmptyClass()
const override {
return true; }
235 int NumModelPerIteration()
const override {
return num_class_; }
237 int NumPredictOneRow()
const override {
return num_class_; }
241 double BoostFromScore(
int class_id)
const override {
242 return binary_loss_[class_id]->BoostFromScore(0);
245 bool ClassNeedTrain(
int class_id)
const override {
246 return binary_loss_[class_id]->ClassNeedTrain(0);
254 std::vector<std::unique_ptr<BinaryLogloss>> binary_loss_;
Objective function for binary classification.
Definition binary_objective.hpp:13
Objective function for multiclass classification, use one-vs-all binary objective function.
Definition multiclass_objective.hpp:166
void Init(const Metadata &metadata, data_size_t num_data) override
Initialize.
Definition multiclass_objective.hpp:201
bool NeedAccuratePrediction() const override
The prediction should be accurate or not. True will disable early stopping for prediction.
Definition multiclass_objective.hpp:239
void GetGradients(const double *score, score_t *gradients, score_t *hessians) const override
calculating first order derivative of loss function
Definition multiclass_objective.hpp:208
Objective function for multiclass classification, use softmax as objective functions.
Definition multiclass_objective.hpp:16
void Init(const Metadata &metadata, data_size_t num_data) override
Initialize.
Definition multiclass_objective.hpp:40
void GetGradients(const double *score, score_t *gradients, score_t *hessians) const override
calculating first order derivative of loss function
Definition multiclass_objective.hpp:67
bool NeedAccuratePrediction() const override
The prediction should be accurate or not. True will disable early stopping for prediction.
Definition multiclass_objective.hpp:134
The interface of Objective Function.
Definition objective_function.h:13
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float score_t
Type of score, and gradients.
Definition meta.h:26
float label_t
Type of metadata, include weight and label.
Definition meta.h:33
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14