Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 23 additions & 32 deletions data/carracing.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,48 @@
"""
Generating data from the CarRacing gym environment.
!!! DOES NOT WORK ON TITANIC, DO IT AT HOME, THEN SCP !!!
"""
# data/carracing.py
import argparse
from os.path import join, exists
import gym
import numpy as np
import gymnasium as gym
from os.path import join, exists
from utils.misc import sample_continuous_policy

def generate_data(rollouts, data_dir, noise_type): # pylint: disable=R0914
""" Generates data """
def generate_data(rollouts, data_dir, noise_type): # pylint: disable=R0914
assert exists(data_dir), "The data directory does not exist..."

env = gym.make("CarRacing-v0")
env = gym.make("CarRacing-v3", render_mode=None) # headless frame capture [Gymnasium]
seq_len = 1000

for i in range(rollouts):
env.reset()
env.env.viewer.window.dispatch_events()
obs, info = env.reset() # new API
if noise_type == 'white':
a_rollout = [env.action_space.sample() for _ in range(seq_len)]
elif noise_type == 'brown':
a_rollout = sample_continuous_policy(env.action_space, seq_len, 1. / 50)

s_rollout = []
r_rollout = []
d_rollout = []
a_rollout = sample_continuous_policy(env.action_space, seq_len, 1.0 / 50)

s_rollout, r_rollout, d_rollout = [], [], []
t = 0
while True:
action = a_rollout[t]
t += 1

s, r, done, _ = env.step(action)
env.env.viewer.window.dispatch_events()
s_rollout += [s]
r_rollout += [r]
d_rollout += [done]
if done:
print("> End of rollout {}, {} frames...".format(i, len(s_rollout)))
np.savez(join(data_dir, 'rollout_{}'.format(i)),
observations=np.array(s_rollout),
rewards=np.array(r_rollout),
actions=np.array(a_rollout),
terminals=np.array(d_rollout))
s, r, terminated, truncated, info = env.step(action) # new API
done = terminated or truncated
s_rollout.append(s)
r_rollout.append(r)
d_rollout.append(done)
if done or t >= seq_len:
print(f"> End of rollout {i}, {len(s_rollout)} frames...")
np.savez(
join(data_dir, f"rollout_{i}"),
observations=np.array(s_rollout),
rewards=np.array(r_rollout),
actions=np.array(a_rollout[:len(s_rollout)]),
terminals=np.array(d_rollout),
)
break

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--rollouts', type=int, help="Number of rollouts")
parser.add_argument('--dir', type=str, help="Where to place rollouts")
parser.add_argument('--policy', type=str, choices=['white', 'brown'],
help='Noise type used for action sampling.',
default='brown')
help='Noise type used for action sampling.', default='brown')
args = parser.parse_args()
generate_data(args.rollouts, args.dir, args.policy)
21 changes: 6 additions & 15 deletions models/mdrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,10 @@ def gmm_loss(batch, mus, sigmas, logpi, reduce=True): # pylint: disable=too-many
with fs).
"""
batch = batch.unsqueeze(-2)
normal_dist = Normal(mus, sigmas)
g_log_probs = normal_dist.log_prob(batch)
g_log_probs = logpi + torch.sum(g_log_probs, dim=-1)
max_log_probs = torch.max(g_log_probs, dim=-1, keepdim=True)[0]
g_log_probs = g_log_probs - max_log_probs

g_probs = torch.exp(g_log_probs)
probs = torch.sum(g_probs, dim=-1)

log_prob = max_log_probs.squeeze() + torch.log(probs)
if reduce:
return - torch.mean(log_prob)
return - log_prob
sigmas = sigmas.clamp_min(1e-6)
comp_log_prob = Normal(mus, sigmas).log_prob(batch).sum(dim=-1)
log_prob = torch.logsumexp(logpi + comp_log_prob, dim=-1)
return -log_prob.mean() if reduce else -log_prob

class _MDRNNBase(nn.Module):
def __init__(self, latents, actions, hiddens, gaussians):
Expand Down Expand Up @@ -93,7 +84,7 @@ def forward(self, actions, latents): # pylint: disable=arguments-differ

sigmas = gmm_outs[:, :, stride:2 * stride]
sigmas = sigmas.view(seq_len, bs, self.gaussians, self.latents)
sigmas = torch.exp(sigmas)
sigmas = f.softplus(sigmas) + 1e-6

pi = gmm_outs[:, :, 2 * stride: 2 * stride + self.gaussians]
pi = pi.view(seq_len, bs, self.gaussians)
Expand Down Expand Up @@ -141,7 +132,7 @@ def forward(self, action, latent, hidden): # pylint: disable=arguments-differ

sigmas = out_full[:, stride:2 * stride]
sigmas = sigmas.view(-1, self.gaussians, self.latents)
sigmas = torch.exp(sigmas)
sigmas = f.softplus(sigmas) + 1e-6

pi = out_full[:, 2 * stride:2 * stride + self.gaussians]
pi = pi.view(-1, self.gaussians)
Expand Down
2 changes: 1 addition & 1 deletion models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def forward(self, x): # pylint: disable=arguments-differ
x = F.relu(self.deconv1(x))
x = F.relu(self.deconv2(x))
x = F.relu(self.deconv3(x))
reconstruction = F.sigmoid(self.deconv4(x))
reconstruction = torch.sigmoid(self.deconv4(x))
return reconstruction

class Encoder(nn.Module): # pylint: disable=too-many-instance-attributes
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
cma
argparse
gym[all]
box2d
gymnasium[box2d]>=0.29
tqdm
numpy
torch>=2.1
torchvision>=0.16
69 changes: 43 additions & 26 deletions trainmdrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from utils.misc import save_checkpoint
from utils.misc import ASIZE, LSIZE, RSIZE, RED_SIZE, SIZE
from utils.learning import EarlyStopping
## WARNING : THIS SHOULD BE REPLACED WITH PYTORCH 0.5
from utils.learning import ReduceLROnPlateau
from torch.optim.lr_scheduler import ReduceLROnPlateau

from data.loaders import RolloutSequenceDataset
from models.vae import VAE
Expand Down Expand Up @@ -72,39 +71,56 @@


# Data Loading
transform = transforms.Lambda(
lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
transform = transforms.Lambda(lambda x: np.transpose(x, (0,3,1,2)).astype(np.float32) / 255.0)

train_loader = DataLoader(
RolloutSequenceDataset('datasets/carracing', SEQ_LEN, transform, buffer_size=30),
batch_size=BSIZE, num_workers=8, shuffle=True)
batch_size=BSIZE, num_workers=8, shuffle=True, drop_last=True)
test_loader = DataLoader(
RolloutSequenceDataset('datasets/carracing', SEQ_LEN, transform, train=False, buffer_size=10),
batch_size=BSIZE, num_workers=8)
batch_size=BSIZE, num_workers=8, drop_last=True)

def to_latent(obs, next_obs):
""" Transform observations to latent space.

