Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configs/vae_ecg_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ train:
epochs: 10
log_interval: 50
lr: 1e-4
loss_function: "VAE_MSE"
kl_weight: 1e-5
recon_weight: 0.1

model:
type: "ECG_VAE"
Expand Down
70 changes: 70 additions & 0 deletions pipeline/vae_ecg/evals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
import torch.nn as nn
from matplotlib import pyplot as plt

from utils.visualizations import interpolate_embeddings


def draw_ecg_reconstructions(
signals: torch.Tensor,
reconstructions: torch.Tensor,
) -> plt.Figure:
reconstructions = reconstructions.detach().cpu()
signals = signals.cpu()
n_samples = signals.shape[0]

fig, axes = plt.subplots(ncols=2, nrows=n_samples, figsize=[10, 2 * n_samples])
for it in range(n_samples):
left_ax = axes[it][0]
right_ax = axes[it][1]
left_ax.plot(signals[it][0], label="signal")
left_ax.plot(reconstructions[it][0], label="reconstruction")
left_ax.legend()

right_ax.plot(signals[it][1], label="signal")
right_ax.plot(reconstructions[it][1], label="reconstruction")
right_ax.legend()

plt.tight_layout()
return fig


def draw_interpolation_tower(model: nn.Module, signals: torch.Tensor, num_interps: int = 10) -> plt.Figure:
mu, logvar = model.encode(signals)
embeddings = model.reparameterize(mu, logvar)

# This only works for the first two signals, if there's more we ignore
left = embeddings[0]
right = embeddings[1]
interpolated_embeddings = interpolate_embeddings(
left=left,
right=right,
num_interps=num_interps,
)

decoded_interpolations = model.decode(interpolated_embeddings)
interpolated_signals = [signals[0].unsqueeze(0), decoded_interpolations, signals[1].unsqueeze(0)]
interpolated_signals = torch.cat(interpolated_signals)

# Include target signals
n_samples = num_interps + 2
fig, axes = plt.subplots(
ncols=2,
nrows=n_samples,
figsize=[10, 1 * n_samples],
gridspec_kw={"hspace": 0.0},
)

for it in range(n_samples):
signal = interpolated_signals[it].detach().cpu()
left_ax = axes[it][0]
left_ax.plot(signal[0])
left_ax.set_xlim(0, 1000)
left_ax.xaxis.set_visible(False)

right_ax = axes[it][1]
right_ax.plot(signal[1])
right_ax.set_xlim(0, 1000)
right_ax.xaxis.set_visible(False)

plt.tight_layout()
89 changes: 77 additions & 12 deletions pipeline/vae_ecg/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,29 @@
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from omegaconf import DictConfig
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from omegaconf import OmegaConf, DictConfig

import wandb
from utils.data_loader import prepare_dataset
from utils.train_utils import prepare_loss_function
from pipeline.vae_ecg import evals as vae_ecg_evals
from models.autoencoder_ecg import VariationalAutoencoderECG


def train(cfg: DictConfig) -> nn.Module:
train_dataset, test_dataset = prepare_dataset(cfg)
train_loader = DataLoader(train_dataset, batch_size=cfg.train.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=cfg.train.batch_size, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=cfg.train.batch_size, shuffle=True, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=cfg.train.batch_size, shuffle=False, num_workers=8)

device = cfg.system.device

# Fixed samples for visualization
n_samples = 5
idxs = np.random.randint(len(test_dataset), size=n_samples)
fixed_validation_signals = test_dataset[idxs]["signal"].to(device)

input_size = train_dataset.input_size
model = VariationalAutoencoderECG(
encoder_output_size=cfg.model.encoder_output_size,
Expand All @@ -25,7 +33,6 @@ def train(cfg: DictConfig) -> nn.Module:
)
model = model.to(device)

loss_fn = prepare_loss_function(loss_function_name=cfg.train.loss_function)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.train.lr)

step = 0
Expand All @@ -41,16 +48,22 @@ def train(cfg: DictConfig) -> nn.Module:
data = batch["signal"].to(device)
optimizer.zero_grad()

