-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_interpolation_task.py
More file actions
164 lines (127 loc) · 6.27 KB
/
_interpolation_task.py
File metadata and controls
164 lines (127 loc) · 6.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# -- import packages: ---------------------------------------------------------
import ABCParse
import autodevice
import torch
# -- import local dependencies: -----------------------------------------------
from ._interpolation_data import InterpolationData
# -- set typing: --------------------------------------------------------------
from typing import Tuple
# -- Operational class: -------------------------------------------------------
class InterpolationTask(ABCParse.ABCParse):
def __init__(
self,
adata,
time_key="Time point",
use_key="X_pca",
t0=2,
n_samples=10_000,
lineage_key="clone_idx",
device=autodevice.AutoDevice(),
backend = "auto",
silent = False,
PCA = None,
batch_size = None,
*args,
**kwargs,
):
self.__parse__(locals())
self.data = InterpolationData(**self._DATA_KWARGS)
from scdiffeq.core.lightning_models.base import SinkhornDivergence
self._sinkhorn_fn = SinkhornDivergence
self.SinkhornDivergence = self._sinkhorn_fn(**self._SINKHORN_KWARGS)
@property
def _DATA_KWARGS(self):
return ABCParse.function_kwargs(
func=InterpolationData, kwargs=self._PARAMS
)
@property
def _SINKHORN_KWARGS(self):
return ABCParse.function_kwargs(
func=self._sinkhorn_fn, kwargs=self._PARAMS
)
def forward_without_grad(self, DiffEq):
"""Forward integrate over the model without gradients."""
with torch.no_grad():
X_hat = DiffEq.forward(self.data.X0, self.data.t)
return self._parse_forward_out(X_hat)
def forward_with_grad(self, DiffEq):
"""Forward integrate over the model retaining gradients."""
torch.set_grad_enabled(True)
X_hat = DiffEq.forward(self.data.X0, self.data.t)
return self._parse_forward_out(X_hat)
@property
def potential(self):
return "Potential" in str(self._DiffEq)
def _parse_forward_out(self, X_hat):
"""to account for KLDiv"""
if isinstance(X_hat, Tuple):
return X_hat[0]
return X_hat
# def _dimension_reduce_pca(self, X_hat):
# return torch.stack(
# [torch.Tensor(self._PCA.transform(x)) for x in X_hat.detach().cpu().numpy()]
# ).to(self.device)
def _apply_pca_to_feature_slice(self, data_slice: torch.Tensor) -> torch.Tensor:
"""Applies PCA to a data slice of shape (n_samples, n_features)."""
if self._PCA is None:
return data_slice
data_slice_np = data_slice.detach().cpu().numpy()
if data_slice_np.ndim == 1:
data_slice_np = data_slice_np.reshape(1, -1)
pca_transformed_np = self._PCA.transform(data_slice_np)
return torch.Tensor(pca_transformed_np).to(data_slice.device)
def __call__(self, trainer, DiffEq, *args, **kwargs):
self.__update__(locals())
actual_n_samples = self.data.X0.shape[0]
if self._batch_size is not None and self._batch_size < actual_n_samples:
# Batch processing
all_X_hat_d4_processed_list = []
all_X_hat_d6_processed_list = []
original_grad_enabled = torch.is_grad_enabled()
for i in range(0, actual_n_samples, self._batch_size):
batch_X0 = self.data.X0[i : i + self._batch_size]
# Use self._DiffEq which was set by self.__update__
if self.potential:
torch.set_grad_enabled(True)
# Pass batch_X0 directly, not self.data.X0
X_hat_batch_raw = self._DiffEq.forward(batch_X0, self.data.t)
else:
with torch.no_grad():
X_hat_batch_raw = self._DiffEq.forward(batch_X0, self.data.t)
X_hat_batch = self._parse_forward_out(X_hat_batch_raw)
# X_hat_batch shape: (current_batch_size, n_timepoints, n_features)
X_hat_d4_batch = X_hat_batch[:, 1, :] # (current_batch_size, n_features)
X_hat_d6_batch = X_hat_batch[:, 2, :] # (current_batch_size, n_features)
if self._PCA is not None:
X_hat_d4_batch = self._apply_pca_to_feature_slice(X_hat_d4_batch)
X_hat_d6_batch = self._apply_pca_to_feature_slice(X_hat_d6_batch)
all_X_hat_d4_processed_list.append(X_hat_d4_batch)
all_X_hat_d6_processed_list.append(X_hat_d6_batch)
torch.set_grad_enabled(original_grad_enabled)
final_X_hat_d4 = torch.cat(all_X_hat_d4_processed_list, dim=0)
final_X_hat_d6 = torch.cat(all_X_hat_d6_processed_list, dim=0)
d4_loss = self.SinkhornDivergence(final_X_hat_d4, self.data.X_test_d4).item()
d6_loss = self.SinkhornDivergence(final_X_hat_d6, self.data.X_train_d6).item()
else:
# Original non-batched logic (or batch_size >= n_samples)
# The forward_with_grad/without_grad methods use self.data.X0 and self.data.t internally
if self.potential:
# These methods call DiffEq.forward(self.data.X0, self.data.t)
X_hat_raw = self.forward_with_grad(self._DiffEq)
else:
X_hat_raw = self.forward_without_grad(self._DiffEq)
# X_hat_raw shape: (actual_n_samples, n_timepoints, n_features)
X_hat_d4 = X_hat_raw[:, 1, :] # (actual_n_samples, n_features)
X_hat_d6 = X_hat_raw[:, 2, :] # (actual_n_samples, n_features)
if self._PCA is not None:
X_hat_d4 = self._apply_pca_to_feature_slice(X_hat_d4)
X_hat_d6 = self._apply_pca_to_feature_slice(X_hat_d6)
d4_loss = self.SinkhornDivergence(X_hat_d4, self.data.X_test_d4).item()
d6_loss = self.SinkhornDivergence(X_hat_d6, self.data.X_train_d6).item()
if not self._silent:
print(
"- Epoch: {:<5}| Day 4 loss: {:.2f} | Day 6 loss: {:.2f}".format(
self._DiffEq.current_epoch, d4_loss, d6_loss,
),
)
return d4_loss, d6_loss