5#include <gtest/gtest.h>
9#include "../../../src/common/survival_util.h"
10#include "../helpers.h"
15inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) {
17 std::unique_ptr<Metric> metric{
Metric::Create(name.c_str(), &ctx)};
20 HostDeviceVector<float> predts;
21 auto p_fmat = EmptyDMatrix();
22 MetaInfo& info = p_fmat->Info();
23 auto &h_predts = predts.HostVector();
26 SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
28 size_t n_samples = 2048;
29 h_predts.resize(n_samples);
31 for (
size_t i = 0; i < n_samples; ++i) {
32 h_predts[i] = dist(&lcg);
35 auto &h_upper = info.labels_upper_bound_.HostVector();
36 auto &h_lower = info.labels_lower_bound_.HostVector();
37 h_lower.resize(n_samples);
38 h_upper.resize(n_samples);
39 for (
size_t i = 0; i < n_samples; ++i) {
44 auto result = metric->Evaluate(predts, p_fmat);
45 for (
size_t i = 0; i < 8; ++i) {
46 ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
57 auto p_fmat = EmptyDMatrix();
61 = { 100.0f, 0.0f, 60.0f, 16.0f };
63 = { 100.0f, 20.0f, std::numeric_limits<bst_float>::infinity(), 200.0f };
64 info.
weights_.HostVector() = std::vector<bst_float>();
69 std::string dist_type;
72 for (
const auto& test_case : std::vector<TestCase>{ {
"normal", 2.1508f}, {
"logistic", 2.1804f},
73 {
"extreme", 2.0706f} }) {
74 std::unique_ptr<Metric> metric(
Metric::Create(
"aft-nloglik", &ctx));
75 metric->Configure({ {
"aft_loss_distribution", test_case.dist_type},
76 {
"aft_loss_distribution_scale",
"1.0"} });
77 EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4);
81inline void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = DataSplitMode::kRow) {
84 auto p_fmat = EmptyDMatrix();
89 info.
weights_.HostVector() = std::vector<bst_float>();
91 HostDeviceVector<bst_float> preds(4, std::log(60.0f));
93 std::unique_ptr<Metric> metric(
Metric::Create(
"interval-regression-accuracy", &ctx));
94 EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f);
96 EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
98 EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
100 EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
102 EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f);
104 CheckDeterministicMetricElementWise(StringView{
"interval-regression-accuracy"}, GPUIDX);
Definition host_device_vector.h:87
static Metric * Create(const std::string &name, Context const *ctx)
create a metric according to name.
Definition metric.cc:46
virtual void Configure(const std::vector< std::pair< std::string, std::string > > &)
Configure the Metric with the specified parameters.
Definition metric.h:38
void VerifyAFTNegLogLik(DataSplitMode data_split_mode=DataSplitMode::kRow)
Definition test_survival_metric.h:50
namespace of xgboost
Definition base.h:90
Context MakeCUDACtx(std::int32_t device)
Make a context that uses CUDA if device >= 0.
Definition helpers.h:410
float bst_float
float type, used for storing statistics
Definition base.h:97
Copyright 2014-2023 by XGBoost Contributors.