:args obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE)
:args next_obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE)

:returns: (latent_obs, latent_next_obs)
- latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
- next_latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
"""
obs, next_obs: (B, T, C, H, W), values in [0,1] (your transform already divides by 255)
Returns: latent_obs, latent_next_obs: (B, T, LSIZE)
"""
with torch.no_grad():
obs, next_obs = [
f.upsample(x.view(-1, 3, SIZE, SIZE), size=RED_SIZE,
mode='bilinear', align_corners=True)
for x in (obs, next_obs)]
B, T, C, H, W = obs.shape

# If VAE expects 3 channels, adapt here.
# Option A: if C==3, great. Use as-is.
# Option B: if C==1, repeat to 3 channels.
# Option C: if C>3 (e.g., stacked), pick 3 channels (first 3 shown here; adjust to your data).
if C == 3:
obs3, next3 = obs, next_obs
elif C == 1:
obs3 = obs.repeat(1, 1, 3, 1, 1)
next3 = next_obs.repeat(1, 1, 3, 1, 1)
else:
# choose which 3 to use; could also learn a 1x1 conv mapper
obs3 = obs[:, :, :3, :, :]
next3 = next_obs[:, :, :3, :, :]

(obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma) = [
vae(x)[1:] for x in (obs, next_obs)]
# Flatten (B, T) → N rows, keep dynamic H, W
x_obs = obs3.reshape(-1, 3, H, W).contiguous()
x_nobs = next3.reshape(-1, 3, H, W).contiguous()

latent_obs, latent_next_obs = [
(x_mu + x_logsigma.exp() * torch.randn_like(x_mu)).view(BSIZE, SEQ_LEN, LSIZE)
for x_mu, x_logsigma in
[(obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma)]]
return latent_obs, latent_next_obs
# Resize to exactly what the VAE was trained on
x_obs = f.interpolate(x_obs, size=RED_SIZE, mode='bilinear', align_corners=True)
x_nobs = f.interpolate(x_nobs, size=RED_SIZE, mode='bilinear', align_corners=True)

# VAE forward returns (recon, (mu, logsigma)); you use the latter
(obs_mu, obs_logsigma), (next_mu, next_logsigma) = [vae(x)[1:] for x in (x_obs, x_nobs)]

# Sanity: VAE should output one latent per input row → N == B*T
assert obs_mu.shape[0] == B*T, f"VAE batch mismatch: got {obs_mu.shape[0]} vs expected {B*T}"

# Reparameterization per frame, then reshape back to (B, T, LSIZE)
latent_obs = (obs_mu + obs_logsigma.exp() * torch.randn_like(obs_mu)).reshape(B, T, LSIZE)
latent_next = (next_mu + next_logsigma.exp() * torch.randn_like(next_mu)).reshape(B, T, LSIZE)

return latent_obs, latent_next

def get_loss(latent_obs, action, reward, terminal,
latent_next_obs, include_reward: bool):
Expand Down Expand Up @@ -163,7 +179,8 @@ def data_pass(epoch, train, include_reward): # pylint: disable=too-many-locals
pbar = tqdm(total=len(loader.dataset), desc="Epoch {}".format(epoch))
for i, data in enumerate(loader):
obs, action, reward, terminal, next_obs = [arr.to(device) for arr in data]

