@@ -36,7 +36,7 @@ def neg_partial_time_log_likelihood(
3636 tensor(0.9452, grad_fn=<DivBackward0>)
3737 >>> neg_partial_time_log_likelihood(estimates.squeeze(), time, event) # Also works with 2D tensor
3838 tensor(0.9452, grad_fn=<DivBackward0>)
39- >>> neg_partial_time_log_likelihood(estimates, time, event, reduction=' sum' )
39+ >>> neg_partial_time_log_likelihood(estimates, time, event, reduction=" sum" )
4040 tensor(37.8082, grad_fn=<SumBackward0>)
4141 >>> from torchsurv.metrics.cindex import ConcordanceIndex
4242 >>> cindex = ConcordanceIndex()
@@ -47,7 +47,7 @@ def neg_partial_time_log_likelihood(
4747 tensor(0.4545)
4848 """
4949
50- # only consider theta at tiem of
50+ # only consider theta at time of
5151 pll = _partial_likelihood_time_cox (log_hz , time , events )
5252
5353 # Negative partial log likelihood
@@ -57,11 +57,7 @@ def neg_partial_time_log_likelihood(
5757 elif reduction .lower () == "sum" :
5858 loss = pll .sum ()
5959 else :
60- raise (
61- ValueError (
62- f"Reduction { reduction } is not implemented yet, should be one of ['mean', 'sum']."
63- )
64- )
60+ raise (ValueError (f"Reduction { reduction } is not implemented yet, should be one of ['mean', 'sum']." ))
6561 return loss
6662
6763
@@ -79,7 +75,7 @@ def _partial_likelihood_time_cox(
7975 Log relative hazard of dimension T x n_samples x P.
8076 T is the time series dimension, P is the number of parameters observed over time.
8177 event (torch.Tensor, bool):
82- Event indicator of length n_samples (= True if event occured ).
78+ Event indicator of length n_samples (= True if event occurred ).
8379 time (torch.Tensor):
8480 Time-to-event or censoring of length n_samples.
8581
@@ -88,32 +84,32 @@ def _partial_likelihood_time_cox(
8884 Vector of the partial log likelihood, length n_samples.
8985
9086 Note:
91- For each subject :math:`i \in \{1, \cdots, N\}`, denote :math:`\t au^*_i` as the survival time and :math:`C_i` as the
92- censoring time. Survival data consist of the event indicator, :math:`\delta_i=1(\t au^*_i\leq C_i)`
93- (argument ``event``) and the time-to-event or censoring, :math:`\t au_i = \min(\{ \t au^*_i,D_i \})`
87+ For each subject :math:`i \\ in \\ {1, \\ cdots, N\ \ }`, denote :math:`\t au^*_i` as the survival time and :math:`C_i` as the
88+ censoring time. Survival data consist of the event indicator, :math:`\\ delta_i=1(\t au^*_i\ \ leq C_i)`
89+ (argument ``event``) and the time-to-event or censoring, :math:`\t au_i = \\ min(\\ { \t au^*_i,D_i \ \ })`
9490 (argument ``time``).
9591
9692 Consider some covariate :math:`Z(t)` with covariate history denoted as :math:`H_Z` and a general form of the cox proportional hazards model:
9793 .. math::
9894
99- \log \lambda_i (t|H_Z) = lambda_0(t) \t heta(Z(t))
95+ \\ log \ \ lambda_i (t|H_Z) = lambda_0(t) \t heta(Z(t))
10096
101- A network that maps the input covariates $Z(t)$ to the log relative hazards: :math:`\log \t heta(Z(t))`.
102- The partial likelihood with repsect to :math:`\log \t heta(Z(t))` is written as:
97+ A network that maps the input covariates $Z(t)$ to the log relative hazards: :math:`\\ log \t heta(Z(t))`.
98+ The partial likelihood with respect to :math:`\ \ log \t heta(Z(t))` is written as:
10399
104100 .. math::
105101
106- \log L(\t heta) = \sum_j \Big( \log \t heta(Z_i(\t au_j)) - \log [\sum_{j \in R_i} \t heta (Z_i(\t au_j))] \Big)
102+ \\ log L(\t heta) = \\ sum_j \\ Big( \\ log \t heta(Z_i(\t au_j)) - \\ log [\\ sum_{j \\ in R_i} \t heta (Z_i(\t au_j))] \ \ Big)
107103
108104 and it only considers the values of te covariate :math:`Z` at event time :math:`\t au_i`
109105
110106 Remarks:
111107 - values inside the time vector must be strictly zero or positive as they are used to identify values of
112108 covariates at event time
113- - the maximum value inside the vector time cannt exceed T-1 for indexing reasons
114- - this function was not tested for P>1 but it should be possile for an extension
109+ - the maximum value inside the vector time cannot exceed T-1 for indexing reasons
110+ - this function was not tested for P>1 but it should be possible for an extension
115111 - the values of Z at event time should not be null, a reasonable imputation method should be used,
116- unless the network fullfills that role
112+ unless the network fulfills that role
117113 - future formatting: time vector must somehow correspond to the T dimension in the log_hz tensor, i.e. for those who experience an event,
118114 we want to identify the index of the covariate upon failure. We could either consider the last covariate before a series of zeros
119115 (requires special data formatting but could reduce issues as it automatically contains event time information).
@@ -170,32 +166,17 @@ def _time_varying_covariance(
170166 # sort data by time-to-event or censoring
171167 time_sorted , idx = torch .sort (time )
172168 log_hz_sorted = log_hz [idx ]
173- event_sorted = event [idx ]
174169
175170 # keep log if we can
176171 exp_log_hz = torch .exp (log_hz_sorted )
177172 # remove mean over time from covariates
178173 # sort covariates so that the rows match the ordering
179174 covariates_sorted = covariates [idx , :] - covariates .mean (dim = 0 )
180175
181- # the left hand side (HS) of the equation
182- # below is Z_k Z_k^T - i think it should be a vector matrix dim nxn
183- covariate_inner_product = torch .matmul (covariates_sorted , covariates_sorted .T )
184-
185- # pointwise multiplication of vectors to get the nominator of left HS
186- # outcome in a vector of length n
187- # Ends up being (1, n)
188- log_nominator_left = torch .matmul (exp_log_hz .T , covariate_inner_product )
189-
190176 # right hand size of the equation
191177 # formulate the brackets \sum exp(theta)Z_k
192178 bracket = torch .mul (exp_log_hz , covariates_sorted )
193179 covariance_matrix = torch .matmul (bracket , bracket .T ) # nxn matrix
194- # ###nbelow is commented out as it does not apply but I wanted to keep it for the functions
195- # #log_nominator_right = torch.sum(nominator_right, dim=0).unsqueeze(0)
196- # log_nominator_right = nominator_right[0,].unsqueeze(0)
197- # log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0) #dim=0 sums over the oth dimension
198- # partial_log_likelihood = torch.div(log_nominator_left - log_nominator_right, log_denominator) # (n, n)
199180
200181 return covariance_matrix
201182
0 commit comments