diff --git a/local_hydra/local_experiment/ae_dc_large.yaml b/local_hydra/local_experiment/ae/ae_dc_large.yaml similarity index 100% rename from local_hydra/local_experiment/ae_dc_large.yaml rename to local_hydra/local_experiment/ae/ae_dc_large.yaml diff --git a/local_hydra/local_experiment/ae/ae_dc_large_128.yaml b/local_hydra/local_experiment/ae/ae_dc_large_128.yaml new file mode 100644 index 00000000..6ff271eb --- /dev/null +++ b/local_hydra/local_experiment/ae/ae_dc_large_128.yaml @@ -0,0 +1,26 @@ +# @package _global_ +defaults: + - /distributed: ddp_4gpu_slurm + - override /datamodule: advection_diffusion + - override /model: autoencoder_dc_large + - override /optimizer: psgd + - _self_ + +experiment_name: ae_dc_large + +datamodule: + use_normalization: false + batch_size: 64 + + +logging: + wandb: + enabled: false + +optimizer: + learning_rate: 1e-5 + weight_decay: 0.0 + scheduler: cosine + +trainer: + gradient_clip_val: 1.0 diff --git a/local_hydra/local_experiment/ae_dc_large_periodic.yaml b/local_hydra/local_experiment/ae/ae_dc_large_periodic.yaml similarity index 100% rename from local_hydra/local_experiment/ae_dc_large_periodic.yaml rename to local_hydra/local_experiment/ae/ae_dc_large_periodic.yaml diff --git a/local_hydra/local_experiment/ae/ae_dc_large_periodic_128.yaml b/local_hydra/local_experiment/ae/ae_dc_large_periodic_128.yaml new file mode 100644 index 00000000..5ecd5046 --- /dev/null +++ b/local_hydra/local_experiment/ae/ae_dc_large_periodic_128.yaml @@ -0,0 +1,31 @@ +# @package _global_ +defaults: + - /distributed: ddp_4gpu_slurm + - override /datamodule: advection_diffusion + - override /model: autoencoder_dc_large + - override /optimizer: psgd + - _self_ + +experiment_name: ae_dc_large + +datamodule: + use_normalization: false + batch_size: 64 + +model: + encoder: + periodic: true + decoder: + periodic: true + +logging: + wandb: + enabled: false + +optimizer: + learning_rate: 1e-5 + weight_decay: 0.0 + scheduler: cosine + +trainer: + gradient_clip_val: 1.0 diff --git a/local_hydra/local_experiment/ae/ae_dc_large_periodic_single_gpu.yaml b/local_hydra/local_experiment/ae/ae_dc_large_periodic_single_gpu.yaml new file mode 100644 index 00000000..188991b1 --- /dev/null +++ b/local_hydra/local_experiment/ae/ae_dc_large_periodic_single_gpu.yaml @@ -0,0 +1,33 @@ +# @package _global_ +defaults: + - /distributed: single_gpu_slurm + - override /datamodule: advection_diffusion + - override /model: autoencoder_dc_large + - override /optimizer: psgd + - _self_ + +experiment_name: ae_dc_large + +datamodule: + # use_normalization: false + # batch_size: 64 + use_normalization: true + batch_size: 256 + +model: + encoder: + periodic: true + decoder: + periodic: true + +logging: + wandb: + enabled: false + +optimizer: + learning_rate: 1e-5 + weight_decay: 0.0 + scheduler: cosine + +trainer: + gradient_clip_val: 1.0 \ No newline at end of file diff --git a/local_hydra/local_experiment/epd_crps_vit_azula_4gpu.yaml b/local_hydra/local_experiment/epd_crps_vit_azula_4gpu.yaml new file mode 100644 index 00000000..b6946415 --- /dev/null +++ b/local_hydra/local_experiment/epd_crps_vit_azula_4gpu.yaml @@ -0,0 +1,34 @@ +# @package _global_ +defaults: + - /distributed: ddp_4gpu_slurm + - override /datamodule: advection_diffusion_multichannel_64_64 + - override /encoder@model.encoder: permute_concat + - override /decoder@model.decoder: channels_last + # - override /processor@model.processor: vit_azula_large + #- override /input_noise_injector@model.input_noise_injector: concat + - _self_ + +experiment_name: epd_crps_vit_azula + +datamodule: + use_normalization: true + batch_size: 32 + +logging: + wandb: + enabled: true + +optimizer: + learning_rate: 0.0002 + +model: + train_in_latent_space: false + n_members: 10 + encoder: + with_constants: true + loss_func: + _target_: autocast.losses.ensemble.CRPSLoss + train_metrics: + crps: + _target_: autocast.metrics.ensemble.CRPS + diff --git a/local_hydra/local_experiment/epd_crps_vit_azula_afcrps.yaml b/local_hydra/local_experiment/epd_crps_vit_azula_afcrps.yaml new file mode 100644 index 00000000..27daea07 --- /dev/null +++ b/local_hydra/local_experiment/epd_crps_vit_azula_afcrps.yaml @@ -0,0 +1,34 @@ +# @package _global_ +defaults: + - /distributed: single_gpu_slurm + - override /datamodule: advection_diffusion_multichannel_64_64 + - override /encoder@model.encoder: permute_concat + - override /decoder@model.decoder: channels_last + # - override /processor@model.processor: vit_azula_large + #- override /input_noise_injector@model.input_noise_injector: concat + - _self_ + +experiment_name: epd_crps_vit_azula + +datamodule: + use_normalization: true + batch_size: 32 + +logging: + wandb: + enabled: true + +optimizer: + learning_rate: 0.0002 + +model: + train_in_latent_space: false + n_members: 10 + encoder: + with_constants: true + loss_func: + _target_: autocast.losses.ensemble.AlphaFairCRPSLoss + train_metrics: + crps: + _target_: autocast.metrics.ensemble.CRPS + diff --git a/local_hydra/local_experiment/epd_crps_vit_azula_concat.yaml b/local_hydra/local_experiment/epd_crps_vit_azula_concat.yaml new file mode 100644 index 00000000..60aac172 --- /dev/null +++ b/local_hydra/local_experiment/epd_crps_vit_azula_concat.yaml @@ -0,0 +1,34 @@ +# @package _global_ +defaults: + - /distributed: single_gpu_slurm + - override /datamodule: advection_diffusion_multichannel_64_64 + - override /encoder@model.encoder: permute_concat + - override /decoder@model.decoder: channels_last + # - override /processor@model.processor: vit_azula_large + - override /input_noise_injector@model.input_noise_injector: concat + - _self_ + +experiment_name: epd_crps_vit_azula + +datamodule: + use_normalization: true + batch_size: 32 + +logging: + wandb: + enabled: true + +optimizer: + learning_rate: 0.0002 + +model: + train_in_latent_space: false + n_members: 10 + encoder: + with_constants: true + loss_func: + _target_: autocast.losses.ensemble.CRPSLoss + train_metrics: + crps: + _target_: autocast.metrics.ensemble.CRPS + diff --git a/local_hydra/local_experiment/epd_crps_vit_azula_mae.yaml b/local_hydra/local_experiment/epd_crps_vit_azula_mae.yaml new file mode 100644 index 00000000..84c29cd0 --- /dev/null +++ b/local_hydra/local_experiment/epd_crps_vit_azula_mae.yaml @@ -0,0 +1,34 @@ +# @package _global_ +defaults: + - /distributed: single_gpu_slurm + - override /datamodule: advection_diffusion_multichannel_64_64 + - override /encoder@model.encoder: permute_concat + - override /decoder@model.decoder: channels_last + # - override /processor@model.processor: vit_azula_large + #- override /input_noise_injector@model.input_noise_injector: concat + - _self_ + +experiment_name: epd_crps_vit_azula + +datamodule: + use_normalization: true + batch_size: 32 + +logging: + wandb: + enabled: true + +optimizer: + learning_rate: 0.0002 + +model: + train_in_latent_space: false + n_members: 10 + encoder: + with_constants: true + loss_func: + _target_: autocast.losses.ensemble.EnsembleMAELoss + train_metrics: + crps: + _target_: autocast.metrics.ensemble.CRPS + diff --git a/local_hydra/local_experiment/epd_crps_vit_azula_n_noise_1024.yaml b/local_hydra/local_experiment/epd_crps_vit_azula_n_noise_1024.yaml new file mode 100644 index 00000000..2f77d7e7 --- /dev/null +++ b/local_hydra/local_experiment/epd_crps_vit_azula_n_noise_1024.yaml @@ -0,0 +1,36 @@ +# @package _global_ +defaults: + - /distributed: single_gpu_slurm + - override /datamodule: advection_diffusion_multichannel_64_64 + - override /encoder@model.encoder: permute_concat + - override /decoder@model.decoder: channels_last + # - override /processor@model.processor: vit_azula_large + #- override /input_noise_injector@model.input_noise_injector: concat + - _self_ + +experiment_name: epd_crps_vit_azula + +datamodule: + use_normalization: true + batch_size: 32 + +logging: + wandb: + enabled: true + +optimizer: + learning_rate: 0.0002 + +model: + train_in_latent_space: false + n_members: 10 + encoder: + with_constants: true + processor: + n_noise_channels: 1024 + loss_func: + _target_: autocast.losses.ensemble.CRPSLoss + train_metrics: + crps: + _target_: autocast.metrics.ensemble.CRPS + diff --git a/pyproject.toml b/pyproject.toml index 91ad079c..1b5d0d37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,13 +116,16 @@ explicit = true [tool.uv.sources] autoemulate = { git = "https://github.com/alan-turing-institute/autoemulate.git" } torch = [ - { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] torchvision = [ - { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] [tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["tests"] +norecursedirs = ["outputs", "slurm_scripts"] filterwarnings = [ # Ignore Lightning warnings that are expected/benign in test environment "ignore:You are trying to `self.log\\(\\)` but the `self.trainer` reference is not registered:UserWarning", diff --git a/src/autocast/configs/datamodule/conditioned_navier_stokes_128.yaml b/src/autocast/configs/datamodule/conditioned_navier_stokes_128.yaml new file mode 100644 index 00000000..cf075649 --- /dev/null +++ b/src/autocast/configs/datamodule/conditioned_navier_stokes_128.yaml @@ -0,0 +1,10 @@ +_target_: autocast.data.datamodule.SpatioTemporalDataModule +data_path: "${oc.env:AUTOCAST_DATASETS,./datasets}/128x128/conditioned_navier_stokes_2d_b106e4d" +batch_size: 16 +n_steps_input: 1 +n_steps_output: 4 +stride: 1 +verbose: false +use_normalization: false +normalization_path: ${.data_path}/stats.yml +num_workers: 0 diff --git a/src/autocast/configs/datamodule/gpe_laser_only_wake.yaml b/src/autocast/configs/datamodule/gpe_laser_only_wake.yaml index ed629a85..2427ed1a 100644 --- a/src/autocast/configs/datamodule/gpe_laser_only_wake.yaml +++ b/src/autocast/configs/datamodule/gpe_laser_only_wake.yaml @@ -1,5 +1,5 @@ _target_: autocast.data.datamodule.SpatioTemporalDataModule -data_path: "${oc.env:AUTOCAST_DATASETS,./datasets}/gpe/laser_only_wake_9be0bfb" +data_path: "${oc.env:AUTOCAST_DATASETS,./datasets}/gpe/laser_only_wake_5b51eac" batch_size: 16 n_steps_input: 1 n_steps_output: 4 diff --git a/src/autocast/configs/datamodule/gpe_laser_only_wake_128.yaml b/src/autocast/configs/datamodule/gpe_laser_only_wake_128.yaml new file mode 100644 index 00000000..4c1f3f9b --- /dev/null +++ b/src/autocast/configs/datamodule/gpe_laser_only_wake_128.yaml @@ -0,0 +1,10 @@ +_target_: autocast.data.datamodule.SpatioTemporalDataModule +data_path: "${oc.env:AUTOCAST_DATASETS,./datasets}/128x128/gpe/laser_only_wake_f15bdbb" +batch_size: 16 +n_steps_input: 1 +n_steps_output: 4 +stride: 1 +verbose: false +use_normalization: false +normalization_path: ${.data_path}/stats.yml +num_workers: 0 diff --git a/src/autocast/configs/datamodule/shallow_water2d_128.yaml b/src/autocast/configs/datamodule/shallow_water2d_128.yaml new file mode 100644 index 00000000..543bb57f --- /dev/null +++ b/src/autocast/configs/datamodule/shallow_water2d_128.yaml @@ -0,0 +1,10 @@ +_target_: autocast.data.datamodule.SpatioTemporalDataModule +data_path: "${oc.env:AUTOCAST_DATASETS,./datasets}/128x128/shallow_water2d_433068d" +batch_size: 16 +n_steps_input: 1 +n_steps_output: 4 +stride: 1 +verbose: false +use_normalization: false +normalization_path: ${.data_path}/stats.yml +num_workers: 0 diff --git a/src/autocast/configs/distributed/single_gpu_slurm.yaml b/src/autocast/configs/distributed/single_gpu_slurm.yaml index 2b4b5f8d..488c886e 100644 --- a/src/autocast/configs/distributed/single_gpu_slurm.yaml +++ b/src/autocast/configs/distributed/single_gpu_slurm.yaml @@ -10,6 +10,13 @@ trainer: strategy: auto num_nodes: 1 +hydra: + launcher: + gpus_per_node: 1 + tasks_per_node: 1 + additional_parameters: + ntasks: 1 + eval: accelerator: auto devices: 1 \ No newline at end of file diff --git a/src/autocast/configs/optimizer/adamw.yaml b/src/autocast/configs/optimizer/adamw.yaml index bb638514..3309921f 100644 --- a/src/autocast/configs/optimizer/adamw.yaml +++ b/src/autocast/configs/optimizer/adamw.yaml @@ -5,4 +5,15 @@ learning_rate: 1e-4 weight_decay: 0.0 warmup: 0 scheduler: "cosine" +# Optional scheduler controls: +# scheduler_interval: "epoch" # "epoch" or "step" +# cosine_epochs: null # Used when scheduler_interval="epoch" +# cosine_steps: null # Used when scheduler_interval="step" +# cosine_t_max: 1 # Fallback horizon if trainer values are unavailable +# Horizon precedence (highest to lowest): +# 1) cosine_epochs/cosine_steps (based on scheduler_interval) +# 2) trainer max value (max_epochs or estimated_stepping_batches) +# 3) cosine_t_max +# min_lr_ratio: 0.0 # Final LR ratio in [0, 1] +# scheduler: "cosine_with_restarts" grad_clip: 1 diff --git a/src/autocast/configs/optimizer/adamw_half.yaml b/src/autocast/configs/optimizer/adamw_half.yaml new file mode 100644 index 00000000..4569034f --- /dev/null +++ b/src/autocast/configs/optimizer/adamw_half.yaml @@ -0,0 +1,24 @@ +# Compatibility with: https://github.com/francois-rozet/lola/blob/21a4354b327e6e5ee06da5075ba3bd1dd88c61f1/experiments/configs/optim/adamw.yaml +name: "${.optimizer}_${.learning_rate}_${.scheduler}_half" +optimizer: "adamw" +betas: [0.9, 0.99] +learning_rate: 1e-4 +weight_decay: 0.0 +scheduler: "cosine" + +# lr_lambda(t) = (1 + cos(pi * t / epochs)) / 2 +# cold_lr_lambda(t) = min(1, (t + 1) / (warmup + 1)) * lr_lambda(t) +scheduler_interval: "epoch" +cosine_epochs: ${trainer.max_epochs} + +# Horizon precedence (highest to lowest): +# 1) Explicit interval key: cosine_epochs (epoch) or cosine_steps (step) +# 2) Trainer-derived value (max_epochs / estimated_stepping_batches) +# 3) cosine_t_max fallback +# In this file, cosine_epochs is set, so cosine_t_max is used only if that +# interpolation is unavailable. +cosine_t_max: 130 + +warmup: 0 +min_lr_ratio: 0.0 +grad_clip: 1 diff --git a/src/autocast/data/datamodule.py b/src/autocast/data/datamodule.py index 92153b06..8b3c439b 100644 --- a/src/autocast/data/datamodule.py +++ b/src/autocast/data/datamodule.py @@ -310,6 +310,7 @@ def train_dataloader(self) -> DataLoader: shuffle=True, num_workers=self.num_workers, collate_fn=collate_batches, + pin_memory=True, ) def val_dataloader(self) -> DataLoader: @@ -320,6 +321,7 @@ def val_dataloader(self) -> DataLoader: shuffle=False, num_workers=self.num_workers, collate_fn=collate_batches, + pin_memory=True, ) def rollout_val_dataloader(self, batch_size: int | None = None) -> DataLoader: @@ -336,6 +338,7 @@ def rollout_val_dataloader(self, batch_size: int | None = None) -> DataLoader: shuffle=False, num_workers=self.num_workers, collate_fn=collate_batches, + pin_memory=True, ) def test_dataloader(self) -> DataLoader: @@ -346,6 +349,7 @@ def test_dataloader(self) -> DataLoader: shuffle=False, num_workers=self.num_workers, collate_fn=collate_batches, + pin_memory=True, ) def rollout_test_dataloader(self, batch_size: int | None = None) -> DataLoader: @@ -362,4 +366,5 @@ def rollout_test_dataloader(self, batch_size: int | None = None) -> DataLoader: shuffle=False, num_workers=self.num_workers, collate_fn=collate_batches, + pin_memory=True, ) diff --git a/src/autocast/losses/__init__.py b/src/autocast/losses/__init__.py index 3ab435e3..1c5a1853 100644 --- a/src/autocast/losses/__init__.py +++ b/src/autocast/losses/__init__.py @@ -1,3 +1,13 @@ -from autocast.losses.ensemble import AlphaFairCRPSLoss, CRPSLoss, FairCRPSLoss +from autocast.losses.ensemble import ( + AlphaFairCRPSLoss, + CRPSLoss, + EnsembleMAELoss, + FairCRPSLoss, +) -__all__ = ["AlphaFairCRPSLoss", "CRPSLoss", "FairCRPSLoss"] +__all__ = [ + "AlphaFairCRPSLoss", + "CRPSLoss", + "EnsembleMAELoss", + "FairCRPSLoss", +] diff --git a/src/autocast/losses/ensemble.py b/src/autocast/losses/ensemble.py index 3222c620..6a0a16a2 100644 --- a/src/autocast/losses/ensemble.py +++ b/src/autocast/losses/ensemble.py @@ -75,3 +75,11 @@ def __init__(self, alpha: float = 0.95, reduction: str = "mean") -> None: def _compute_score(self, preds: TensorBTSCM, targets: TensorBTSC) -> Tensor: return _alpha_fair_crps_score(preds, targets, alpha=self.alpha) + + +class EnsembleMAELoss(EnsembleLoss): + """Mean absolute error computed from the ensemble mean forecast.""" + + def _compute_score(self, preds: TensorBTSCM, targets: TensorBTSC) -> Tensor: + ensemble_mean = preds.mean(dim=-1) + return (ensemble_mean - targets).abs() diff --git a/src/autocast/models/optimizer_mixin.py b/src/autocast/models/optimizer_mixin.py index cb38b8c4..70df391a 100644 --- a/src/autocast/models/optimizer_mixin.py +++ b/src/autocast/models/optimizer_mixin.py @@ -1,5 +1,6 @@ """Optimizer configuration mixin for Lightning modules.""" +import math from typing import Any import heavyball @@ -21,6 +22,48 @@ class OptimizerMixin(nn.Module): # Type hints for attributes expected from the concrete class optimizer_config: DictConfig | dict[str, Any] | None + def _get_scheduler_interval(self, cfg: dict[str, Any]) -> str: + """Return scheduler interval ('epoch' or 'step').""" + interval = str(cfg.get("scheduler_interval", "epoch")).lower() + if interval not in {"epoch", "step"}: + msg = ( + f"scheduler_interval must be either 'epoch' or 'step'. Got: {interval}" + ) + raise ValueError(msg) + return interval + + def _as_positive_int(self, value: Any, field_name: str) -> int: + """Convert value to a positive integer.""" + int_value = int(value) + if int_value <= 0: + msg = f"{field_name} must be a positive integer. Got: {value}" + raise ValueError(msg) + return int_value + + def _resolve_cosine_horizon(self, cfg: dict[str, Any], interval: str) -> int: + """Resolve the cosine horizon used in the scheduler lambda.""" + explicit_key = "cosine_steps" if interval == "step" else "cosine_epochs" + explicit_value = cfg.get(explicit_key) + if explicit_value is not None: + return self._as_positive_int(explicit_value, explicit_key) + + trainer = getattr(self, "trainer", None) + if interval == "step" and trainer is not None: + estimated_steps = getattr(trainer, "estimated_stepping_batches", None) + if estimated_steps is not None: + try: + estimated_int = int(estimated_steps) + except (TypeError, ValueError): + estimated_int = 0 + if estimated_int > 0: + return estimated_int + + if trainer is not None and trainer.max_epochs is not None: + return self._as_positive_int(trainer.max_epochs, "trainer.max_epochs") + + fallback = cfg.get("cosine_t_max", 1) + return self._as_positive_int(fallback, "cosine_t_max") + def _create_optimizer( self, cfg: DictConfig | dict[str, Any] ) -> torch.optim.Optimizer: @@ -133,15 +176,37 @@ def _create_scheduler( ) -> torch.optim.lr_scheduler.LRScheduler: """Create learning rate scheduler from config.""" scheduler_name = str(cfg.get("scheduler", "")).lower() + scheduler_interval = self._get_scheduler_interval(cfg) + + warmup = int(cfg.get("warmup", 0)) + warmup = max(warmup, 0) + min_lr_ratio = float(cfg.get("min_lr_ratio", 0.0)) + if not 0.0 <= min_lr_ratio <= 1.0: + msg = f"min_lr_ratio must be in [0, 1]. Got: {min_lr_ratio}" + raise ValueError(msg) + + if scheduler_name in {"cosine", "cosine_with_restarts"}: + horizon = self._resolve_cosine_horizon(cfg, scheduler_interval) + use_restarts = scheduler_name == "cosine_with_restarts" + + def cosine_lambda(t: int) -> float: + phase_t = t % horizon if use_restarts else t + cosine = 0.5 * ( + 1.0 + math.cos(math.pi * float(phase_t) / float(horizon)) + ) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine + + if warmup > 0: + + def lr_lambda(t: int) -> float: + warm = min(1.0, float(t + 1) / float(warmup + 1)) + return warm * cosine_lambda(t) + + else: + lr_lambda = cosine_lambda + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) - if scheduler_name == "cosine": - max_epochs = 1 - trainer = getattr(self, "trainer", None) - if trainer is not None and trainer.max_epochs is not None: - max_epochs = int(trainer.max_epochs) - return torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=max_epochs, eta_min=0 - ) if scheduler_name == "step": step_size = cfg.get("step_size", 30) gamma = cfg.get("gamma", 0.1) @@ -184,6 +249,7 @@ def configure_optimizers(self) -> OptimizerLRScheduler: return optimizer scheduler = self._create_scheduler(optimizer, cfg) + scheduler_interval = self._get_scheduler_interval(cfg) # ReduceLROnPlateau needs special handling if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): @@ -191,8 +257,16 @@ def configure_optimizers(self) -> OptimizerLRScheduler: "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, + "interval": "epoch", "monitor": "val_loss", }, } - return {"optimizer": optimizer, "lr_scheduler": scheduler} + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": scheduler_interval, + "frequency": 1, + }, + } diff --git a/src/autocast/nn/base.py b/src/autocast/nn/base.py index f8b6bb05..6d2f2804 100644 --- a/src/autocast/nn/base.py +++ b/src/autocast/nn/base.py @@ -38,6 +38,7 @@ def __init__( # TCN parameters tcn_kernel_size: int = 3, tcn_num_layers: int = 2, + use_precomputed_modulation: bool = False, ): """Initialize Temporal Backbone Base. @@ -68,6 +69,7 @@ def __init__( self.n_steps_output = n_steps_output self.n_steps_input = n_steps_input self.mod_features = mod_features + self.use_precomputed_modulation = use_precomputed_modulation # Validate global conditioning configuration if include_global_cond and ( @@ -78,12 +80,18 @@ def __init__( self.global_cond_channels = global_cond_channels self.include_global_cond = include_global_cond - # Time embedding for diffusion timestep - self.time_embedding = nn.Sequential( - SineEncoding(mod_features), - nn.Linear(mod_features, mod_features), - nn.SiLU(), - nn.Linear(mod_features, mod_features), + # Time embedding for scalar diffusion timesteps. Some models pass + # precomputed modulation vectors directly and should not register + # unused embedding parameters under strict DDP. + self.time_embedding = ( + None + if self.use_precomputed_modulation + else nn.Sequential( + SineEncoding(mod_features), + nn.Linear(mod_features, mod_features), + nn.SiLU(), + nn.Linear(mod_features, mod_features), + ) ) self.global_cond_embedding = ( @@ -223,6 +231,12 @@ def forward( if t.ndim == 2 and t.shape[-1] == self.mod_features: t_emb = t else: + if self.time_embedding is None: + msg = ( + "Expected precomputed modulation vectors with shape " + "(B, mod_features), but received scalar timesteps." + ) + raise ValueError(msg) t_emb = self.time_embedding(t) # Combine with global conditioning embedding if provided diff --git a/src/autocast/nn/vit.py b/src/autocast/nn/vit.py index 4bcac241..189496e2 100644 --- a/src/autocast/nn/vit.py +++ b/src/autocast/nn/vit.py @@ -40,6 +40,7 @@ def __init__( dropout: float = 0.0, ffn_factor: int = 4, checkpointing: bool = False, + use_precomputed_modulation: bool = False, ): """Initialize Temporal ViT Backbone. @@ -84,6 +85,7 @@ def __init__( temporal_attention_hidden_dim=temporal_attention_hidden_dim, tcn_kernel_size=tcn_kernel_size, tcn_num_layers=tcn_num_layers, + use_precomputed_modulation=use_precomputed_modulation, ) self.patch_size = patch_size diff --git a/src/autocast/processors/azula_vit.py b/src/autocast/processors/azula_vit.py index 0ec7b9dc..10e037a1 100644 --- a/src/autocast/processors/azula_vit.py +++ b/src/autocast/processors/azula_vit.py @@ -57,6 +57,7 @@ def __init__( temporal_method=temporal_method, temporal_attention_heads=num_heads, temporal_attention_hidden_dim=hidden_dim // num_heads, + use_precomputed_modulation=True, ) def forward(self, x: Tensor, x_noise: Tensor | None = None) -> Tensor: diff --git a/src/autocast/scripts/eval/encoder_processor_decoder.py b/src/autocast/scripts/eval/encoder_processor_decoder.py index 73b66836..7f214e93 100644 --- a/src/autocast/scripts/eval/encoder_processor_decoder.py +++ b/src/autocast/scripts/eval/encoder_processor_decoder.py @@ -1188,10 +1188,6 @@ def run_evaluation(cfg: DictConfig, work_dir: Path | None = None) -> None: # no log.info("Loading checkpoint from %s", checkpoint_path) checkpoint_payload = load_checkpoint_payload(checkpoint_path) processor_only = _is_processor_only_checkpoint(checkpoint_payload) - log.info( - "Checkpoint type: %s", - "processor-only" if processor_only else "encoder-processor-decoder", - ) # Setup datamodule and resolve config datamodule, cfg, stats = setup_datamodule(cfg) @@ -1226,6 +1222,28 @@ def run_evaluation(cfg: DictConfig, work_dir: Path | None = None) -> None: # no # Metrics are computed in latent space. example_batch = stats.get("example_batch") + + # Stateless encoders/decoders (e.g. PermuteConcat/ChannelsLast) contribute no + # `encoder_decoder.*` params, so a full EPD checkpoint can look processor-only. + if ( + processor_only + and isinstance(example_batch, Batch) + and not cfg.get("autoencoder_checkpoint") + and cfg.get("model", {}).get("encoder") is not None + and cfg.get("model", {}).get("decoder") is not None + ): + log.info( + "Checkpoint contains no encoder_decoder.* params, but datamodule " + "returns raw Batch and model has encoder+decoder config. " + "Assuming full EPD checkpoint with stateless encoder/decoder." + ) + processor_only = False + + log.info( + "Checkpoint type: %s", + "processor-only" if processor_only else "encoder-processor-decoder", + ) + decode_fn = None # optional callable: latent tensor → data-space tensor decoder_module = None # keep reference for device placement diff --git a/src/autocast/scripts/setup.py b/src/autocast/scripts/setup.py index b703ed00..6542cc29 100644 --- a/src/autocast/scripts/setup.py +++ b/src/autocast/scripts/setup.py @@ -461,6 +461,15 @@ def _build_loss_func(model_config: DictConfig) -> nn.Module: return instantiate(loss_func_config) +def _maybe_add_metric_overrides( + kwargs: dict[str, Any], model_config: DictConfig +) -> None: + """Forward explicit metric overrides from model config when present.""" + for key in ("train_metrics", "val_metrics", "test_metrics"): + if key in model_config: + kwargs[key] = model_config.get(key) + + def setup_processor_model( config: DictConfig, stats: dict, @@ -500,6 +509,7 @@ def setup_processor_model( "noise_injector": noise_injector, "norm": norm, } + _maybe_add_metric_overrides(kwargs, model_config) if is_ensemble: kwargs["n_members"] = model_config.get("n_members") @@ -607,6 +617,7 @@ def setup_epd_model( "input_noise_injector": noise_injector, "norm": norm, } + _maybe_add_metric_overrides(kwargs, model_config) if is_ensemble: kwargs["n_members"] = model_config.get("n_members") diff --git a/src/autocast/utils/optimizer.py b/src/autocast/utils/optimizer.py index 5c465cb2..29f82203 100644 --- a/src/autocast/utils/optimizer.py +++ b/src/autocast/utils/optimizer.py @@ -17,7 +17,8 @@ def get_optimizer_config( Args: learning_rate: Learning rate for the optimizer. Default 1e-4. optimizer: Optimizer name ('adam', 'adamw', 'sgd'). Default 'adam'. - scheduler: Optional scheduler name ('cosine', 'step', 'plateau'). + scheduler: Optional scheduler name ('cosine', 'cosine_with_restarts', + 'step', 'plateau'). Default None (no scheduler). **kwargs: Additional optimizer parameters (e.g., betas, weight_decay, step_size, gamma). diff --git a/tests/losses/__init__.py b/tests/losses/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/losses/test_ensemble.py b/tests/losses/test_ensemble.py new file mode 100644 index 00000000..991be92d --- /dev/null +++ b/tests/losses/test_ensemble.py @@ -0,0 +1,57 @@ +import pytest +import torch + +from autocast.losses import EnsembleMAELoss + + +def test_ensemble_mae_loss_uses_ensemble_mean(): + preds = torch.tensor( + [ + [[[[[0.0, 2.0], [2.0, 4.0]]]]], + ] + ) + targets = torch.tensor( + [ + [[[[1.0, 2.0]]]], + ] + ) + + loss = EnsembleMAELoss(reduction="none")(preds, targets) + + expected = torch.tensor( + [ + [[[[0.0, 1.0]]]], + ] + ) + assert torch.allclose(loss, expected) + + +@pytest.mark.parametrize( + ("reduction", "expected"), + [("mean", 0.5), ("sum", 1.0)], +) +def test_ensemble_mae_loss_reductions(reduction, expected): + preds = torch.tensor( + [ + [[[[[0.0, 2.0], [2.0, 4.0]]]]], + ] + ) + targets = torch.tensor( + [ + [[[[1.0, 2.0]]]], + ] + ) + + loss = EnsembleMAELoss(reduction=reduction)(preds, targets) + + assert torch.isclose(loss, torch.tensor(expected)) + + +def test_ensemble_mae_loss_rejects_targets_with_member_dim(): + preds = torch.ones((1, 1, 1, 2, 3)) + targets = torch.ones((1, 1, 1, 2, 3)) + + with pytest.raises( + ValueError, match="Targets should not have the ensemble dimension" + ): + EnsembleMAELoss()(preds, targets) diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/test_processor.py b/tests/models/test_processor.py index a76128a7..535b85a2 100644 --- a/tests/models/test_processor.py +++ b/tests/models/test_processor.py @@ -9,6 +9,7 @@ from autocast.models.processor import ProcessorModel from autocast.nn.unet import TemporalUNetBackbone +from autocast.nn.vit import TemporalViTBackbone from autocast.processors.diffusion import DiffusionProcessor from autocast.processors.flow_matching import FlowMatchingProcessor from autocast.types import EncodedBatch @@ -489,3 +490,35 @@ def test_processor_ignores_global_cond_when_disabled(): assert torch.allclose(output_no_cond, output_with_cond), ( "Output changed despite global_cond being disabled" ) + + +def test_temporal_vit_uses_precomputed_modulation_without_embedding_params(): + backbone = TemporalViTBackbone( + in_channels=4, + out_channels=4, + cond_channels=0, + n_steps_output=1, + n_steps_input=1, + global_cond_channels=None, + include_global_cond=False, + mod_features=32, + hid_channels=32, + hid_blocks=1, + attention_heads=4, + patch_size=4, + spatial=2, + temporal_method="none", + use_precomputed_modulation=True, + ) + + # In precomputed modulation mode, no trainable time embedding parameters + # should be registered, which avoids strict-DDP unused-parameter failures. + assert all( + not name.startswith("time_embedding") for name, _ in backbone.named_parameters() + ) + + x = torch.randn(2, 1, 8, 8, 4) + t = torch.randn(2, 32) + y = backbone(x_t=x, t=t, cond=None, global_cond=None) + + assert y.shape == (2, 1, 8, 8, 4) diff --git a/tests/scripts/test_training.py b/tests/scripts/test_training.py index 12ea2c52..5f47c864 100644 --- a/tests/scripts/test_training.py +++ b/tests/scripts/test_training.py @@ -132,6 +132,43 @@ def test_processor_config_training_step_smoke(config_dir: str, dummy_datamodule) assert loss.ndim == 0 +def test_processor_metric_overrides_are_forwarded(config_dir: str, dummy_datamodule): + processor_cfg = _load_config(config_dir, "processor/flow_matching").processor + with open_dict(processor_cfg): + processor_cfg.backbone.include_global_cond = False + processor_cfg.backbone.global_cond_channels = 0 + + encoded_inputs = torch.randn(2, 2, 4, 4, 1) + encoded_outputs = torch.randn(2, 2, 4, 4, 1) + cfg = OmegaConf.create( + { + "model": { + "processor": processor_cfg, + "loss_func": {"_target_": "torch.nn.MSELoss"}, + "val_metrics": [], + "test_metrics": [], + }, + "optimizer": get_optimizer_config(learning_rate=1e-3), + "datamodule": { + "stride": 1, + "n_steps_input": encoded_inputs.shape[1], + "n_steps_output": encoded_outputs.shape[1], + }, + } + ) + batch = EncodedBatch( + encoded_inputs=encoded_inputs, + encoded_output_fields=encoded_outputs, + global_cond=None, + encoded_info={}, + ) + stats = _stats_from_encoded_batch(batch) + model = setup_processor_model(cfg, stats, dummy_datamodule) + + assert model.val_metrics is None + assert model.test_metrics is None + + # --- TrainingTimerCallback --- @@ -270,6 +307,37 @@ def test_epd_config_forward_smoke(config_dir: str, toy_batch: Batch, dummy_datam assert output.shape == toy_batch.output_fields.shape +def test_epd_metric_overrides_are_forwarded( + config_dir: str, toy_batch: Batch, dummy_datamodule +): + model_cfg = _load_config( + config_dir, + "model/encoder_processor_decoder", + overrides=[ + "encoder@model.encoder=dc", + "decoder@model.decoder=dc", + "processor@model.processor=flow_matching", + ], + ) + cfg = _wrap_model_config(model_cfg) + with open_dict(cfg): + cfg.optimizer = get_optimizer_config() + cfg.datamodule = { + "stride": 1, + "n_steps_input": toy_batch.input_fields.shape[1], + "n_steps_output": toy_batch.output_fields.shape[1], + } + cfg.model.processor.backbone.include_global_cond = False + cfg.model.processor.backbone.global_cond_channels = 0 + cfg.model.val_metrics = [] + cfg.model.test_metrics = [] + stats = _stats_from_batch(toy_batch) + model = setup_epd_model(cfg, stats, dummy_datamodule) + + assert model.val_metrics is None + assert model.test_metrics is None + + def test_infer_latent_spatial_resolution_channels_last_with_time(): class DummyEncoder: channel_axis = -1