Jiarui Xu | @May 6, 2023
This blog post provides an introduction to proportional hazards and partial likelihood in survival analysis. Survival analysis is a statistical method used to analyze the time it takes for an event of interest (such as death or failure of a product) to occur. The blog post also covers Breslow's method for handling tied event times and includes code in JAX for computing the negative log likelihood, gradient, and hessian.
My favorite resources for this topic are included in the References section. The full code with JAX is available at https://github.com/jxucoder/survival_jax. JAX is an open-source Python library for high-performance numerical computing, particularly for machine learning and scientific computing. It was developed by Google Brain and released in 2018.
Credit to ChatGPT for assisting with the writing and implementation. Next up: parametric methods for survival analysis with JAX.
Proportional Hazards
Proportional hazards is a concept used in survival analysis, which is a statistical method used to analyze the time it takes for an event of interest (such as death or failure of a product) to occur.
In survival analysis, proportional hazards means that the hazard rate (the probability of the event occurring at a particular time, given that it has not already occurred) for two or more groups being compared is proportional over time. This means that the ratio of the hazard rates for two groups is constant over time.
My favorite proportional hazards explanation is by Prof. Mike Marin
Imagine that a researcher is studying the effect of a particular treatment on the survival of patients with a certain disease. They divide the patients into two groups: one group receives the treatment, while the other group receives a placebo. The researcher wants to know whether there is a difference in survival between the two groups.
If the proportional hazards assumption holds, it means that the hazard ratio between the two groups is constant over time. For example, if the hazard ratio is 2, it means that the patients in the treatment group have twice the risk of dying at any given time compared to the placebo group. This hazard ratio remains constant over time, which allows the researcher to use statistical methods such as Cox regression to analyze the data.
Example Dataset
Let’s create a super simple example dataset to illustrate proportional hazards and partial likelihood. We have 5 samples, with the event occurance indicators (1 if failure) and the corresponding times.
Covariates | Times | Events |
2 | 0 | |
2 | 1 | |
4 | 0 | |
5 | 1 | |
7 | 0 |
Cox proportional hazards (PH) model
In the Cox proportional hazards (PH) model, the hazard function is modeled as a product of two parts: a baseline hazard function and a set of covariate effects. The baseline hazard function represents the hazard for an individual with all covariate values set to zero, while the covariate effects represent the change in hazard associated with changes in covariate values.
Mathematically, the hazard function in the Cox PH model is defined as:
Assuming proportional hazards, we use to calculate their hazard rates below:
Covariates | Times | Events | Hazard Rates |
2 | 0 | ||
2 | 1 | ||
4 | 0 | ||
5 | 1 | ||
7 | 0 |
Recall that “if the proportional hazards assumption holds, it means that the hazard ratio between the two groups is constant over time”, we can verify it by checking and , their hazard rates are , respectively. The ratio at is constant:
In fact, this constant still holds when as cancels out.
Partial Likelihood
For now, let’s assume there is no tie in failure times. A tied event time occurs when two or more individuals experience the event at exactly the same time, and can lead to ambiguity in the ordering of the individuals and the estimation of the regression coefficients.
The partial likelihood function is a fundamental concept in the Cox proportional hazards (PH) model that is used to estimate the regression coefficients of the covariates while accounting for censoring in survival data.
The partial likelihood function is defined as the product of the conditional survivor function for each uncensored event, divided by the sum of the conditional survivor functions for all individuals at risk at the time of each event. Mathematically, the partial likelihood function for the Cox PH model is:
In this equation, is the partial likelihood function for the Cox PH model, is the vector of regression coefficients for the covariates, is the vector of covariate values for the th individual, is the total number of events, is the set of individuals at risk at the th event time, and is the indicator variable for whether the th event is censored or not .
The partial likelihood function is maximized with respect to to obtain the estimate of the regression coefficients. The estimates are obtained by solving the partial likelihood score equations or by using an iterative method such as the Newton-Raphson algorithm.
In this dataset, there are two event failure times: 2, 5. At , samples are at risk (meaning there is possibility that they would fail). At , samples are at risk.
failure time | risk set | failed sample | probability of happening |
2 | |||
5 |
The likelihood of the events above happening is calculated by multiplying the failure probabilities together:
This is exactly what formula expresses.
Knowing the likelihood function, we can now estimate to maximize the likelihood. Just like most ML problems, to maximize the likelihood function , we choose to minimize the negative log likelihood .
Before solving , let’s revisit this “no tied events” assumption. What if there are tied events?
What if there are tied events? Breslow's method.
In reality, tied events are common. Breslow's method is an approach for handling tied event times in the Cox PH model. In Breslow's method, tied event times are treated as a single event time, and the contribution of each individual to the partial likelihood function is weighted by the number of individuals at risk just prior to the tied event time. This is in contrast to the Efron approximation, which assigns a fractional weight to each individual at the tied event time.
- : the set of unique failure times
- : is a unique failure time belong to
- : the risk set at
- : one of the samples that are at risk at
- : one of the samples that failed at
- : number of failures at
Taking the log, we get log partial likelihood:
We typically add regularization term too. Let’s formally define the minimization problem with regularization:
- : regularization strength factora hyperparameter in regularization methods, such as ridge regression and lasso regression, that controls the amount of regularization applied to the regression coefficients.
Breslow's method has the advantage of being computationally simpler than the Efron approximation, but it may be less accurate when the number of tied event times is high. In practice, the choice of method for handling tied event times can depend on the specific research question, the sample size, and the amount of tied event times in the data.
It is worth noting that both Breslow's method and the Efron approximation assume that tied event times occur at random, and not due to any underlying structure or relationship among the individuals in the study. If tied event times are not random and are related to unobserved covariates, then alternative approaches may be necessary to address this issue.
Show me the code, in JAX
Let’s define a few key variables:
- : sorted list of unique failure times.
- : a list of failure counts for each unique failure time in . equals number of samples failed at time
- : a matrix of risk set indicators. indicates if sample belong to risk set at unique failure time
- : a matrix of failure indicators. indicates if sample failed at unique failure time
With these variables, we can express the cost function in matrix multiplication form:
import jax.numpy as jnp
@jax.jit
def negative_log_cox_partial_likelihood(w, X, indices, riskset, d, lmbd):
wx = w @ X.T
exp_wx = jnp.exp(wx)
reg_term = lmbd * jnp.linalg.norm(w, ord=1)
return -jnp.sum(indices @ wx - d * jnp.log(riskset @ exp_wx)) + reg_term
Now, thanks to JAX, we get the gradient and hessian for free
import jax.numpy as jnp
from jax import grad, hessian
from scipy.optimize import minimize
from functools import partial
cost_func = partial(negative_log_cox_partial_likelihood, indices=indices, riskset=riskset, d=d, X=covariates,
lmbd=lmbd)
g = grad(cost_func)
h = hessian(cost_func)
result = minimize(cost_func, np.zeros(14), method='Newton-CG', jac=g, hess=h)
betas = result.x
We can check the correctness of estimated betas by lifelines package.
However, estimating beta is not the end of it, we still need to estimate to calculate hazards function. Once hazards function is estimated, we can get Survival function and PDF . We will discuss them in the next post.
References
Ko, Jessica. Solving the Cox proportional hazards model and its applications. Diss. Master’s thesis, EECS Department, University of California, Berkeley, 2017
Kleinbaum, D. G., & Klein, M. (1996). Survival analysis a self-learning text. Springer.
“Survival Analysis Part 9 | Cox Proportional Hazards Model.” www.youtube.com, youtu.be/aETMUW_TWV0. Accessed 8 May 2023.
Survival Analysis: Optimize the Partial Likelihood of the Cox Model. (2022). Retrieved 8 May 2023, from https://towardsdatascience.com/survival-analysis-optimize-the-partial-likelihood-of-the-cox-model-b56b8f112401
GraphPad Software, L. (2023). GraphPad Prism 9 Statistics Guide - The mathematics of the cumulative hazard function. Retrieved 8 May 2023, from https://www.graphpad.com/guides/prism/latest/statistics/stat_cox_math_cumulative_hazard.htm
Ravinutala, S. (2021). [Survival models] Cox proportional hazard model. Retrieved 8 May 2023, from https://sidravi1.github.io/blog/2021/10/11/cox-proportional-hazard-model