Revisit Cox Proportional Hazards with JAX

Revisit Cox Proportional Hazards with JAX

Jiarui Xu | @May 6, 2023

This blog post provides an introduction to proportional hazards and partial likelihood in survival analysis. Survival analysis is a statistical method used to analyze the time it takes for an event of interest (such as death or failure of a product) to occur. The blog post also covers Breslow's method for handling tied event times and includes code in JAX for computing the negative log likelihood, gradient, and hessian.

My favorite resources for this topic are included in the References section. The full code with JAX is available at https://github.com/jxucoder/survival_jax. JAX is an open-source Python library for high-performance numerical computing, particularly for machine learning and scientific computing. It was developed by Google Brain and released in 2018.

Credit to ChatGPT for assisting with the writing and implementation. Next up: parametric methods for survival analysis with JAX.

Proportional Hazards

Proportional hazards is a concept used in survival analysis, which is a statistical method used to analyze the time it takes for an event of interest (such as death or failure of a product) to occur.

In survival analysis, proportional hazards means that the hazard rate (the probability of the event occurring at a particular time, given that it has not already occurred) for two or more groups being compared is proportional over time. This means that the ratio of the hazard rates for two groups is constant over time.

My favorite proportional hazards explanation is by Prof. Mike Marin

Imagine that a researcher is studying the effect of a particular treatment on the survival of patients with a certain disease. They divide the patients into two groups: one group receives the treatment, while the other group receives a placebo. The researcher wants to know whether there is a difference in survival between the two groups.

If the proportional hazards assumption holds, it means that the hazard ratio between the two groups is constant over time. For example, if the hazard ratio is 2, it means that the patients in the treatment group have twice the risk of dying at any given time compared to the placebo group. This hazard ratio remains constant over time, which allows the researcher to use statistical methods such as Cox regression to analyze the data.

Example Dataset

Let’s create a super simple example dataset to illustrate proportional hazards and partial likelihood. We have 5 samples, with the event occurance indicators (1 if failure) and the corresponding times.

Covariates
Times
Events
X1X_1
2
0
X2X_2
2
1
X3X_3
4
0
X4X_4
5
1
X5X_5
7
0

Cox proportional hazards (PH) model

In the Cox proportional hazards (PH) model, the hazard function is modeled as a product of two parts: a baseline hazard function and a set of covariate effects. The baseline hazard function represents the hazard for an individual with all covariate values set to zero, while the covariate effects represent the change in hazard associated with changes in covariate values.

Mathematically, the hazard function in the Cox PH model is defined as:

h(tX)=h0(t)eβ1x1+β2x2++βpxph(t|X) = h_0(t) \cdot e^{\beta_1 x_1 + \beta_2 x_2 + \cdots + \beta_p x_p}

