1#ifndef LIGHTGBM_METRIC_XENTROPY_METRIC_HPP_
2#define LIGHTGBM_METRIC_XENTROPY_METRIC_HPP_
4#include <LightGBM/metric.h>
5#include <LightGBM/meta.h>
7#include <LightGBM/utils/log.h>
8#include <LightGBM/utils/common.h>
31 inline static double XentLoss(
label_t label,
double prob) {
32 const double log_arg_epsilon = 1.0e-12;
34 if (prob > log_arg_epsilon) {
37 a *= std::log(log_arg_epsilon);
39 double b = 1.0f - label;
40 if (1.0f - prob > log_arg_epsilon) {
41 b *= std::log(1.0f - prob);
43 b *= std::log(log_arg_epsilon);
49 inline static double XentLambdaLoss(
label_t label,
label_t weight,
double hhat) {
50 return XentLoss(label, 1.0f - std::exp(-weight * hhat));
56 inline static double YentLoss(
double p) {
58 if (p > 0) hp += p * std::log(p);
60 if (q > 0) hp += q * std::log(q);
73 name_.emplace_back(
"xentropy");
75 label_ = metadata.
label();
78 CHECK_NOTNULL(label_);
81 Common::CheckElementsIntervalClosed<label_t>(label_, 0.0f, 1.0f, num_data_, GetName()[0].c_str());
82 Log::Info(
"[%s:%s]: (metric) labels passed interval [0, 1] check", GetName()[0].c_str(), __func__);
85 if (weights_ ==
nullptr) {
86 sum_weights_ =
static_cast<double>(num_data_);
89 Common::ObtainMinMaxSum(weights_, num_data_, &minw, (
label_t*)
nullptr, &sum_weights_);
91 Log::Fatal(
"[%s:%s]: (metric) weights not allowed to be negative", GetName()[0].c_str(), __func__);
96 if (sum_weights_ <= 0.0f) {
97 Log::Fatal(
"[%s:%s]: sum-of-weights = %f is non-positive", __func__, GetName()[0].c_str(), sum_weights_);
99 Log::Info(
"[%s:%s]: sum-of-weights = %f", GetName()[0].c_str(), __func__, sum_weights_);
103 double sum_loss = 0.0f;
104 if (objective ==
nullptr) {
105 if (weights_ ==
nullptr) {
106 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
108 sum_loss += XentLoss(label_[i], score[i]);
111 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
113 sum_loss += XentLoss(label_[i], score[i]) * weights_[i];
117 if (weights_ ==
nullptr) {
118 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
121 objective->ConvertOutput(&score[i], &p);
122 sum_loss += XentLoss(label_[i], p);
125 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
128 objective->ConvertOutput(&score[i], &p);
129 sum_loss += XentLoss(label_[i], p) * weights_[i];
133 double loss = sum_loss / sum_weights_;
134 return std::vector<double>(1, loss);
137 const std::vector<std::string>& GetName()
const override {
141 double factor_to_bigger_better()
const override {
155 std::vector<std::string> name_;
168 name_.emplace_back(
"xentlambda");
169 num_data_ = num_data;
170 label_ = metadata.
label();
173 CHECK_NOTNULL(label_);
174 Common::CheckElementsIntervalClosed<label_t>(label_, 0.0f, 1.0f, num_data_, GetName()[0].c_str());
175 Log::Info(
"[%s:%s]: (metric) labels passed interval [0, 1] check", GetName()[0].c_str(), __func__);
178 if (weights_ !=
nullptr) {
180 Common::ObtainMinMaxSum(weights_, num_data_, &minw, (
label_t*)
nullptr, (
label_t*)
nullptr);
182 Log::Fatal(
"[%s:%s]: (metric) all weights must be positive", GetName()[0].c_str(), __func__);
188 double sum_loss = 0.0f;
189 if (objective ==
nullptr) {
190 if (weights_ ==
nullptr) {
191 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
193 double hhat = std::log(1.0f + std::exp(score[i]));
194 sum_loss += XentLambdaLoss(label_[i], 1.0f, hhat);
197 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
199 double hhat = std::log(1.0f + std::exp(score[i]));
200 sum_loss += XentLambdaLoss(label_[i], weights_[i], hhat);
204 if (weights_ ==
nullptr) {
205 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
208 objective->ConvertOutput(&score[i], &hhat);
209 sum_loss += XentLambdaLoss(label_[i], 1.0f, hhat);
212 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
215 objective->ConvertOutput(&score[i], &hhat);
216 sum_loss += XentLambdaLoss(label_[i], weights_[i], hhat);
220 return std::vector<double>(1, sum_loss /
static_cast<double>(num_data_));
223 const std::vector<std::string>& GetName()
const override {
227 double factor_to_bigger_better()
const override {
239 std::vector<std::string> name_;
251 name_.emplace_back(
"kldiv");
252 num_data_ = num_data;
253 label_ = metadata.
label();
256 CHECK_NOTNULL(label_);
257 Common::CheckElementsIntervalClosed<label_t>(label_, 0.0f, 1.0f, num_data_, GetName()[0].c_str());
258 Log::Info(
"[%s:%s]: (metric) labels passed interval [0, 1] check", GetName()[0].c_str(), __func__);
260 if (weights_ ==
nullptr) {
261 sum_weights_ =
static_cast<double>(num_data_);
264 Common::ObtainMinMaxSum(weights_, num_data_, &minw, (
label_t*)
nullptr, &sum_weights_);
266 Log::Fatal(
"[%s:%s]: (metric) at least one weight is negative", GetName()[0].c_str(), __func__);
271 if (sum_weights_ <= 0.0f) {
272 Log::Fatal(
"[%s:%s]: sum-of-weights = %f is non-positive", GetName()[0].c_str(), __func__, sum_weights_);
275 Log::Info(
"[%s:%s]: sum-of-weights = %f", GetName()[0].c_str(), __func__, sum_weights_);
278 presum_label_entropy_ = 0.0f;
279 if (weights_ ==
nullptr) {
282 presum_label_entropy_ += YentLoss(label_[i]);
287 presum_label_entropy_ += YentLoss(label_[i]) * weights_[i];
290 presum_label_entropy_ /= sum_weights_;
293 Log::Info(
"%s offset term = %f", GetName()[0].c_str(), presum_label_entropy_);
297 double sum_loss = 0.0f;
298 if (objective ==
nullptr) {
299 if (weights_ ==
nullptr) {
300 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
302 sum_loss += XentLoss(label_[i], score[i]);
305 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
307 sum_loss += XentLoss(label_[i], score[i]) * weights_[i];
311 if (weights_ ==
nullptr) {
312 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
315 objective->ConvertOutput(&score[i], &p);
316 sum_loss += XentLoss(label_[i], p);
319 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
322 objective->ConvertOutput(&score[i], &p);
323 sum_loss += XentLoss(label_[i], p) * weights_[i];
327 double loss = presum_label_entropy_ + sum_loss / sum_weights_;
328 return std::vector<double>(1, loss);
331 const std::vector<std::string>& GetName()
const override {
335 double factor_to_bigger_better()
const override {
349 double presum_label_entropy_;
351 std::vector<std::string> name_;
Definition xentropy_metric.hpp:162
void Init(const Metadata &metadata, data_size_t num_data) override
Initialize.
Definition xentropy_metric.hpp:167
std::vector< double > Eval(const double *score, const ObjectiveFunction *objective) const override
Calcaluting and printing metric result.
Definition xentropy_metric.hpp:187
Definition xentropy_metric.hpp:67
std::vector< double > Eval(const double *score, const ObjectiveFunction *objective) const override
Calcaluting and printing metric result.
Definition xentropy_metric.hpp:102
void Init(const Metadata &metadata, data_size_t num_data) override
Initialize.
Definition xentropy_metric.hpp:72
Definition xentropy_metric.hpp:245
std::vector< double > Eval(const double *score, const ObjectiveFunction *objective) const override
Calcaluting and printing metric result.
Definition xentropy_metric.hpp:296
void Init(const Metadata &metadata, data_size_t num_data) override
Initialize.
Definition xentropy_metric.hpp:250
The interface of metric. Metric is used to calculate metric result.
Definition metric.h:20
The interface of Objective Function.
Definition objective_function.h:13
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float label_t
Type of metadata, include weight and label.
Definition meta.h:33
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14