Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 80 additions & 19 deletions larry/tasks/interpolation/_interpolation_task.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@

# -- import local dependencies: -----------------------------------------------
from ._interpolation_data import InterpolationData


# -- import packages: ---------------------------------------------------------
import ABCParse
import autodevice
import torch

# -- import local dependencies: -----------------------------------------------
from ._interpolation_data import InterpolationData


# -- set typing: --------------------------------------------------------------
from typing import Tuple


# -- Operational class: -------------------------------------------------------
class InterpolationTask(ABCParse.ABCParse):
def __init__(
Expand All @@ -27,6 +24,7 @@ def __init__(
backend = "auto",
silent = False,
PCA = None,
batch_size = None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -72,31 +70,94 @@ def _parse_forward_out(self, X_hat):
return X_hat[0]
return X_hat

def _dimension_reduce_pca(self, X_hat):
return torch.stack(
[torch.Tensor(self._PCA.transform(x)) for x in X_hat.detach().cpu().numpy()]
).to(self.device)
# def _dimension_reduce_pca(self, X_hat):
# return torch.stack(
# [torch.Tensor(self._PCA.transform(x)) for x in X_hat.detach().cpu().numpy()]
# ).to(self.device)

def _apply_pca_to_feature_slice(self, data_slice: torch.Tensor) -> torch.Tensor:
"""Applies PCA to a data slice of shape (n_samples, n_features)."""
if self._PCA is None:
return data_slice

data_slice_np = data_slice.detach().cpu().numpy()
if data_slice_np.ndim == 1:
data_slice_np = data_slice_np.reshape(1, -1)

pca_transformed_np = self._PCA.transform(data_slice_np)
return torch.Tensor(pca_transformed_np).to(data_slice.device)

def __call__(self, trainer, DiffEq, *args, **kwargs):

self.__update__(locals())

if self.potential:
X_hat = self.forward_with_grad(DiffEq)
actual_n_samples = self.data.X0.shape[0]

if self._batch_size is not None and self._batch_size < actual_n_samples:
# Batch processing
all_X_hat_d4_processed_list = []
all_X_hat_d6_processed_list = []

original_grad_enabled = torch.is_grad_enabled()

for i in range(0, actual_n_samples, self._batch_size):
batch_X0 = self.data.X0[i : i + self._batch_size]

# Use self._DiffEq which was set by self.__update__
if self.potential:
torch.set_grad_enabled(True)
# Pass batch_X0 directly, not self.data.X0
X_hat_batch_raw = self._DiffEq.forward(batch_X0, self.data.t)
else:
with torch.no_grad():
X_hat_batch_raw = self._DiffEq.forward(batch_X0, self.data.t)

X_hat_batch = self._parse_forward_out(X_hat_batch_raw)
# X_hat_batch shape: (current_batch_size, n_timepoints, n_features)

X_hat_d4_batch = X_hat_batch[:, 1, :] # (current_batch_size, n_features)
X_hat_d6_batch = X_hat_batch[:, 2, :] # (current_batch_size, n_features)

if self._PCA is not None:
X_hat_d4_batch = self._apply_pca_to_feature_slice(X_hat_d4_batch)
X_hat_d6_batch = self._apply_pca_to_feature_slice(X_hat_d6_batch)

all_X_hat_d4_processed_list.append(X_hat_d4_batch)
all_X_hat_d6_processed_list.append(X_hat_d6_batch)

torch.set_grad_enabled(original_grad_enabled)

final_X_hat_d4 = torch.cat(all_X_hat_d4_processed_list, dim=0)
final_X_hat_d6 = torch.cat(all_X_hat_d6_processed_list, dim=0)

d4_loss = self.SinkhornDivergence(final_X_hat_d4, self.data.X_test_d4).item()
d6_loss = self.SinkhornDivergence(final_X_hat_d6, self.data.X_train_d6).item()

else:
X_hat = self.forward_without_grad(DiffEq)
# Original non-batched logic (or batch_size >= n_samples)
# The forward_with_grad/without_grad methods use self.data.X0 and self.data.t internally
if self.potential:
# These methods call DiffEq.forward(self.data.X0, self.data.t)
X_hat_raw = self.forward_with_grad(self._DiffEq)
else:
X_hat_raw = self.forward_without_grad(self._DiffEq)

if not self._PCA is None:
X_hat = self._dimension_reduce_pca(X_hat)

d4_loss = self.SinkhornDivergence(X_hat[1], self.data.X_test_d4).item()
d6_loss = self.SinkhornDivergence(X_hat[2], self.data.X_train_d6).item()
# X_hat_raw shape: (actual_n_samples, n_timepoints, n_features)

X_hat_d4 = X_hat_raw[:, 1, :] # (actual_n_samples, n_features)
X_hat_d6 = X_hat_raw[:, 2, :] # (actual_n_samples, n_features)

if self._PCA is not None:
X_hat_d4 = self._apply_pca_to_feature_slice(X_hat_d4)
X_hat_d6 = self._apply_pca_to_feature_slice(X_hat_d6)

d4_loss = self.SinkhornDivergence(X_hat_d4, self.data.X_test_d4).item()
d6_loss = self.SinkhornDivergence(X_hat_d6, self.data.X_train_d6).item()

if not self._silent:
print(
"- Epoch: {:<5}| Day 4 loss: {:.2f} | Day 6 loss: {:.2f}".format(
DiffEq.current_epoch, d4_loss, d6_loss,
self._DiffEq.current_epoch, d4_loss, d6_loss,
),
)

Expand Down