Skip to content

Commit ffd6920

Browse files
SoniaDemtcoroller
andauthored
added comments to functions in cox.py file (#76)
* added comments to functions in cox.py file * linting * missing escape --------- Co-authored-by: corolth1 <[email protected]>
1 parent 08049b1 commit ffd6920

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

src/torchsurv/loss/cox.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,30 @@ def _partial_likelihood_cox(
168168
log_hz_sorted: torch.Tensor,
169169
event_sorted: torch.Tensor,
170170
) -> torch.Tensor:
171-
"""Calculate the partial log likelihood for the Cox proportional hazards model
171+
r"""Calculate the partial log likelihood for the Cox proportional hazards model
172172
in the absence of ties in event time.
173+
174+
Args:
175+
log_hz_sorted (torch.Tensor, float):
176+
Log relative hazard of length n_samples, ordered by time-to-event or censoring.
177+
event_sorted (torch.Tensor, bool):
178+
Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring.
179+
180+
Returns:
181+
(torch.tensor, float):
182+
Vector of the partial log likelihoods.
183+
184+
Note:
185+
Let :math:`\tau_1 < \tau_2 < \cdots < \tau_N`
186+
be the ordered times and let :math:`R(\tau_i) = \{ j: \tau_j \geq \tau_i\}`
187+
be the risk set at :math:`\tau_i`. The partial log likelihood is defined as:
188+
189+
.. math::
190+
191+
pll = \sum_{i: \: \delta_i = 1} \left(\log \theta_i - \log\left(\sum_{j \in R(\tau_i)} \theta_j \right) \right)
173192
"""
174193
log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0)
194+
175195
return (log_hz_sorted - log_denominator)[event_sorted]
176196

177197

@@ -183,6 +203,19 @@ def _partial_likelihood_efron(
183203
) -> torch.Tensor:
184204
"""Calculate the partial log likelihood for the Cox proportional hazards model
185205
using Efron's method to handle ties in event time.
206+
207+
Args:
208+
log_hz_sorted (torch.Tensor, float):
209+
Log relative hazard of length n_samples, ordered by time-to-event or censoring.
210+
event_sorted (torch.Tensor, bool):
211+
Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring.
212+
time_sorted (torch.Tensor):
213+
Time-to-event values sorted in order.
214+
time_unique (torch.Tensor):
215+
Set of unique time-to-event values.
216+
Returns:
217+
(torch.tensor, float):
218+
Vector of partial log likelihood estimated using Efron's method.
186219
"""
187220
J = len(time_unique)
188221

@@ -206,6 +239,7 @@ def _partial_likelihood_efron(
206239
log_denominator_efron[j] += torch.log(
207240
denominator_naive[j] - (l - 1) / m[j] * denominator_ties[j]
208241
)
242+
209243
return (log_nominator - log_denominator_efron)[include]
210244

211245

@@ -216,6 +250,18 @@ def _partial_likelihood_breslow(
216250
):
217251
"""Calculate the partial log likelihood for the Cox proportional hazards model
218252
using Breslow's method to handle ties in event time.
253+
254+
Args:
255+
log_hz_sorted (torch.Tensor, float):
256+
Log relative hazard of length n_samples, ordered by time-to-event or censoring.
257+
event_sorted (torch.Tensor, bool):
258+
Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring.
259+
time_sorted (torch.Tensor):
260+
Time-to-event values sorted in order.
261+
262+
Returns:
263+
(torch.tensor, float):
264+
Vector containing partial log likelihood estimated using Breslow's method.
219265
"""
220266
N = len(time_sorted)
221267

0 commit comments

Comments
 (0)