Medial Code Documentation
Loading...
Searching...
No Matches
rank_metric.hpp
1#ifndef LIGHTGBM_METRIC_RANK_METRIC_HPP_
2#define LIGHTGBM_METRIC_RANK_METRIC_HPP_
3
4#include <LightGBM/metric.h>
5
6#include <LightGBM/utils/common.h>
7#include <LightGBM/utils/log.h>
8#include <LightGBM/utils/openmp_wrapper.h>
9
10#include <sstream>
11#include <vector>
12
13namespace LightGBM {
14
15class NDCGMetric:public Metric {
16public:
17 explicit NDCGMetric(const Config& config) {
18 // get eval position
19 eval_at_ = config.eval_at;
20 auto label_gain = config.label_gain;
21 DCGCalculator::DefaultEvalAt(&eval_at_);
22 DCGCalculator::DefaultLabelGain(&label_gain);
23 // initialize DCG calculator
24 DCGCalculator::Init(label_gain);
25 // get number of threads
26 #pragma omp parallel
27 #pragma omp master
28 {
29 num_threads_ = omp_get_num_threads();
30 }
31 }
32
33 ~NDCGMetric() {
34 }
35 void Init(const Metadata& metadata, data_size_t num_data) override {
36 for (auto k : eval_at_) {
37 name_.emplace_back(std::string("ndcg@") + std::to_string(k));
38 }
39 num_data_ = num_data;
40 // get label
41 label_ = metadata.label();
42 DCGCalculator::CheckLabel(label_, num_data_);
43 // get query boundaries
44 query_boundaries_ = metadata.query_boundaries();
45 if (query_boundaries_ == nullptr) {
46 Log::Fatal("The NDCG metric requires query information");
47 }
48 num_queries_ = metadata.num_queries();
49 // get query weights
50 query_weights_ = metadata.query_weights();
51 if (query_weights_ == nullptr) {
52 sum_query_weights_ = static_cast<double>(num_queries_);
53 } else {
54 sum_query_weights_ = 0.0f;
55 for (data_size_t i = 0; i < num_queries_; ++i) {
56 sum_query_weights_ += query_weights_[i];
57 }
58 }
59 inverse_max_dcgs_.resize(num_queries_);
60 // cache the inverse max DCG for all querys, used to calculate NDCG
61 #pragma omp parallel for schedule(static)
62 for (data_size_t i = 0; i < num_queries_; ++i) {
63 inverse_max_dcgs_[i].resize(eval_at_.size(), 0.0f);
64 DCGCalculator::CalMaxDCG(eval_at_, label_ + query_boundaries_[i],
65 query_boundaries_[i + 1] - query_boundaries_[i],
66 &inverse_max_dcgs_[i]);
67 for (size_t j = 0; j < inverse_max_dcgs_[i].size(); ++j) {
68 if (inverse_max_dcgs_[i][j] > 0.0f) {
69 inverse_max_dcgs_[i][j] = 1.0f / inverse_max_dcgs_[i][j];
70 } else {
71 // marking negative for all negative querys.
72 // if one meet this query, it's ndcg will be set as -1.
73 inverse_max_dcgs_[i][j] = -1.0f;
74 }
75 }
76 }
77 }
78
79 const std::vector<std::string>& GetName() const override {
80 return name_;
81 }
82
83 double factor_to_bigger_better() const override {
84 return 1.0f;
85 }
86
87 std::vector<double> Eval(const double* score, const ObjectiveFunction*) const override {
88 // some buffers for multi-threading sum up
89 std::vector<std::vector<double>> result_buffer_;
90 for (int i = 0; i < num_threads_; ++i) {
91 result_buffer_.emplace_back(eval_at_.size(), 0.0f);
92 }
93 std::vector<double> tmp_dcg(eval_at_.size(), 0.0f);
94 if (query_weights_ == nullptr) {
95 #pragma omp parallel for schedule(static) firstprivate(tmp_dcg)
96 for (data_size_t i = 0; i < num_queries_; ++i) {
97 const int tid = omp_get_thread_num();
98 // if all doc in this query are all negative, let its NDCG=1
99 if (inverse_max_dcgs_[i][0] <= 0.0f) {
100 for (size_t j = 0; j < eval_at_.size(); ++j) {
101 result_buffer_[tid][j] += 1.0f;
102 }
103 } else {
104 // calculate DCG
105 DCGCalculator::CalDCG(eval_at_, label_ + query_boundaries_[i],
106 score + query_boundaries_[i],
107 query_boundaries_[i + 1] - query_boundaries_[i], &tmp_dcg);
108 // calculate NDCG
109 for (size_t j = 0; j < eval_at_.size(); ++j) {
110 result_buffer_[tid][j] += tmp_dcg[j] * inverse_max_dcgs_[i][j];
111 }
112 }
113 }
114 } else {
115 #pragma omp parallel for schedule(static) firstprivate(tmp_dcg)
116 for (data_size_t i = 0; i < num_queries_; ++i) {
117 const int tid = omp_get_thread_num();
118 // if all doc in this query are all negative, let its NDCG=1
119 if (inverse_max_dcgs_[i][0] <= 0.0f) {
120 for (size_t j = 0; j < eval_at_.size(); ++j) {
121 result_buffer_[tid][j] += 1.0f;
122 }
123 } else {
124 // calculate DCG
125 DCGCalculator::CalDCG(eval_at_, label_ + query_boundaries_[i],
126 score + query_boundaries_[i],
127 query_boundaries_[i + 1] - query_boundaries_[i], &tmp_dcg);
128 // calculate NDCG
129 for (size_t j = 0; j < eval_at_.size(); ++j) {
130 result_buffer_[tid][j] += tmp_dcg[j] * inverse_max_dcgs_[i][j] * query_weights_[i];
131 }
132 }
133 }
134 }
135 // Get final average NDCG
136 std::vector<double> result(eval_at_.size(), 0.0f);
137 for (size_t j = 0; j < result.size(); ++j) {
138 for (int i = 0; i < num_threads_; ++i) {
139 result[j] += result_buffer_[i][j];
140 }
141 result[j] /= sum_query_weights_;
142 }
143 return result;
144 }
145
146private:
148 data_size_t num_data_;
150 const label_t* label_;
152 std::vector<std::string> name_;
154 const data_size_t* query_boundaries_;
156 data_size_t num_queries_;
158 const label_t* query_weights_;
160 double sum_query_weights_;
162 std::vector<data_size_t> eval_at_;
164 std::vector<std::vector<double>> inverse_max_dcgs_;
166 int num_threads_;
167};
168
169} // namespace LightGBM
170
171#endif // LightGBM_METRIC_RANK_METRIC_HPP_
static void CheckLabel(const label_t *label, data_size_t num_data)
Check the label range for NDCG and lambdarank.
Definition dcg_calculator.cpp:152
static void Init(const std::vector< double > &label_gain)
Initial logic.
Definition dcg_calculator.cpp:40
static void CalMaxDCG(const std::vector< data_size_t > &ks, const label_t *label, data_size_t num_data, std::vector< double > *out)
Calculate the Max DCG score at multi position.
Definition dcg_calculator.cpp:75
static void CalDCG(const std::vector< data_size_t > &ks, const label_t *label, const double *score, data_size_t num_data, std::vector< double > *out)
Calculate the DCG score at multi position.
Definition dcg_calculator.cpp:127
This class is used to store some meta(non-feature) data for training data, e.g. labels,...
Definition dataset.h:36
const label_t * label() const
Get pointer of label.
Definition dataset.h:113
const label_t * query_weights() const
Get weights for queries, if not exists, will return nullptr.
Definition dataset.h:179
const data_size_t * query_boundaries() const
Get data boundaries on queries, if not exists, will return nullptr we assume data will order by query...
Definition dataset.h:161
data_size_t num_queries() const
Get Number of queries.
Definition dataset.h:173
The interface of metric. Metric is used to calculate metric result.
Definition metric.h:20
Definition rank_metric.hpp:15
std::vector< double > Eval(const double *score, const ObjectiveFunction *) const override
Calcaluting and printing metric result.
Definition rank_metric.hpp:87
void Init(const Metadata &metadata, data_size_t num_data) override
Initialize.
Definition rank_metric.hpp:35
The interface of Objective Function.
Definition objective_function.h:13
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float label_t
Type of metadata, include weight and label.
Definition meta.h:33
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14
Definition config.h:27