Medial Code Documentation
Loading...
Searching...
No Matches
multiclass_objective.hpp
1#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
2#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
3
4#include <LightGBM/objective_function.h>
5
6#include <cstring>
7#include <cmath>
8#include <vector>
9
10#include "binary_objective.hpp"
11
12namespace LightGBM {
17public:
18 explicit MulticlassSoftmax(const Config& config) {
19 num_class_ = config.num_class;
20 }
21
22 explicit MulticlassSoftmax(const std::vector<std::string>& strs) {
23 num_class_ = -1;
24 for (auto str : strs) {
25 auto tokens = Common::Split(str.c_str(), ':');
26 if (tokens.size() == 2) {
27 if (tokens[0] == std::string("num_class")) {
28 Common::Atoi(tokens[1].c_str(), &num_class_);
29 }
30 }
31 }
32 if (num_class_ < 0) {
33 Log::Fatal("Objective should contain num_class field");
34 }
35 }
36
38 }
39
40 void Init(const Metadata& metadata, data_size_t num_data) override {
41 num_data_ = num_data;
42 label_ = metadata.label();
43 weights_ = metadata.weights();
44 label_int_.resize(num_data_);
45 class_init_probs_.resize(num_class_, 0.0);
46 double sum_weight = 0.0;
47 for (int i = 0; i < num_data_; ++i) {
48 label_int_[i] = static_cast<int>(label_[i]);
49 if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
50 Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
51 }
52 if (weights_ == nullptr) {
53 class_init_probs_[label_int_[i]] += 1.0;
54 } else {
55 class_init_probs_[label_int_[i]] += weights_[i];
56 sum_weight += weights_[i];
57 }
58 }
59 if (weights_ == nullptr) {
60 sum_weight = num_data_;
61 }
62 for (int i = 0; i < num_class_; ++i) {
63 class_init_probs_[i] /= sum_weight;
64 }
65 }
66
67 void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
68 if (weights_ == nullptr) {
69 std::vector<double> rec;
70 #pragma omp parallel for schedule(static) private(rec)
71 for (data_size_t i = 0; i < num_data_; ++i) {
72 rec.resize(num_class_);
73 for (int k = 0; k < num_class_; ++k) {
74 size_t idx = static_cast<size_t>(num_data_) * k + i;
75 rec[k] = static_cast<double>(score[idx]);
76 }
77 Common::Softmax(&rec);
78 for (int k = 0; k < num_class_; ++k) {
79 auto p = rec[k];
80 size_t idx = static_cast<size_t>(num_data_) * k + i;
81 if (label_int_[i] == k) {
82 gradients[idx] = static_cast<score_t>(p - 1.0f);
83 } else {
84 gradients[idx] = static_cast<score_t>(p);
85 }
86 hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p));
87 }
88 }
89 } else {
90 std::vector<double> rec;
91 #pragma omp parallel for schedule(static) private(rec)
92 for (data_size_t i = 0; i < num_data_; ++i) {
93 rec.resize(num_class_);
94 for (int k = 0; k < num_class_; ++k) {
95 size_t idx = static_cast<size_t>(num_data_) * k + i;
96 rec[k] = static_cast<double>(score[idx]);
97 }
98 Common::Softmax(&rec);
99 for (int k = 0; k < num_class_; ++k) {
100 auto p = rec[k];
101 size_t idx = static_cast<size_t>(num_data_) * k + i;
102 if (label_int_[i] == k) {
103 gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]);
104 } else {
105 gradients[idx] = static_cast<score_t>((p) * weights_[i]);
106 }
107 hessians[idx] = static_cast<score_t>((2.0f * p * (1.0f - p))* weights_[i]);
108 }
109 }
110 }
111 }
112
113 void ConvertOutput(const double* input, double* output) const override {
114 Common::Softmax(input, output, num_class_);
115 }
116
117 const char* GetName() const override {
118 return "multiclass";
119 }
120
121 std::string ToString() const override {
122 std::stringstream str_buf;
123 str_buf << GetName() << " ";
124 str_buf << "num_class:" << num_class_;
125 return str_buf.str();
126 }
127
128 bool SkipEmptyClass() const override { return true; }
129
130 int NumModelPerIteration() const override { return num_class_; }
131
132 int NumPredictOneRow() const override { return num_class_; }
133
134 bool NeedAccuratePrediction() const override { return false; }
135
136 double BoostFromScore(int class_id) const override {
137 return std::log(std::max<double>(kEpsilon, class_init_probs_[class_id]));
138 }
139
140 bool ClassNeedTrain(int class_id) const override {
141 if (std::fabs(class_init_probs_[class_id]) <= kEpsilon
142 || std::fabs(class_init_probs_[class_id]) >= 1.0 - kEpsilon) {
143 return false;
144 } else {
145 return true;
146 }
147 }
148
149private:
151 data_size_t num_data_;
153 int num_class_;
155 const label_t* label_;
157 std::vector<int> label_int_;
159 const label_t* weights_;
160 std::vector<double> class_init_probs_;
161};
162
167public:
168 explicit MulticlassOVA(const Config& config) {
169 num_class_ = config.num_class;
170 for (int i = 0; i < num_class_; ++i) {
171 binary_loss_.emplace_back(
172 new BinaryLogloss(config, [i](label_t label) { return static_cast<int>(label) == i; }));
173 }
174 sigmoid_ = config.sigmoid;
175 }
176
177 explicit MulticlassOVA(const std::vector<std::string>& strs) {
178 num_class_ = -1;
179 sigmoid_ = -1;
180 for (auto str : strs) {
181 auto tokens = Common::Split(str.c_str(), ':');
182 if (tokens.size() == 2) {
183 if (tokens[0] == std::string("num_class")) {
184 Common::Atoi(tokens[1].c_str(), &num_class_);
185 } else if (tokens[0] == std::string("sigmoid")) {
186 Common::Atof(tokens[1].c_str(), &sigmoid_);
187 }
188 }
189 }
190 if (num_class_ < 0) {
191 Log::Fatal("Objective should contain num_class field");
192 }
193 if (sigmoid_ <= 0.0) {
194 Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
195 }
196 }
197
199 }
200
201 void Init(const Metadata& metadata, data_size_t num_data) override {
202 num_data_ = num_data;
203 for (int i = 0; i < num_class_; ++i) {
204 binary_loss_[i]->Init(metadata, num_data);
205 }
206 }
207
208 void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
209 for (int i = 0; i < num_class_; ++i) {
210 int64_t bias = static_cast<int64_t>(num_data_) * i;
211 binary_loss_[i]->GetGradients(score + bias, gradients + bias, hessians + bias);
212 }
213 }
214
215 const char* GetName() const override {
216 return "multiclassova";
217 }
218
219 void ConvertOutput(const double* input, double* output) const override {
220 for (int i = 0; i < num_class_; ++i) {
221 output[i] = 1.0f / (1.0f + std::exp(-sigmoid_ * input[i]));
222 }
223 }
224
225 std::string ToString() const override {
226 std::stringstream str_buf;
227 str_buf << GetName() << " ";
228 str_buf << "num_class:" << num_class_ << " ";
229 str_buf << "sigmoid:" << sigmoid_;
230 return str_buf.str();
231 }
232
233 bool SkipEmptyClass() const override { return true; }
234
235 int NumModelPerIteration() const override { return num_class_; }
236
237 int NumPredictOneRow() const override { return num_class_; }
238
239 bool NeedAccuratePrediction() const override { return false; }
240
241 double BoostFromScore(int class_id) const override {
242 return binary_loss_[class_id]->BoostFromScore(0);
243 }
244
245 bool ClassNeedTrain(int class_id) const override {
246 return binary_loss_[class_id]->ClassNeedTrain(0);
247 }
248
249private:
251 data_size_t num_data_;
253 int num_class_;
254 std::vector<std::unique_ptr<BinaryLogloss>> binary_loss_;
255 double sigmoid_;
256};
257
258} // namespace LightGBM
259#endif // LightGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
Objective function for binary classification.
Definition binary_objective.hpp:13
This class is used to store some meta(non-feature) data for training data, e.g. labels,...
Definition dataset.h:36
const label_t * label() const
Get pointer of label.
Definition dataset.h:113
const label_t * weights() const
Get weights, if not exists, will return nullptr.
Definition dataset.h:146
Objective function for multiclass classification, use one-vs-all binary objective function.
Definition multiclass_objective.hpp:166
void Init(const Metadata &metadata, data_size_t num_data) override
Initialize.
Definition multiclass_objective.hpp:201
bool NeedAccuratePrediction() const override
The prediction should be accurate or not. True will disable early stopping for prediction.
Definition multiclass_objective.hpp:239
void GetGradients(const double *score, score_t *gradients, score_t *hessians) const override
calculating first order derivative of loss function
Definition multiclass_objective.hpp:208
Objective function for multiclass classification, use softmax as objective functions.
Definition multiclass_objective.hpp:16
void Init(const Metadata &metadata, data_size_t num_data) override
Initialize.
Definition multiclass_objective.hpp:40
void GetGradients(const double *score, score_t *gradients, score_t *hessians) const override
calculating first order derivative of loss function
Definition multiclass_objective.hpp:67
bool NeedAccuratePrediction() const override
The prediction should be accurate or not. True will disable early stopping for prediction.
Definition multiclass_objective.hpp:134
The interface of Objective Function.
Definition objective_function.h:13
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float score_t
Type of score, and gradients.
Definition meta.h:26
float label_t
Type of metadata, include weight and label.
Definition meta.h:33
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14
Definition config.h:27