recon_batch, mu, logvar = model(data)
loss = loss_fn(recon_batch, data, mu, logvar)
losses = forward_step(
model=model,
data=data,
kl_weight=cfg.train.kl_weight,
recon_weight=cfg.train.recon_weight,
)
loss = losses["loss"]
loss.backward()
optimizer.step()

train_loss.append(loss.item())

if step % cfg.train.log_interval == 0:
train_progress.set_postfix(loss=loss.item())
wandb.log({"train/loss": loss.item()}, step=step)
metrics = {f"train/{key}": value.item() for key, value in losses.items()}
wandb.log(metrics, step=step)

step += 1

Expand All @@ -62,10 +75,20 @@ def train(cfg: DictConfig) -> nn.Module:
test_progress = tqdm(enumerate(test_loader), total=len(test_loader), leave=False)
for it, batch in test_progress:
data = batch["signal"].to(device)
recon_batch, mu, logvar = model(data)
loss = loss_fn(recon_batch, data, mu, logvar)
losses = forward_step(
model=model,
data=data,
kl_weight=cfg.train.kl_weight,
recon_weight=cfg.train.recon_weight,
)
loss = losses["loss"]
test_loss.append(loss.item())

# Review reconstructions
reconstructions, mu, logvar = model(fixed_validation_signals)
fig = vae_ecg_evals.draw_ecg_reconstructions(fixed_validation_signals, reconstructions)
wandb.log({"test/reconstruction": wandb.Image(fig)}, step=step)

# Epoch summary
test_loss = np.mean(test_loss)
train_loss = np.mean(train_loss)
Expand All @@ -76,7 +99,8 @@ def train(cfg: DictConfig) -> nn.Module:
checkpoint = {
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"config": cfg,
"config": OmegaConf.to_object(cfg),
"test_loss": test_loss,
}
checkpoint_path = "{}/{}.pt".format(cfg.checkpoint_path, cfg.run_name)
torch.save(checkpoint, checkpoint_path)
Expand All @@ -85,5 +109,46 @@ def train(cfg: DictConfig) -> nn.Module:
return model


def forward_step(model: nn.Module, data: torch.Tensor, kl_weight: float, recon_weight: float) -> dict[str, torch.Tensor]:
recon_batch, mu, logvar = model(data)

recon_loss = F.mse_loss(recon_batch, data, reduction="mean")
KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

# Balance between reconstruction loss and KLD is a hyperparameter
# See: https://github.com/Nospoko/autoencoders-demo/pull/7
loss = recon_weight * recon_loss + kl_weight * KLD

losses = {
"loss": loss,
"KLD": KLD,
"recon": recon_loss,
}

return losses


def main(cfg: DictConfig):
_ = train(cfg)
model = train(cfg)
model.eval()

# Data prep
_, test_dataset = prepare_dataset(cfg)
device = cfg.system.device
n_samples = 16
idxs = np.random.randint(len(test_dataset), size=n_samples)
signals = test_dataset[idxs]["signal"].to(device)

# Process
reconstructions, mu, logvar = model(signals)

# Review reconstructions
vae_ecg_evals.draw_ecg_reconstructions(signals, reconstructions)
savepath = "tmp/vae-ecg-reconstruction.png"
plt.savefig(savepath)
print("Saved an image!", savepath)

# Review embedding based interpolations
vae_ecg_evals.draw_interpolation_tower(model, signals, 16)
savepath = "tmp/vae-ecg-interpolation.png"
plt.savefig(savepath)
6 changes: 6 additions & 0 deletions utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def __init__(self, dataset: HFDataset, input_size: tuple):
# I don't like this but what can you do
self.data_key = "image" if len(input_size) == 3 else "signal"

def __rich_repr__(self):
yield "SizedDataset"
yield "data_key", self.data_key
yield "input_size", self.input_size
yield "n_samples", len(self)

@property
def n_channels(self) -> int:
return self.input_size[0]
Expand Down