print("-------------------------------------",obs.shape)
print(next_obs.shape)
# transform obs
latent_obs, latent_next_obs = to_latent(obs, next_obs)

Expand Down
94 changes: 63 additions & 31 deletions trainvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
import torch.utils.data
from torch import optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import v2
from torchvision.utils import save_image

from models.vae import VAE

from utils.misc import save_checkpoint
from utils.misc import LSIZE, RED_SIZE
## WARNING : THIS SHOULD BE REPLACE WITH PYTORCH 0.5
from utils.learning import EarlyStopping
from utils.learning import ReduceLROnPlateau
from torch.optim.lr_scheduler import ReduceLROnPlateau
from data.loaders import RolloutObservationDataset
from torch.cuda.amp import autocast, GradScaler


parser = argparse.ArgumentParser(description='VAE Trainer')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
Expand All @@ -33,26 +34,26 @@

args = parser.parse_args()
cuda = torch.cuda.is_available()

device = torch.device("cuda" if cuda else "cpu")

torch.manual_seed(123)
# Fix numeric divergence due to bug in Cudnn
torch.backends.cudnn.benchmark = True
# if args.consistent_input_sizes:
# torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if cuda else "cpu")

scaler = GradScaler() if cuda else None

transform_train = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((RED_SIZE, RED_SIZE)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transform_train = v2.Compose([
v2.ToPILImage(),
v2.Resize((RED_SIZE, RED_SIZE)),
v2.RandomHorizontalFlip(),
v2.ToTensor(),
])

transform_test = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((RED_SIZE, RED_SIZE)),
transforms.ToTensor(),
transform_test = v2.Compose([
v2.ToPILImage(),
v2.Resize((RED_SIZE, RED_SIZE)),
v2.ToTensor(),
])

dataset_train = RolloutObservationDataset('datasets/carracing',
Expand All @@ -70,37 +71,68 @@
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
earlystopping = EarlyStopping('min', patience=30)

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logsigma):
""" VAE loss function """
BCE = F.mse_loss(recon_x, x, size_average=False)
# # Reconstruction + KL divergence losses summed over all elements and batch
# def loss_function(recon_x, x, mu, logsigma):
# """ VAE loss function """
# recon_loss = F.mse_loss(recon_x, x, reduction='sum')

# # see Appendix B from VAE paper:
# # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# # https://arxiv.org/abs/1312.6114
# # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
# # KL divergence using distributions
# posterior = torch.distributions.Normal(mu, torch.exp(0.5 * logvar))
# prior = torch.distributions.Normal(0, 1)
# kl_loss = torch.distributions.kl.kl_divergence(posterior, prior).sum()
# return recon_loss + kl_loss

def loss_function(recon_x, x, mu, logvar, beta=1.0, return_parts=False):
recon_loss = F.mse_loss(recon_x, x, reduction='sum') # Gaussian decoder likelihood
logvar = torch.clamp(logvar, -20.0, 20.0)
std = torch.exp(0.5 * logvar)
posterior = torch.distributions.Normal(mu, std + 1e-8)
prior = torch.distributions.Normal(0.0, 1.0)
kl_loss = torch.distributions.kl.kl_divergence(posterior, prior).sum()
total = recon_loss + beta * kl_loss
return (total, recon_loss, kl_loss) if return_parts else total

# ---- KL weight schedules ----
def linear_beta(epoch: int, warmup_epochs: int) -> float:
# β ramps 0 → 1 over warmup_epochs, then stays at 1
if warmup_epochs <= 0:
return 1.0
return min(1.0, float(epoch) / float(warmup_epochs))

def cyclical_beta(epoch: int, cycle_len: int, floor: float = 0.0, ceil: float = 1.0) -> float:
# β cycles from floor → ceil each cycle_len epochs
if cycle_len <= 0:
return ceil
t = (epoch % cycle_len) / float(cycle_len)
return floor + (ceil - floor) * t

# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + 2 * logsigma - mu.pow(2) - (2 * logsigma).exp())
return BCE + KLD


def train(epoch):
""" One training epoch """
model.train()
dataset_train.load_next_buffer()
train_loss = 0
beta = linear_beta(epoch, warmup_epochs=50) # warm-up for 50 epochs
for batch_idx, data in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss, recon_loss, kl_loss = loss_function(recon_batch, data, mu, logvar, beta=beta, return_parts=True)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % 20 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
lr = optimizer.param_groups[0]['lr']
bs = len(data)
print(f"Train Epoch: {epoch} [{batch_idx*bs}/{len(train_loader.dataset)} "
f"({100.*batch_idx/len(train_loader):.0f}%)] "
f"loss/img={loss.item()/bs:.4f} recon/img={recon_loss.item()/bs:.4f} "
f"kl/img={kl_loss.item()/bs:.4f} beta={beta:.3f} lr={lr:.2e}")

print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
Expand Down
Loading