1#ifndef LIGHTGBM_METRIC_RANK_METRIC_HPP_
2#define LIGHTGBM_METRIC_RANK_METRIC_HPP_
4#include <LightGBM/metric.h>
6#include <LightGBM/utils/common.h>
7#include <LightGBM/utils/log.h>
8#include <LightGBM/utils/openmp_wrapper.h>
19 eval_at_ = config.eval_at;
20 auto label_gain = config.label_gain;
21 DCGCalculator::DefaultEvalAt(&eval_at_);
22 DCGCalculator::DefaultLabelGain(&label_gain);
29 num_threads_ = omp_get_num_threads();
36 for (
auto k : eval_at_) {
37 name_.emplace_back(std::string(
"ndcg@") + std::to_string(k));
41 label_ = metadata.
label();
45 if (query_boundaries_ ==
nullptr) {
46 Log::Fatal(
"The NDCG metric requires query information");
51 if (query_weights_ ==
nullptr) {
52 sum_query_weights_ =
static_cast<double>(num_queries_);
54 sum_query_weights_ = 0.0f;
56 sum_query_weights_ += query_weights_[i];
59 inverse_max_dcgs_.resize(num_queries_);
61 #pragma omp parallel for schedule(static)
63 inverse_max_dcgs_[i].resize(eval_at_.size(), 0.0f);
65 query_boundaries_[i + 1] - query_boundaries_[i],
66 &inverse_max_dcgs_[i]);
67 for (
size_t j = 0; j < inverse_max_dcgs_[i].size(); ++j) {
68 if (inverse_max_dcgs_[i][j] > 0.0f) {
69 inverse_max_dcgs_[i][j] = 1.0f / inverse_max_dcgs_[i][j];
73 inverse_max_dcgs_[i][j] = -1.0f;
79 const std::vector<std::string>& GetName()
const override {
83 double factor_to_bigger_better()
const override {
89 std::vector<std::vector<double>> result_buffer_;
90 for (
int i = 0; i < num_threads_; ++i) {
91 result_buffer_.emplace_back(eval_at_.size(), 0.0f);
93 std::vector<double> tmp_dcg(eval_at_.size(), 0.0f);
94 if (query_weights_ ==
nullptr) {
95 #pragma omp parallel for schedule(static) firstprivate(tmp_dcg)
97 const int tid = omp_get_thread_num();
99 if (inverse_max_dcgs_[i][0] <= 0.0f) {
100 for (
size_t j = 0; j < eval_at_.size(); ++j) {
101 result_buffer_[tid][j] += 1.0f;
106 score + query_boundaries_[i],
107 query_boundaries_[i + 1] - query_boundaries_[i], &tmp_dcg);
109 for (
size_t j = 0; j < eval_at_.size(); ++j) {
110 result_buffer_[tid][j] += tmp_dcg[j] * inverse_max_dcgs_[i][j];
115 #pragma omp parallel for schedule(static) firstprivate(tmp_dcg)
117 const int tid = omp_get_thread_num();
119 if (inverse_max_dcgs_[i][0] <= 0.0f) {
120 for (
size_t j = 0; j < eval_at_.size(); ++j) {
121 result_buffer_[tid][j] += 1.0f;
126 score + query_boundaries_[i],
127 query_boundaries_[i + 1] - query_boundaries_[i], &tmp_dcg);
129 for (
size_t j = 0; j < eval_at_.size(); ++j) {
130 result_buffer_[tid][j] += tmp_dcg[j] * inverse_max_dcgs_[i][j] * query_weights_[i];
136 std::vector<double> result(eval_at_.size(), 0.0f);
137 for (
size_t j = 0; j < result.size(); ++j) {
138 for (
int i = 0; i < num_threads_; ++i) {
139 result[j] += result_buffer_[i][j];
141 result[j] /= sum_query_weights_;
152 std::vector<std::string> name_;
160 double sum_query_weights_;
162 std::vector<data_size_t> eval_at_;
164 std::vector<std::vector<double>> inverse_max_dcgs_;
static void CheckLabel(const label_t *label, data_size_t num_data)
Check the label range for NDCG and lambdarank.
Definition dcg_calculator.cpp:152
static void Init(const std::vector< double > &label_gain)
Initial logic.
Definition dcg_calculator.cpp:40
static void CalMaxDCG(const std::vector< data_size_t > &ks, const label_t *label, data_size_t num_data, std::vector< double > *out)
Calculate the Max DCG score at multi position.
Definition dcg_calculator.cpp:75
static void CalDCG(const std::vector< data_size_t > &ks, const label_t *label, const double *score, data_size_t num_data, std::vector< double > *out)
Calculate the DCG score at multi position.
Definition dcg_calculator.cpp:127
The interface of metric. Metric is used to calculate metric result.
Definition metric.h:20
Definition rank_metric.hpp:15
std::vector< double > Eval(const double *score, const ObjectiveFunction *) const override
Calcaluting and printing metric result.
Definition rank_metric.hpp:87
void Init(const Metadata &metadata, data_size_t num_data) override
Initialize.
Definition rank_metric.hpp:35
The interface of Objective Function.
Definition objective_function.h:13
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
float label_t
Type of metadata, include weight and label.
Definition meta.h:33
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14