@@ -168,10 +168,30 @@ def _partial_likelihood_cox(
168168 log_hz_sorted : torch .Tensor ,
169169 event_sorted : torch .Tensor ,
170170) -> torch .Tensor :
171- """Calculate the partial log likelihood for the Cox proportional hazards model
171+ r """Calculate the partial log likelihood for the Cox proportional hazards model
172172 in the absence of ties in event time.
173+
174+ Args:
175+ log_hz_sorted (torch.Tensor, float):
176+ Log relative hazard of length n_samples, ordered by time-to-event or censoring.
177+ event_sorted (torch.Tensor, bool):
178+ Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring.
179+
180+ Returns:
181+ (torch.tensor, float):
182+ Vector of the partial log likelihoods.
183+
184+ Note:
185+ Let :math:`\tau_1 < \tau_2 < \cdots < \tau_N`
186+ be the ordered times and let :math:`R(\tau_i) = \{ j: \tau_j \geq \tau_i\}`
187+ be the risk set at :math:`\tau_i`. The partial log likelihood is defined as:
188+
189+ .. math::
190+
191+ pll = \sum_{i: \: \delta_i = 1} \left(\log \theta_i - \log\left(\sum_{j \in R(\tau_i)} \theta_j \right) \right)
173192 """
174193 log_denominator = torch .logcumsumexp (log_hz_sorted .flip (0 ), dim = 0 ).flip (0 )
194+
175195 return (log_hz_sorted - log_denominator )[event_sorted ]
176196
177197
@@ -183,6 +203,19 @@ def _partial_likelihood_efron(
183203) -> torch .Tensor :
184204 """Calculate the partial log likelihood for the Cox proportional hazards model
185205 using Efron's method to handle ties in event time.
206+
207+ Args:
208+ log_hz_sorted (torch.Tensor, float):
209+ Log relative hazard of length n_samples, ordered by time-to-event or censoring.
210+ event_sorted (torch.Tensor, bool):
211+ Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring.
212+ time_sorted (torch.Tensor):
213+ Time-to-event values sorted in order.
214+ time_unique (torch.Tensor):
215+ Set of unique time-to-event values.
216+ Returns:
217+ (torch.tensor, float):
218+ Vector of partial log likelihood estimated using Efron's method.
186219 """
187220 J = len (time_unique )
188221
@@ -206,6 +239,7 @@ def _partial_likelihood_efron(
206239 log_denominator_efron [j ] += torch .log (
207240 denominator_naive [j ] - (l - 1 ) / m [j ] * denominator_ties [j ]
208241 )
242+
209243 return (log_nominator - log_denominator_efron )[include ]
210244
211245
@@ -216,6 +250,18 @@ def _partial_likelihood_breslow(
216250):
217251 """Calculate the partial log likelihood for the Cox proportional hazards model
218252 using Breslow's method to handle ties in event time.
253+
254+ Args:
255+ log_hz_sorted (torch.Tensor, float):
256+ Log relative hazard of length n_samples, ordered by time-to-event or censoring.
257+ event_sorted (torch.Tensor, bool):
258+ Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring.
259+ time_sorted (torch.Tensor):
260+ Time-to-event values sorted in order.
261+
262+ Returns:
263+ (torch.tensor, float):
264+ Vector containing partial log likelihood estimated using Breslow's method.
219265 """
220266 N = len (time_sorted )
221267
0 commit comments