Medial Code Documentation
Loading...
Searching...
No Matches
survival_util.h
Go to the documentation of this file.
1
8#ifndef XGBOOST_COMMON_SURVIVAL_UTIL_H_
9#define XGBOOST_COMMON_SURVIVAL_UTIL_H_
10
11/*
12 * For the derivation of the loss, gradient, and hessian for the Accelerated Failure Time model,
13 * refer to the paper "Survival regression with accelerated failure time model in XGBoost"
14 * at https://arxiv.org/abs/2006.04920.
15 */
16
17#include <xgboost/parameter.h>
18#include <memory>
19#include <algorithm>
20#include <limits>
22
24
25namespace xgboost {
26namespace common {
27
28#ifndef __CUDACC__
29
30using std::log;
31using std::fmax;
32
33#endif // __CUDACC__
34
35enum class CensoringType : uint8_t {
36 kUncensored, kRightCensored, kLeftCensored, kIntervalCensored
37};
38
39namespace aft {
40
41// Allowable range for gradient and hessian. Used for regularization
42constexpr double kMinGradient = -15.0;
43constexpr double kMaxGradient = 15.0;
44constexpr double kMinHessian = 1e-16; // Ensure that no data point gets zero hessian
45constexpr double kMaxHessian = 15.0;
46
47constexpr double kEps = 1e-12; // A denominator in a fraction should not be too small
48
49// Clip (limit) x to fit range [x_min, x_max].
50// If x < x_min, return x_min; if x > x_max, return x_max; if x_min <= x <= x_max, return x.
51// This function assumes x_min < x_max; behavior is undefined if this assumption does not hold.
53inline double Clip(double x, double x_min, double x_max) {
54 if (x < x_min) {
55 return x_min;
56 }
57 if (x > x_max) {
58 return x_max;
59 }
60 return x;
61}
62
63template<typename Distribution>
64XGBOOST_DEVICE inline double
65GetLimitGradAtInfPred(CensoringType censor_type, bool sign, double sigma);
66
67template<typename Distribution>
68XGBOOST_DEVICE inline double
69GetLimitHessAtInfPred(CensoringType censor_type, bool sign, double sigma);
70
71} // namespace aft
72
74struct AFTParam : public XGBoostParameter<AFTParam> {
79 DMLC_DECLARE_PARAMETER(AFTParam) {
80 DMLC_DECLARE_FIELD(aft_loss_distribution)
81 .set_default(ProbabilityDistributionType::kNormal)
82 .add_enum("normal", ProbabilityDistributionType::kNormal)
83 .add_enum("logistic", ProbabilityDistributionType::kLogistic)
84 .add_enum("extreme", ProbabilityDistributionType::kExtreme)
85 .describe("Choice of distribution for the noise term in "
86 "Accelerated Failure Time model");
87 DMLC_DECLARE_FIELD(aft_loss_distribution_scale)
88 .set_default(1.0f)
89 .describe("Scaling factor used to scale the distribution in "
90 "Accelerated Failure Time model");
91 }
92};
93
95template<typename Distribution>
96struct AFTLoss {
97 XGBOOST_DEVICE inline static
98 double Loss(double y_lower, double y_upper, double y_pred, double sigma) {
99 const double log_y_lower = log(y_lower);
100 const double log_y_upper = log(y_upper);
101
102 double cost;
103
104 if (y_lower == y_upper) { // uncensored
105 const double z = (log_y_lower - y_pred) / sigma;
106 const double pdf = Distribution::PDF(z);
107 // Regularize the denominator with eps, to avoid INF or NAN
108 cost = -log(fmax(pdf / (sigma * y_lower), aft::kEps));
109 } else { // censored; now check what type of censorship we have
110 double z_u, z_l, cdf_u, cdf_l;
111 if (isinf(y_upper)) { // right-censored
112 cdf_u = 1;
113 } else { // left-censored or interval-censored
114 z_u = (log_y_upper - y_pred) / sigma;
115 cdf_u = Distribution::CDF(z_u);
116 }
117 if (y_lower <= 0.0) { // left-censored
118 cdf_l = 0;
119 } else { // right-censored or interval-censored
120 z_l = (log_y_lower - y_pred) / sigma;
121 cdf_l = Distribution::CDF(z_l);
122 }
123 // Regularize the denominator with eps, to avoid INF or NAN
124 cost = -log(fmax(cdf_u - cdf_l, aft::kEps));
125 }
126
127 return cost;
128 }
129
130 XGBOOST_DEVICE inline static
131 double Gradient(double y_lower, double y_upper, double y_pred, double sigma) {
132 const double log_y_lower = log(y_lower);
133 const double log_y_upper = log(y_upper);
134 double numerator, denominator, gradient; // numerator and denominator of gradient
135 CensoringType censor_type;
136 bool z_sign; // sign of z-score
137
138 if (y_lower == y_upper) { // uncensored
139 const double z = (log_y_lower - y_pred) / sigma;
140 const double pdf = Distribution::PDF(z);
141 const double grad_pdf = Distribution::GradPDF(z);
142 censor_type = CensoringType::kUncensored;
143 numerator = grad_pdf;
144 denominator = sigma * pdf;
145 z_sign = (z > 0);
146 } else { // censored; now check what type of censorship we have
147 double z_u = 0.0, z_l = 0.0, pdf_u, pdf_l, cdf_u, cdf_l;
148 censor_type = CensoringType::kIntervalCensored;
149 if (isinf(y_upper)) { // right-censored
150 pdf_u = 0;
151 cdf_u = 1;
152 censor_type = CensoringType::kRightCensored;
153 } else { // interval-censored or left-censored
154 z_u = (log_y_upper - y_pred) / sigma;
155 pdf_u = Distribution::PDF(z_u);
156 cdf_u = Distribution::CDF(z_u);
157 }
158 if (y_lower <= 0.0) { // left-censored
159 pdf_l = 0;
160 cdf_l = 0;
161 censor_type = CensoringType::kLeftCensored;
162 } else { // interval-censored or right-censored
163 z_l = (log_y_lower - y_pred) / sigma;
164 pdf_l = Distribution::PDF(z_l);
165 cdf_l = Distribution::CDF(z_l);
166 }
167 z_sign = (z_u > 0 || z_l > 0);
168 numerator = pdf_u - pdf_l;
169 denominator = sigma * (cdf_u - cdf_l);
170 }
171 gradient = numerator / denominator;
172 if (denominator < aft::kEps && (isnan(gradient) || isinf(gradient))) {
173 gradient = aft::GetLimitGradAtInfPred<Distribution>(censor_type, z_sign, sigma);
174 }
175
176 return aft::Clip(gradient, aft::kMinGradient, aft::kMaxGradient);
177 }
178
179 XGBOOST_DEVICE inline static
180 double Hessian(double y_lower, double y_upper, double y_pred, double sigma) {
181 const double log_y_lower = log(y_lower);
182 const double log_y_upper = log(y_upper);
183 double numerator, denominator, hessian; // numerator and denominator of hessian
184 CensoringType censor_type;
185 bool z_sign; // sign of z-score
186
187 if (y_lower == y_upper) { // uncensored
188 const double z = (log_y_lower - y_pred) / sigma;
189 const double pdf = Distribution::PDF(z);
190 const double grad_pdf = Distribution::GradPDF(z);
191 const double hess_pdf = Distribution::HessPDF(z);
192 censor_type = CensoringType::kUncensored;
193 numerator = -(pdf * hess_pdf - grad_pdf * grad_pdf);
194 denominator = sigma * sigma * pdf * pdf;
195 z_sign = (z > 0);
196 } else { // censored; now check what type of censorship we have
197 double z_u = 0.0, z_l = 0.0, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l;
198 censor_type = CensoringType::kIntervalCensored;
199 if (isinf(y_upper)) { // right-censored
200 pdf_u = 0;
201 cdf_u = 1;
202 grad_pdf_u = 0;
203 censor_type = CensoringType::kRightCensored;
204 } else { // interval-censored or left-censored
205 z_u = (log_y_upper - y_pred) / sigma;
206 pdf_u = Distribution::PDF(z_u);
207 cdf_u = Distribution::CDF(z_u);
208 grad_pdf_u = Distribution::GradPDF(z_u);
209 }
210 if (y_lower <= 0.0) { // left-censored
211 pdf_l = 0;
212 cdf_l = 0;
213 grad_pdf_l = 0;
214 censor_type = CensoringType::kLeftCensored;
215 } else { // interval-censored or right-censored
216 z_l = (log_y_lower - y_pred) / sigma;
217 pdf_l = Distribution::PDF(z_l);
218 cdf_l = Distribution::CDF(z_l);
219 grad_pdf_l = Distribution::GradPDF(z_l);
220 }
221 const double cdf_diff = cdf_u - cdf_l;
222 const double pdf_diff = pdf_u - pdf_l;
223 const double grad_diff = grad_pdf_u - grad_pdf_l;
224 const double sqrt_denominator = sigma * cdf_diff;
225 z_sign = (z_u > 0 || z_l > 0);
226 numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff);
227 denominator = sqrt_denominator * sqrt_denominator;
228 }
229 hessian = numerator / denominator;
230 if (denominator < aft::kEps && (isnan(hessian) || isinf(hessian))) {
231 hessian = aft::GetLimitHessAtInfPred<Distribution>(censor_type, z_sign, sigma);
232 }
233
234 return aft::Clip(hessian, aft::kMinHessian, aft::kMaxHessian);
235 }
236};
237
238namespace aft {
239
240template <>
241XGBOOST_DEVICE inline double
242GetLimitGradAtInfPred<NormalDistribution>(CensoringType censor_type, bool sign, double sigma) {
243 // Remove unused parameter compiler warning.
244 (void) sigma;
245
246 switch (censor_type) {
247 case CensoringType::kUncensored:
248 return sign ? kMinGradient : kMaxGradient;
249 case CensoringType::kRightCensored:
250 return sign ? kMinGradient : 0.0;
251 case CensoringType::kLeftCensored:
252 return sign ? 0.0 : kMaxGradient;
253 case CensoringType::kIntervalCensored:
254 return sign ? kMinGradient : kMaxGradient;
255 }
256 return std::numeric_limits<double>::quiet_NaN();
257}
258
259template <>
260XGBOOST_DEVICE inline double
261GetLimitHessAtInfPred<NormalDistribution>(CensoringType censor_type, bool sign, double sigma) {
262 switch (censor_type) {
263 case CensoringType::kUncensored:
264 return 1.0 / (sigma * sigma);
265 case CensoringType::kRightCensored:
266 return sign ? (1.0 / (sigma * sigma)) : kMinHessian;
267 case CensoringType::kLeftCensored:
268 return sign ? kMinHessian : (1.0 / (sigma * sigma));
269 case CensoringType::kIntervalCensored:
270 return 1.0 / (sigma * sigma);
271 }
272 return std::numeric_limits<double>::quiet_NaN();
273}
274
275template <>
276XGBOOST_DEVICE inline double
277GetLimitGradAtInfPred<LogisticDistribution>(CensoringType censor_type, bool sign, double sigma) {
278 switch (censor_type) {
279 case CensoringType::kUncensored:
280 return sign ? (-1.0 / sigma) : (1.0 / sigma);
281 case CensoringType::kRightCensored:
282 return sign ? (-1.0 / sigma) : 0.0;
283 case CensoringType::kLeftCensored:
284 return sign ? 0.0 : (1.0 / sigma);
285 case CensoringType::kIntervalCensored:
286 return sign ? (-1.0 / sigma) : (1.0 / sigma);
287 }
288 return std::numeric_limits<double>::quiet_NaN();
289}
290
291template <>
292XGBOOST_DEVICE inline double
293GetLimitHessAtInfPred<LogisticDistribution>(CensoringType censor_type, bool sign, double sigma) {
294 // Remove unused parameter compiler warning.
295 (void) sign;
296 (void) sigma;
297
298 switch (censor_type) {
299 case CensoringType::kUncensored:
300 case CensoringType::kRightCensored:
301 case CensoringType::kLeftCensored:
302 case CensoringType::kIntervalCensored:
303 return kMinHessian;
304 }
305 return std::numeric_limits<double>::quiet_NaN();
306}
307
308template <>
309XGBOOST_DEVICE inline double
310GetLimitGradAtInfPred<ExtremeDistribution>(CensoringType censor_type, bool sign, double sigma) {
311 switch (censor_type) {
312 case CensoringType::kUncensored:
313 return sign ? kMinGradient : (1.0 / sigma);
314 case CensoringType::kRightCensored:
315 return sign ? kMinGradient : 0.0;
316 case CensoringType::kLeftCensored:
317 return sign ? 0.0 : (1.0 / sigma);
318 case CensoringType::kIntervalCensored:
319 return sign ? kMinGradient : (1.0 / sigma);
320 }
321 return std::numeric_limits<double>::quiet_NaN();
322}
323
324template <>
325XGBOOST_DEVICE inline double
326GetLimitHessAtInfPred<ExtremeDistribution>(CensoringType censor_type, bool sign, double sigma) {
327 // Remove unused parameter compiler warning.
328 (void) sigma;
329
330 switch (censor_type) {
331 case CensoringType::kUncensored:
332 case CensoringType::kRightCensored:
333 return sign ? kMaxHessian : kMinHessian;
334 case CensoringType::kLeftCensored:
335 return kMinHessian;
336 case CensoringType::kIntervalCensored:
337 return sign ? kMaxHessian : kMinHessian;
338 }
339 return std::numeric_limits<double>::quiet_NaN();
340}
341
342} // namespace aft
343
344} // namespace common
345} // namespace xgboost
346
347#endif // XGBOOST_COMMON_SURVIVAL_UTIL_H_
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition base.h:64
macro for using C++11 enum class as DMLC parameter
#define DECLARE_FIELD_ENUM_CLASS(EnumClass)
Specialization of FieldEntry for enum class (backed by int)
Definition parameter.h:50
ProbabilityDistributionType
Enum encoding possible choices of probability distribution.
Definition probability_distribution.h:31
namespace of xgboost
Definition base.h:90
Implementation of a few useful probability distributions.
Definition parameter.h:84
The AFT loss function.
Definition survival_util.h:96
Parameter structure for AFT loss and metric.
Definition survival_util.h:74
ProbabilityDistributionType aft_loss_distribution
Choice of probability distribution for the noise term in AFT.
Definition survival_util.h:76
float aft_loss_distribution_scale
Scaling factor to be applied to the distribution.
Definition survival_util.h:78