Medial Code Documentation
Loading...
Searching...
No Matches
test_survival_metric.h
1
4#pragma once
5#include <gtest/gtest.h>
6
7#include <cmath>
8
9#include "../../../src/common/survival_util.h"
10#include "../helpers.h"
11#include "xgboost/metric.h"
12
13namespace xgboost {
14namespace common {
15inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) {
16 auto ctx = MakeCUDACtx(device);
17 std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &ctx)};
18 metric->Configure(Args{});
19
20 HostDeviceVector<float> predts;
21 auto p_fmat = EmptyDMatrix();
22 MetaInfo& info = p_fmat->Info();
23 auto &h_predts = predts.HostVector();
24
25 SimpleLCG lcg;
26 SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
27
28 size_t n_samples = 2048;
29 h_predts.resize(n_samples);
30
31 for (size_t i = 0; i < n_samples; ++i) {
32 h_predts[i] = dist(&lcg);
33 }
34
35 auto &h_upper = info.labels_upper_bound_.HostVector();
36 auto &h_lower = info.labels_lower_bound_.HostVector();
37 h_lower.resize(n_samples);
38 h_upper.resize(n_samples);
39 for (size_t i = 0; i < n_samples; ++i) {
40 h_lower[i] = 1;
41 h_upper[i] = 10;
42 }
43
44 auto result = metric->Evaluate(predts, p_fmat);
45 for (size_t i = 0; i < 8; ++i) {
46 ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
47 }
48}
49
50inline void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) {
51 auto ctx = MakeCUDACtx(GPUIDX);
52
57 auto p_fmat = EmptyDMatrix();
58 MetaInfo& info = p_fmat->Info();
59 info.num_row_ = 4;
60 info.labels_lower_bound_.HostVector()
61 = { 100.0f, 0.0f, 60.0f, 16.0f };
62 info.labels_upper_bound_.HostVector()
63 = { 100.0f, 20.0f, std::numeric_limits<bst_float>::infinity(), 200.0f };
64 info.weights_.HostVector() = std::vector<bst_float>();
65 info.data_split_mode = data_split_mode;
66 HostDeviceVector<bst_float> preds(4, std::log(64));
67
68 struct TestCase {
69 std::string dist_type;
70 bst_float reference_value;
71 };
72 for (const auto& test_case : std::vector<TestCase>{ {"normal", 2.1508f}, {"logistic", 2.1804f},
73 {"extreme", 2.0706f} }) {
74 std::unique_ptr<Metric> metric(Metric::Create("aft-nloglik", &ctx));
75 metric->Configure({ {"aft_loss_distribution", test_case.dist_type},
76 {"aft_loss_distribution_scale", "1.0"} });
77 EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4);
78 }
79}
80
81inline void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = DataSplitMode::kRow) {
82 auto ctx = MakeCUDACtx(GPUIDX);
83
84 auto p_fmat = EmptyDMatrix();
85 MetaInfo& info = p_fmat->Info();
86 info.num_row_ = 4;
87 info.labels_lower_bound_.HostVector() = { 20.0f, 0.0f, 60.0f, 16.0f };
88 info.labels_upper_bound_.HostVector() = { 80.0f, 20.0f, 80.0f, 200.0f };
89 info.weights_.HostVector() = std::vector<bst_float>();
90 info.data_split_mode = data_split_mode;
91 HostDeviceVector<bst_float> preds(4, std::log(60.0f));
92
93 std::unique_ptr<Metric> metric(Metric::Create("interval-regression-accuracy", &ctx));
94 EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f);
95 info.labels_lower_bound_.HostVector()[2] = 70.0f;
96 EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
97 info.labels_upper_bound_.HostVector()[2] = std::numeric_limits<bst_float>::infinity();
98 EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
99 info.labels_upper_bound_.HostVector()[3] = std::numeric_limits<bst_float>::infinity();
100 EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
101 info.labels_lower_bound_.HostVector()[0] = 70.0f;
102 EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f);
103
104 CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX);
105}
106} // namespace common
107} // namespace xgboost
Definition host_device_vector.h:87
Meta information about dataset, always sit in memory.
Definition data.h:48
HostDeviceVector< bst_float > labels_upper_bound_
upper bound of the label, to be used for survival analysis (censored regression)
Definition data.h:83
HostDeviceVector< bst_float > weights_
weights of each instance, optional
Definition data.h:69
DataSplitMode data_split_mode
data split mode
Definition data.h:62
uint64_t num_row_
number of rows in the data
Definition data.h:54
HostDeviceVector< bst_float > labels_lower_bound_
lower bound of the label, to be used for survival analysis (censored regression)
Definition data.h:79
static Metric * Create(const std::string &name, Context const *ctx)
create a metric according to name.
Definition metric.cc:46
virtual void Configure(const std::vector< std::pair< std::string, std::string > > &)
Configure the Metric with the specified parameters.
Definition metric.h:38
void VerifyAFTNegLogLik(DataSplitMode data_split_mode=DataSplitMode::kRow)
Definition test_survival_metric.h:50
namespace of xgboost
Definition base.h:90
Context MakeCUDACtx(std::int32_t device)
Make a context that uses CUDA if device >= 0.
Definition helpers.h:410
float bst_float
float type, used for storing statistics
Definition base.h:97
Copyright 2014-2023 by XGBoost Contributors.