Medial Code Documentation
Loading...
Searching...
No Matches
map_metric.hpp
1#ifndef LIGHTGBM_METRIC_MAP_METRIC_HPP_
2#define LIGHTGBM_METRIC_MAP_METRIC_HPP_
3#include <LightGBM/metric.h>
4
5#include <LightGBM/utils/common.h>
6#include <LightGBM/utils/log.h>
7
8#include <LightGBM/utils/openmp_wrapper.h>
9
10#include <sstream>
11#include <vector>
12
13namespace LightGBM {
14
15class MapMetric:public Metric {
16public:
17 explicit MapMetric(const Config& config) {
18 // get eval position
19 eval_at_ = config.eval_at;
20 DCGCalculator::DefaultEvalAt(&eval_at_);
21 // get number of threads
22 #pragma omp parallel
23 #pragma omp master
24 {
25 num_threads_ = omp_get_num_threads();
26 }
27 }
28
29 ~MapMetric() {
30 }
31
32 void Init(const Metadata& metadata, data_size_t num_data) override {
33 for (auto k : eval_at_) {
34 name_.emplace_back(std::string("map@") + std::to_string(k));
35 }
36 num_data_ = num_data;
37 // get label
38 label_ = metadata.label();
39 // get query boundaries
40 query_boundaries_ = metadata.query_boundaries();
41 if (query_boundaries_ == nullptr) {
42 Log::Fatal("For MAP metric, there should be query information");
43 }
44 num_queries_ = metadata.num_queries();
45 Log::Info("Total groups: %d, total data: %d", num_queries_, num_data_);
46 // get query weights
47 query_weights_ = metadata.query_weights();
48 if (query_weights_ == nullptr) {
49 sum_query_weights_ = static_cast<double>(num_queries_);
50 } else {
51 sum_query_weights_ = 0.0f;
52 for (data_size_t i = 0; i < num_queries_; ++i) {
53 sum_query_weights_ += query_weights_[i];
54 }
55 }
56
57 npos_per_query_.resize(num_queries_, 0);
58 for (data_size_t i = 0; i < num_queries_; ++i) {
59 for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
60 if (label_[j] > 0.5f) {
61 ++npos_per_query_[i];
62 }
63 }
64 }
65 }
66
67 const std::vector<std::string>& GetName() const override {
68 return name_;
69 }
70
71 double factor_to_bigger_better() const override {
72 return 1.0f;
73 }
74
75 void CalMapAtK(std::vector<int> ks, data_size_t npos, const label_t* label,
76 const double* score, data_size_t num_data, std::vector<double>* out) const {
77 // get sorted indices by score
78 std::vector<data_size_t> sorted_idx;
79 for (data_size_t i = 0; i < num_data; ++i) {
80 sorted_idx.emplace_back(i);
81 }
82 std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
83 [score](data_size_t a, data_size_t b) {return score[a] > score[b]; });
84
85 int num_hit = 0;
86 double sum_ap = 0.0f;
87 data_size_t cur_left = 0;
88 for (size_t i = 0; i < ks.size(); ++i) {
89 data_size_t cur_k = static_cast<data_size_t>(ks[i]);
90 if (cur_k > num_data) { cur_k = num_data; }
91 for (data_size_t j = cur_left; j < cur_k; ++j) {
92 data_size_t idx = sorted_idx[j];
93 if (label[idx] > 0.5f) {
94 ++num_hit;
95 sum_ap += static_cast<double>(num_hit) / (j + 1.0f);
96 }
97 }
98 if (npos > 0) {
99 (*out)[i] = sum_ap / std::min(npos, cur_k);
100 } else {
101 (*out)[i] = 1.0f;
102 }
103 cur_left = cur_k;
104 }
105 }
106 std::vector<double> Eval(const double* score, const ObjectiveFunction*) const override {
107 // some buffers for multi-threading sum up
108 std::vector<std::vector<double>> result_buffer_;
109 for (int i = 0; i < num_threads_; ++i) {
110 result_buffer_.emplace_back(eval_at_.size(), 0.0f);
111 }
112 std::vector<double> tmp_map(eval_at_.size(), 0.0f);
113 if (query_weights_ == nullptr) {
114 #pragma omp parallel for schedule(guided) firstprivate(tmp_map)
115 for (data_size_t i = 0; i < num_queries_; ++i) {
116 const int tid = omp_get_thread_num();
117 CalMapAtK(eval_at_, npos_per_query_[i], label_ + query_boundaries_[i],
118 score + query_boundaries_[i], query_boundaries_[i + 1] - query_boundaries_[i], &tmp_map);
119 for (size_t j = 0; j < eval_at_.size(); ++j) {
120 result_buffer_[tid][j] += tmp_map[j];
121 }
122 }
123 } else {
124 #pragma omp parallel for schedule(guided) firstprivate(tmp_map)
125 for (data_size_t i = 0; i < num_queries_; ++i) {
126 const int tid = omp_get_thread_num();
127 CalMapAtK(eval_at_, npos_per_query_[i], label_ + query_boundaries_[i],
128 score + query_boundaries_[i], query_boundaries_[i + 1] - query_boundaries_[i], &tmp_map);
129 for (size_t j = 0; j < eval_at_.size(); ++j) {
130 result_buffer_[tid][j] += tmp_map[j] * query_weights_[i];
131 }
132 }
133 }
134 // Get final average MAP
135 std::vector<double> result(eval_at_.size(), 0.0f);
136 for (size_t j = 0; j < result.size(); ++j) {
137 for (int i = 0; i < num_threads_; ++i) {
138 result[j] += result_buffer_[i][j];
139 }
140 result[j] /= sum_query_weights_;
141 }
142 return result;
143 }
144
145private:
147 data_size_t num_data_;
149 const label_t* label_;
151 const data_size_t* query_boundaries_;
153 data_size_t num_queries_;
155 const label_t* query_weights_;
157 double sum_query_weights_;
159 std::vector<data_size_t> eval_at_;
161 int num_threads_;
162 std::vector<std::string> name_;
163 std::vector<data_size_t> npos_per_query_;
164};
165
166} // namespace LightGBM
167
168#endif // LIGHTGBM_METRIC_MAP_METRIC_HPP_
Definition map_metric.hpp:15
void Init(const Metadata &metadata, data_size_t num_data) override
Initialize.
Definition map_metric.hpp:32
std::vector< double > Eval(const double *score, const ObjectiveFunction *) const override
Calcaluting and printing metric result.
Definition map_metric.hpp:106
This class is used to store some meta(non-feature) data for training data, e.g. labels,...
Definition dataset.h:36
const label_t * label() const
Get pointer of label.
Definition dataset.h:113
const label_t * query_weights() const
Get weights for queries, if not exists, will return nullptr.
Definition dataset.h:179
const data_size_t * query_boundaries() const
Get data boundaries on queries, if not exists, will return nullptr we assume data will order by query...
Definition dataset.h:161
data_size_t num_queries() const
Get Number of queries.
Definition dataset.h:173
The interface of metric. Metric is used to calculate metric result.
Definition metric.h:20
The interface of Objective Function.
Definition objective_function.h:13
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float label_t
Type of metadata, include weight and label.
Definition meta.h:33
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14
Definition config.h:27