Assuming proportional hazards, we use h(tXi)=h0(t)eXiβh(t|X_i) = h_0(t) \cdot e^{X_i'\beta} to calculate their hazard rates below:

Covariates
Times
Events
Hazard Rates
X1X_1
2
0
h0(2)eX1βh_0(2)\cdot e^{X_1'\beta}
X2X_2
2
1
h0(2)eX2βh_0(2)\cdot e^{X_2'\beta}
X3X_3
4
0
h0(4)eX3βh_0(4)\cdot e^{X_3'\beta}
X4X_4
5
1
h0(5)eX4βh_0(5)\cdot e^{X_4'\beta}
X5X_5
7
0
h0(7)eX5βh_0(7)\cdot e^{X_5'\beta}

Recall that “if the proportional hazards assumption holds, it means that the hazard ratio between the two groups is constant over time”, we can verify it by checking X1X_1 and X2X_2, their hazard rates are h0(2)eX1βh_0(2)\cdot e^{X_1'\beta}, h0(2)eX2βh_0(2)\cdot e^{X_2'\beta} respectively. The ratio at t=2t=2 is constant:

h0(2)eX1βh0(2)eX2β=eX1βeX2β=eX1βX2β\frac{h_0(2)\cdot e^{X_1'\beta}}{ h_0(2)\cdot e^{X_2'\beta}} = \frac{e^{X_1'\beta}}{ e^{X_2'\beta}} = e^{X_1'\beta-X_2'\beta}

In fact, this constant still holds when t2t\neq2 as h0(t)h_0(t) cancels out.

Partial Likelihood

For now, let’s assume there is no tie in failure times. A tied event time occurs when two or more individuals experience the event at exactly the same time, and can lead to ambiguity in the ordering of the individuals and the estimation of the regression coefficients.

The partial likelihood function is a fundamental concept in the Cox proportional hazards (PH) model that is used to estimate the regression coefficients of the covariates while accounting for censoring in survival data.

The partial likelihood function is defined as the product of the conditional survivor function for each uncensored event, divided by the sum of the conditional survivor functions for all individuals at risk at the time of each event. Mathematically, the partial likelihood function for the Cox PH model is:

L(β)=i=1n[eXiβjR(ti)eXjβ]δiL(\beta) = \prod_{i=1}^n \left[ \frac{e^{X_i'\beta}}{\sum_{j \in R(t_i)} e^{X_j'\beta}} \right]^{\delta_i}

In this equation, L(β)L(\beta) is the partial likelihood function for the Cox PH model, β\beta is the vector of regression coefficients for the covariates, XiX_i is the vector of covariate values for the iith individual, nn is the total number of events, R(ti)R(t_i) is the set of individuals at risk at the iith event time, and δi\delta_i is the indicator variable for whether the iith event is censored (δi=0)(\delta_i = 0) or not (δi=1)(\delta_i = 1).

The partial likelihood function is maximized with respect to ββ to obtain the estimate of the regression coefficients. The estimates are obtained by solving the partial likelihood score equations or by using an iterative method such as the Newton-Raphson algorithm.

In this dataset, there are two event failure times: 2, 5. At t=2t=2, samples 1,2,3,4,51,2,3,4,5 are at risk (meaning there is possibility that they would fail). At t=5t=5, samples 4,54,5 are at risk.

failure time
risk set
failed sample
probability of happening
2
1,2,3,4,51,2,3,4,5
22
h2(t=2)h1(t=2)+h2(t=2)+h3(t=2)+h4(t=2)+h5(t=2)\frac{h_2(t=2)}{h_1(t=2)+h_2(t=2)+h_3(t=2)+h_4(t=2)+h_5(t=2)}
5
4,54,5
44
h4(t=5)h4(t=5)+h4(t=5)\frac{h_4(t=5)}{h_4(t=5)+h_4(t=5)}

The likelihood of the events above happening is calculated by multiplying the failure probabilities together:

L(β)=h2(t=2)h1(t=2)+h2(t=2)+h3(t=2)+h4(t=2)+h5(t=2)×h4(t=5)h4(t=5)+h5(t=5)=exp(X2β)exp(X1β)+exp(X2β)+exp(X3β)+exp(X4β)+exp(X5β)×exp(X4β)exp(X4β)+exp(X5β)L(\beta)= \frac{h_2(t=2)}{h_1(t=2)+h_2(t=2)+h_3(t=2)+h_4(t=2)+h_5(t=2)} \times \frac{h_4(t=5)}{h_4(t=5)+h_5(t=5)} = \frac{\exp(X_2'\beta)}{\exp(X_1'\beta)+\exp(X_2'\beta)+\exp(X_3'\beta)+\exp(X_4'\beta)+\exp(X_5'\beta)} \times \frac{\exp(X_4'\beta)}{\exp(X_4'\beta)+\exp(X_5'\beta)}

This is exactly what L(β)L(\beta) formula expresses.

Knowing the likelihood function, we can now estimate β\beta to maximize the likelihood. Just like most ML problems, to maximize the likelihood function L(β)L(\beta), we choose to minimize the negative log likelihood (β)-\ell(\beta).

(β)=i=1nδi[ XiβlogjRieXjβ]\ell(\beta)=\sum_{i=1}^n {\delta_i} \left[\ X_i'\beta-\log \sum_{j \in R_i} e^{X_j'\beta}\right]

Before solving β\beta, let’s revisit this “no tied events” assumption. What if there are tied events?

What if there are tied events? Breslow's method.

In reality, tied events are common. Breslow's method is an approach for handling tied event times in the Cox PH model. In Breslow's method, tied event times are treated as a single event time, and the contribution of each individual to the partial likelihood function is weighted by the number of individuals at risk just prior to the tied event time. This is in contrast to the Efron approximation, which assigns a fractional weight to each individual at the tied event time.

L(β)=k=1fesI(k)Xsβ(jRkeXjβ)dkL(\beta) = \prod_{k=1}^f \frac{e^{\sum_{s \in I(k)}X_s'\beta}}{(\sum_{j \in R_k} e^{X_j'\beta})^{d_k}}
  • ff: the set of unique failure times
  • kk: tkt_k is a unique failure time belong to ff
  • RkR_k: the risk set at tkt_k
  • XjX_j: one of the samples that are at risk at tkt_k
  • XsX_s: one of the samples that failed at tkt_k
  • dkd_k: number of failures at tkt_k

Taking the log, we get log partial likelihood:

(β)=k=1f[(sI(k)Xsβ)dklogjRkeXjβ]\ell(\beta)=\sum_{k=1}^f\left[\left(\sum_{s \in I(k)} X_s'\beta\right)-d_k \log \sum_{j \in R_k} e^{X_j'\beta}\right]

We typically add regularization term too. Let’s formally define the minimization problem with L1L_1 regularization:

minβk=1f[(sI(k)Xsβ)dklogjRkeXjβ]+λβ1\min _\beta \sum_{k=1}^f -\left[\left(\sum_{s \in I(k)} X_s'\beta\right)-d_k \log \sum_{j \in R_k} e^{X_j'\beta}\right] + \lambda ||\beta||_1
  • λ\lambda: regularization strength factora hyperparameter in regularization methods, such as ridge regression and lasso regression, that controls the amount of regularization applied to the regression coefficients.

Breslow's method has the advantage of being computationally simpler than the Efron approximation, but it may be less accurate when the number of tied event times is high. In practice, the choice of method for handling tied event times can depend on the specific research question, the sample size, and the amount of tied event times in the data.

It is worth noting that both Breslow's method and the Efron approximation assume that tied event times occur at random, and not due to any underlying structure or relationship among the individuals in the study. If tied event times are not random and are related to unobserved covariates, then alternative approaches may be necessary to address this issue.

Show me the code, in JAX

Let’s define a few key variables:

  • TsortedT_{sorted}: sorted list of unique failure times. tkTsortedt_k \in T_{sorted}
  • dd: a list of failure counts for each unique failure time in TsortedT_{sorted}. dkd_k equals number of samples failed at time tkt_k
  • RR: a matrix Tsorted×n|T_{sorted}| \times n of risk set indicators. Rk,iR_{k, i} indicates if sample ii belong to risk set at unique failure time tkt_k
  • II: a matrix Tsorted×n|T_{sorted}| \times n of failure indicators. Rk,iR_{k, i} indicates if sample ii failed at unique failure time tkt_k

With these variables, we can express the cost function in matrix multiplication form:

import jax.numpy as jnp


@jax.jit
def negative_log_cox_partial_likelihood(w, X, indices, riskset, d, lmbd):
		wx = w @ X.T
		exp_wx = jnp.exp(wx)
		reg_term = lmbd * jnp.linalg.norm(w, ord=1)
		return -jnp.sum(indices @ wx - d * jnp.log(riskset @ exp_wx)) + reg_term

Now, thanks to JAX, we get the gradient and hessian for free

import jax.numpy as jnp
from jax import grad, hessian
from scipy.optimize import minimize
from functools import partial


cost_func = partial(negative_log_cox_partial_likelihood, indices=indices, riskset=riskset, d=d, X=covariates,
                            lmbd=lmbd)
g = grad(cost_func)
h = hessian(cost_func)

result = minimize(cost_func, np.zeros(14), method='Newton-CG', jac=g, hess=h)
betas = result.x

We can check the correctness of estimated betas by lifelines package.

image

However, estimating beta is not the end of it, we still need to estimate h0(t)h_0(t) to calculate hazards function. Once hazards function is estimated, we can get Survival function S(t)S(t) and PDF f(t)f(t). We will discuss them in the next post.

References

Ko, Jessica. Solving the Cox proportional hazards model and its applications. Diss. Master’s thesis, EECS Department, University of California, Berkeley, 2017

Kleinbaum, D. G., & Klein, M. (1996). Survival analysis a self-learning text. Springer.

“Survival Analysis Part 9 | Cox Proportional Hazards Model.” www.youtube.com, youtu.be/aETMUW_TWV0. Accessed 8 May 2023.

Survival Analysis: Optimize the Partial Likelihood of the Cox Model. (2022). Retrieved 8 May 2023, from https://towardsdatascience.com/survival-analysis-optimize-the-partial-likelihood-of-the-cox-model-b56b8f112401

GraphPad Software, L. (2023). GraphPad Prism 9 Statistics Guide - The mathematics of the cumulative hazard function. Retrieved 8 May 2023, from https://www.graphpad.com/guides/prism/latest/statistics/stat_cox_math_cumulative_hazard.htm

Ravinutala, S. (2021). [Survival models] Cox proportional hazard model. Retrieved 8 May 2023, from https://sidravi1.github.io/blog/2021/10/11/cox-proportional-hazard-model