Medial Code Documentation
Loading...
Searching...
No Matches
binary_metric.hpp
1#ifndef LIGHTGBM_METRIC_BINARY_METRIC_HPP_
2#define LIGHTGBM_METRIC_BINARY_METRIC_HPP_
3
4#include <LightGBM/metric.h>
5
6#include <LightGBM/utils/log.h>
7#include <LightGBM/utils/common.h>
8
9#include <algorithm>
10#include <vector>
11#include <sstream>
12
13namespace LightGBM {
14
19template<typename PointWiseLossCalculator>
20class BinaryMetric: public Metric {
21public:
22 explicit BinaryMetric(const Config&) {
23 }
24
25 virtual ~BinaryMetric() {
26 }
27
28 void Init(const Metadata& metadata, data_size_t num_data) override {
29 name_.emplace_back(PointWiseLossCalculator::Name());
30
31 num_data_ = num_data;
32 // get label
33 label_ = metadata.label();
34
35 // get weights
36 weights_ = metadata.weights();
37
38 if (weights_ == nullptr) {
39 sum_weights_ = static_cast<double>(num_data_);
40 } else {
41 sum_weights_ = 0.0f;
42 for (data_size_t i = 0; i < num_data; ++i) {
43 sum_weights_ += weights_[i];
44 }
45 }
46 }
47
48 const std::vector<std::string>& GetName() const override {
49 return name_;
50 }
51
52 double factor_to_bigger_better() const override {
53 return -1.0f;
54 }
55
56 std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const override {
57 double sum_loss = 0.0f;
58 if (objective == nullptr) {
59 if (weights_ == nullptr) {
60 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
61 for (data_size_t i = 0; i < num_data_; ++i) {
62 // add loss
63 sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]);
64 }
65 } else {
66 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
67 for (data_size_t i = 0; i < num_data_; ++i) {
68 // add loss
69 sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]) * weights_[i];
70 }
71 }
72 } else {
73 if (weights_ == nullptr) {
74 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
75 for (data_size_t i = 0; i < num_data_; ++i) {
76 double prob = 0;
77 objective->ConvertOutput(&score[i], &prob);
78 // add loss
79 sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob);
80 }
81 } else {
82 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
83 for (data_size_t i = 0; i < num_data_; ++i) {
84 double prob = 0;
85 objective->ConvertOutput(&score[i], &prob);
86 // add loss
87 sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob) * weights_[i];
88 }
89 }
90 }
91 double loss = sum_loss / sum_weights_;
92 return std::vector<double>(1, loss);
93 }
94
95private:
97 data_size_t num_data_;
99 const label_t* label_;
101 const label_t* weights_;
103 double sum_weights_;
105 std::vector<std::string> name_;
106};
107
111class BinaryLoglossMetric: public BinaryMetric<BinaryLoglossMetric> {
112public:
113 explicit BinaryLoglossMetric(const Config& config) :BinaryMetric<BinaryLoglossMetric>(config) {}
114
115 inline static double LossOnPoint(label_t label, double prob) {
116 if (label <= 0) {
117 if (1.0f - prob > kEpsilon) {
118 return -std::log(1.0f - prob);
119 }
120 } else {
121 if (prob > kEpsilon) {
122 return -std::log(prob);
123 }
124 }
125 return -std::log(kEpsilon);
126 }
127
128 inline static const char* Name() {
129 return "binary_logloss";
130 }
131};
135class BinaryErrorMetric: public BinaryMetric<BinaryErrorMetric> {
136public:
137 explicit BinaryErrorMetric(const Config& config) :BinaryMetric<BinaryErrorMetric>(config) {}
138
139 inline static double LossOnPoint(label_t label, double prob) {
140 if (prob <= 0.5f) {
141 return label > 0;
142 } else {
143 return label <= 0;
144 }
145 }
146
147 inline static const char* Name() {
148 return "binary_error";
149 }
150};
151
155class AUCMetric: public Metric {
156public:
157 explicit AUCMetric(const Config&) {
158 }
159
160 virtual ~AUCMetric() {
161 }
162
163 const std::vector<std::string>& GetName() const override {
164 return name_;
165 }
166
167 double factor_to_bigger_better() const override {
168 return 1.0f;
169 }
170
171 void Init(const Metadata& metadata, data_size_t num_data) override {
172 name_.emplace_back("auc");
173
174 num_data_ = num_data;
175 // get label
176 label_ = metadata.label();
177 // get weights
178 weights_ = metadata.weights();
179
180 if (weights_ == nullptr) {
181 sum_weights_ = static_cast<double>(num_data_);
182 } else {
183 sum_weights_ = 0.0f;
184 for (data_size_t i = 0; i < num_data; ++i) {
185 sum_weights_ += weights_[i];
186 }
187 }
188 }
189
190 std::vector<double> Eval(const double* score, const ObjectiveFunction*) const override {
191 // get indices sorted by score, descent order
192 std::vector<data_size_t> sorted_idx;
193 for (data_size_t i = 0; i < num_data_; ++i) {
194 sorted_idx.emplace_back(i);
195 }
196 Common::ParallelSort(sorted_idx.begin(), sorted_idx.end(), [score](data_size_t a, data_size_t b) {return score[a] > score[b]; });
197 // temp sum of postive label
198 double cur_pos = 0.0f;
199 // total sum of postive label
200 double sum_pos = 0.0f;
201 // accumlate of auc
202 double accum = 0.0f;
203 // temp sum of negative label
204 double cur_neg = 0.0f;
205 double threshold = score[sorted_idx[0]];
206 if (weights_ == nullptr) { // no weights
207 for (data_size_t i = 0; i < num_data_; ++i) {
208 const label_t cur_label = label_[sorted_idx[i]];
209 const double cur_score = score[sorted_idx[i]];
210 // new threshold
211 if (cur_score != threshold) {
212 threshold = cur_score;
213 // accmulate
214 accum += cur_neg*(cur_pos * 0.5f + sum_pos);
215 sum_pos += cur_pos;
216 // reset
217 cur_neg = cur_pos = 0.0f;
218 }
219 cur_neg += (cur_label <= 0);
220 cur_pos += (cur_label > 0);
221 }
222 } else { // has weights
223 for (data_size_t i = 0; i < num_data_; ++i) {
224 const label_t cur_label = label_[sorted_idx[i]];
225 const double cur_score = score[sorted_idx[i]];
226 const label_t cur_weight = weights_[sorted_idx[i]];
227 // new threshold
228 if (cur_score != threshold) {
229 threshold = cur_score;
230 // accmulate
231 accum += cur_neg*(cur_pos * 0.5f + sum_pos);
232 sum_pos += cur_pos;
233 // reset
234 cur_neg = cur_pos = 0.0f;
235 }
236 cur_neg += (cur_label <= 0)*cur_weight;
237 cur_pos += (cur_label > 0)*cur_weight;
238 }
239 }
240 accum += cur_neg*(cur_pos * 0.5f + sum_pos);
241 sum_pos += cur_pos;
242 double auc = 1.0f;
243 if (sum_pos > 0.0f && sum_pos != sum_weights_) {
244 auc = accum / (sum_pos *(sum_weights_ - sum_pos));
245 }
246 return std::vector<double>(1, auc);
247 }
248
249private:
251 data_size_t num_data_;
253 const label_t* label_;
255 const label_t* weights_;
257 double sum_weights_;
259 std::vector<std::string> name_;
260};
261
262} // namespace LightGBM
263#endif // LightGBM_METRIC_BINARY_METRIC_HPP_
Auc Metric for binary classification task.
Definition binary_metric.hpp:155
void Init(const Metadata &metadata, data_size_t num_data) override
Initialize.
Definition binary_metric.hpp:171
std::vector< double > Eval(const double *score, const ObjectiveFunction *) const override
Calcaluting and printing metric result.
Definition binary_metric.hpp:190
Error rate metric for binary classification task.
Definition binary_metric.hpp:135
Log loss metric for binary classification task.
Definition binary_metric.hpp:111
Metric for binary classification task. Use static class "PointWiseLossCalculator" to calculate loss p...
Definition binary_metric.hpp:20
std::vector< double > Eval(const double *score, const ObjectiveFunction *objective) const override
Calcaluting and printing metric result.
Definition binary_metric.hpp:56
void Init(const Metadata &metadata, data_size_t num_data) override
Initialize.
Definition binary_metric.hpp:28
This class is used to store some meta(non-feature) data for training data, e.g. labels,...
Definition dataset.h:36
const label_t * label() const
Get pointer of label.
Definition dataset.h:113
const label_t * weights() const
Get weights, if not exists, will return nullptr.
Definition dataset.h:146
The interface of metric. Metric is used to calculate metric result.
Definition metric.h:20
The interface of Objective Function.
Definition objective_function.h:13
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float label_t
Type of metadata, include weight and label.
Definition meta.h:33
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14
Definition config.h:27