Medial Code Documentation
Loading...
Searching...
No Matches
multiclass_metric.hpp
1#ifndef LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_
2#define LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_
3
4#include <LightGBM/metric.h>
5
6#include <LightGBM/utils/log.h>
7
8#include <cmath>
9
10namespace LightGBM {
15template<typename PointWiseLossCalculator>
16class MulticlassMetric: public Metric {
17public:
18 explicit MulticlassMetric(const Config& config) {
19 num_class_ = config.num_class;
20 }
21
22 virtual ~MulticlassMetric() {
23 }
24
25 void Init(const Metadata& metadata, data_size_t num_data) override {
26 name_.emplace_back(PointWiseLossCalculator::Name());
27 num_data_ = num_data;
28 // get label
29 label_ = metadata.label();
30 // get weights
31 weights_ = metadata.weights();
32 if (weights_ == nullptr) {
33 sum_weights_ = static_cast<double>(num_data_);
34 } else {
35 sum_weights_ = 0.0f;
36 for (data_size_t i = 0; i < num_data_; ++i) {
37 sum_weights_ += weights_[i];
38 }
39 }
40 }
41
42 const std::vector<std::string>& GetName() const override {
43 return name_;
44 }
45
46 double factor_to_bigger_better() const override {
47 return -1.0f;
48 }
49
50 std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const override {
51 double sum_loss = 0.0;
52 int num_tree_per_iteration = num_class_;
53 int num_pred_per_row = num_class_;
54 if (objective != nullptr) {
55 num_tree_per_iteration = objective->NumModelPerIteration();
56 num_pred_per_row = objective->NumPredictOneRow();
57 }
58 if (objective != nullptr) {
59 if (weights_ == nullptr) {
60 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
61 for (data_size_t i = 0; i < num_data_; ++i) {
62 std::vector<double> raw_score(num_tree_per_iteration);
63 for (int k = 0; k < num_tree_per_iteration; ++k) {
64 size_t idx = static_cast<size_t>(num_data_) * k + i;
65 raw_score[k] = static_cast<double>(score[idx]);
66 }
67 std::vector<double> rec(num_pred_per_row);
68 objective->ConvertOutput(raw_score.data(), rec.data());
69 // add loss
70 sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec);
71 }
72 } else {
73 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
74 for (data_size_t i = 0; i < num_data_; ++i) {
75 std::vector<double> raw_score(num_tree_per_iteration);
76 for (int k = 0; k < num_tree_per_iteration; ++k) {
77 size_t idx = static_cast<size_t>(num_data_) * k + i;
78 raw_score[k] = static_cast<double>(score[idx]);
79 }
80 std::vector<double> rec(num_pred_per_row);
81 objective->ConvertOutput(raw_score.data(), rec.data());
82 // add loss
83 sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i];
84 }
85 }
86 } else {
87 if (weights_ == nullptr) {
88 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
89 for (data_size_t i = 0; i < num_data_; ++i) {
90 std::vector<double> rec(num_tree_per_iteration);
91 for (int k = 0; k < num_tree_per_iteration; ++k) {
92 size_t idx = static_cast<size_t>(num_data_) * k + i;
93 rec[k] = static_cast<double>(score[idx]);
94 }
95 // add loss
96 sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec);
97 }
98 } else {
99 #pragma omp parallel for schedule(static) reduction(+:sum_loss)
100 for (data_size_t i = 0; i < num_data_; ++i) {
101 std::vector<double> rec(num_tree_per_iteration);
102 for (int k = 0; k < num_tree_per_iteration; ++k) {
103 size_t idx = static_cast<size_t>(num_data_) * k + i;
104 rec[k] = static_cast<double>(score[idx]);
105 }
106 // add loss
107 sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i];
108 }
109 }
110 }
111 double loss = sum_loss / sum_weights_;
112 return std::vector<double>(1, loss);
113 }
114
115private:
117 data_size_t num_data_;
119 const label_t* label_;
121 const label_t* weights_;
123 double sum_weights_;
125 std::vector<std::string> name_;
126 int num_class_;
127};
128
130class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
131public:
132 explicit MultiErrorMetric(const Config& config) :MulticlassMetric<MultiErrorMetric>(config) {}
133
134 inline static double LossOnPoint(label_t label, std::vector<double>& score) {
135 size_t k = static_cast<size_t>(label);
136 for (size_t i = 0; i < score.size(); ++i) {
137 if (i != k && score[i] >= score[k]) {
138 return 1.0f;
139 }
140 }
141 return 0.0f;
142 }
143
144 inline static const char* Name() {
145 return "multi_error";
146 }
147};
148
150class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetric> {
151public:
153
154 inline static double LossOnPoint(label_t label, std::vector<double>& score) {
155 size_t k = static_cast<size_t>(label);
156 if (score[k] > kEpsilon) {
157 return static_cast<double>(-std::log(score[k]));
158 } else {
159 return -std::log(kEpsilon);
160 }
161 }
162
163 inline static const char* Name() {
164 return "multi_logloss";
165 }
166};
167
168} // namespace LightGBM
169#endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_
This class is used to store some meta(non-feature) data for training data, e.g. labels,...
Definition dataset.h:36
const label_t * label() const
Get pointer of label.
Definition dataset.h:113
const label_t * weights() const
Get weights, if not exists, will return nullptr.
Definition dataset.h:146
The interface of metric. Metric is used to calculate metric result.
Definition metric.h:20
L2 loss for multiclass task.
Definition multiclass_metric.hpp:130
Logloss for multiclass task.
Definition multiclass_metric.hpp:150
Metric for multiclass task. Use static class "PointWiseLossCalculator" to calculate loss point-wise.
Definition multiclass_metric.hpp:16
void Init(const Metadata &metadata, data_size_t num_data) override
Initialize.
Definition multiclass_metric.hpp:25
std::vector< double > Eval(const double *score, const ObjectiveFunction *objective) const override
Calcaluting and printing metric result.
Definition multiclass_metric.hpp:50
The interface of Objective Function.
Definition objective_function.h:13
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float label_t
Type of metadata, include weight and label.
Definition meta.h:33
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14
Definition config.h:27