Skip to content

Commit a472866

Browse files
committed
Updating hydra config further to support flow matching and diffusion
1 parent edd2bcb commit a472866

9 files changed

Lines changed: 126 additions & 35 deletions

File tree

configs/decoder/identity.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
_target_: autocast.decoders.identity.IdentityDecoder

configs/encoder/identity.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
_target_: autocast.encoders.identity.IdentityEncoder
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
defaults:
2+
- /encoder: identity
3+
- /decoder: identity
4+
- /processor: flow_matching
5+
- _self_
6+
7+
learning_rate: 0.001
8+
train_processor_only: false
9+
teacher_forcing_ratio: 0.5
10+
max_rollout_steps: 10
11+
loss_func:
12+
_target_: torch.nn.MSELoss

configs/processor.yaml

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
defaults:
22
- data: reaction_diffusion
3-
- encoder: permute_concat
4-
- decoder: channels_last
5-
- processor: flow_matching
3+
- model: encoder_processor_decoder
64
- trainer: default
75
- logging: wandb
86
- _self_
@@ -14,18 +12,12 @@ output:
1412
save_config: true
1513

1614
training:
17-
n_steps_input: 4
15+
n_steps_input: 1
1816
n_steps_output: 4
1917
stride: 4
2018
autoencoder_checkpoint: null
2119
freeze_autoencoder: false
2220

23-
encoder_processor_decoder:
24-
learning_rate: 0.001
25-
train_processor_only: false
26-
loss_func:
27-
_target_: torch.nn.MSELoss
28-
2921
hydra:
3022
run:
3123
dir: outputs/${experiment_name}/${now:%Y-%m-%d_%H-%M-%S}

src/autocast/eval/processor.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from autocast.models.encoder_decoder import EncoderDecoder
3232
from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder
33+
from autocast.processors.utils import initialize_flow_matching_backbone
3334
from autocast.train.configuration import (
3435
compose_training_config,
3536
configure_module_dimensions,
@@ -330,17 +331,28 @@ def _load_state_dict(checkpoint_path: Path) -> OrderedDict[str, torch.Tensor]:
330331
return state_dict
331332

332333

333-
def _load_model(cfg: DictConfig, checkpoint_path: Path) -> EncoderProcessorDecoder:
334-
encoder = instantiate(cfg.encoder)
335-
decoder = instantiate(cfg.decoder)
334+
def _load_model(
335+
cfg: DictConfig,
336+
checkpoint_path: Path,
337+
n_steps_input: int,
338+
channel_count: int,
339+
spatial_shape: Sequence[int],
340+
) -> EncoderProcessorDecoder:
341+
model_cfg = cfg.get("model") or cfg
342+
encoder = instantiate(model_cfg.encoder)
343+
decoder = instantiate(model_cfg.decoder)
336344
encoder_decoder = EncoderDecoder(encoder=encoder, decoder=decoder)
337-
processor = instantiate(cfg.processor)
338-
epd_cfg = cfg.get("encoder_processor_decoder") or {}
345+
processor = instantiate(model_cfg.processor)
346+
initialize_flow_matching_backbone(
347+
processor,
348+
n_steps_input,
349+
channel_count,
350+
spatial_shape,
351+
)
352+
epd_cfg = model_cfg
339353
learning_rate = epd_cfg.get("learning_rate", 1e-3)
340-
training_cfg = cfg.get("training")
341-
stride = 1
342-
if isinstance(training_cfg, DictConfig):
343-
stride = training_cfg.get("stride", 1)
354+
training_cfg = cfg.get("training") or {}
355+
stride = training_cfg.get("stride", 1)
344356
teacher_forcing_ratio = epd_cfg.get("teacher_forcing_ratio", 0.5)
345357
max_rollout_steps = epd_cfg.get("max_rollout_steps", 10)
346358
loss_cfg = epd_cfg.get("loss_func")
@@ -495,8 +507,8 @@ def main() -> None:
495507
channel_count,
496508
inferred_n_steps_input,
497509
inferred_n_steps_output,
498-
_,
499-
_,
510+
input_shape,
511+
output_shape,
500512
) = prepare_datamodule(cfg)
501513

502514
configure_module_dimensions(
@@ -509,7 +521,14 @@ def main() -> None:
509521

510522
metrics = _build_metrics(args.metrics or ("mse", "rmse"))
511523

512-
model = _load_model(cfg, args.checkpoint)
524+
spatial_shape = tuple(input_shape[2:-1])
525+
model = _load_model(
526+
cfg,
527+
args.checkpoint,
528+
inferred_n_steps_input,
529+
channel_count,
530+
spatial_shape,
531+
)
513532
device = _resolve_device(args.device)
514533
model.to(device)
515534

src/autocast/processors/flow_matching.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any
44

55
import torch
6+
from omegaconf import DictConfig, OmegaConf
67
from torch import nn
78

89
from autocast.nn.unet import TemporalUNetBackbone
@@ -28,7 +29,7 @@ def __init__(
2829
flow_ode_steps: int = 1,
2930
n_steps_output: int = 4,
3031
n_channels_out: int = 1,
31-
backbone_kwargs: dict[str, Any] | None = None,
32+
backbone_kwargs: dict[str, Any] | DictConfig | None = None,
3233
**kwargs: Any,
3334
) -> None:
3435
# Store core hyperparameters and optional prebuilt backbone.
@@ -46,7 +47,19 @@ def __init__(
4647
self.flow_ode_steps = max(flow_ode_steps, 1)
4748
self.n_steps_output = n_steps_output
4849
self.n_channels_out = n_channels_out
49-
self.backbone_kwargs = backbone_kwargs or {}
50+
processed_kwargs: dict[str, Any] = {}
51+
raw_kwargs: Any | None
52+
if isinstance(backbone_kwargs, DictConfig):
53+
raw_kwargs = OmegaConf.to_container(backbone_kwargs, resolve=True)
54+
else:
55+
raw_kwargs = backbone_kwargs
56+
if isinstance(raw_kwargs, dict):
57+
processed_kwargs = {str(k): v for k, v in raw_kwargs.items()}
58+
for field in ("hid_channels", "hid_blocks"):
59+
value = processed_kwargs.get(field)
60+
if isinstance(value, list):
61+
processed_kwargs[field] = tuple(value)
62+
self.backbone_kwargs = processed_kwargs
5063

5164
def _maybe_build_backbone(self, x: Tensor) -> None:
5265
"""Lazily build TemporalUNetBackbone when no model is provided."""

src/autocast/processors/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Sequence
4+
5+
import torch
6+
7+
8+
def initialize_flow_matching_backbone(
9+
processor,
10+
n_steps_input: int | None,
11+
channel_count: int | None,
12+
spatial_shape: Sequence[int] | None,
13+
) -> None:
14+
"""Instantiate the flow-matching backbone before optimizers are created."""
15+
builder = getattr(processor, "_maybe_build_backbone", None)
16+
has_model = getattr(processor, "flow_matching_model", None) is not None
17+
if builder is None or has_model:
18+
return
19+
if n_steps_input is None or channel_count is None:
20+
return
21+
spatial = tuple(spatial_shape) if spatial_shape is not None else ()
22+
dummy = torch.zeros(
23+
(1, n_steps_input, *spatial, channel_count), dtype=torch.float32
24+
)
25+
builder(dummy)

src/autocast/train/configuration.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,26 @@ def _maybe_set(cfg_node: DictConfig | None, key: str, value: int) -> None:
6565
cfg_node[key] = value
6666

6767

68+
def _model_cfg(cfg: DictConfig) -> DictConfig:
69+
"""Return the nested model config when present, else the root config."""
70+
model_cfg = cfg.get("model")
71+
if isinstance(model_cfg, DictConfig):
72+
return model_cfg
73+
return cfg
74+
75+
6876
def configure_module_dimensions(
6977
cfg: DictConfig,
7078
channel_count: int,
7179
n_steps_input: int,
7280
n_steps_output: int,
7381
) -> None:
7482
"""Populate missing dimension hints for encoder/decoder/processor modules."""
75-
_maybe_set(cfg.decoder, "output_channels", channel_count)
76-
_maybe_set(cfg.decoder, "time_steps", n_steps_output)
77-
processor_cfg = cfg.get("processor")
83+
model_cfg = _model_cfg(cfg)
84+
decoder_cfg = model_cfg.get("decoder")
85+
_maybe_set(decoder_cfg, "output_channels", channel_count)
86+
_maybe_set(decoder_cfg, "time_steps", n_steps_output)
87+
processor_cfg = model_cfg.get("processor")
7888
_maybe_set(processor_cfg, "in_channels", channel_count * n_steps_input)
7989
_maybe_set(processor_cfg, "out_channels", channel_count * n_steps_output)
8090
_maybe_set(processor_cfg, "n_steps_output", n_steps_output)
@@ -88,7 +98,7 @@ def configure_module_dimensions(
8898

8999
def normalize_processor_cfg(cfg: DictConfig) -> None:
90100
"""Force config values into the shapes expected by processor classes."""
91-
processor_cfg = cfg.get("processor")
101+
processor_cfg = _model_cfg(cfg).get("processor")
92102
if processor_cfg is None:
93103
return
94104
tuple_fields = ("n_modes",)

src/autocast/train/processor.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from autocast.models.ae import AE, AELoss
1818
from autocast.models.encoder_decoder import EncoderDecoder
1919
from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder
20+
from autocast.processors.utils import initialize_flow_matching_backbone
2021
from autocast.train.configuration import (
2122
compose_training_config,
2223
configure_module_dimensions,
@@ -165,7 +166,7 @@ def instantiate_trainer(
165166
)
166167

167168

168-
def main() -> None:
169+
def main() -> None: # noqa: PLR0915
169170
"""CLI entrypoint for training the processor."""
170171
args = parse_args()
171172
logging.basicConfig(level=logging.INFO)
@@ -175,6 +176,7 @@ def main() -> None:
175176

176177
cfg = compose_training_config(args)
177178
resolved_cfg = OmegaConf.to_container(cfg, resolve=True)
179+
model_cfg = cfg.get("model") or cfg
178180
wandb_logger, watch_cfg = create_wandb_logger(
179181
cfg.get("logging"),
180182
experiment_name=cfg.get("experiment_name", "processor"),
@@ -225,8 +227,8 @@ def main() -> None:
225227
normalize_processor_cfg(cfg)
226228

227229
encoder, decoder = build_autoencoder_modules(
228-
cfg.encoder,
229-
cfg.decoder,
230+
model_cfg.encoder,
231+
model_cfg.decoder,
230232
training_params.autoencoder_checkpoint,
231233
)
232234
encoder_decoder = EncoderDecoder(encoder=encoder, decoder=decoder)
@@ -236,17 +238,33 @@ def main() -> None:
236238
_freeze_module(encoder_decoder.encoder)
237239
_freeze_module(encoder_decoder.decoder)
238240

239-
processor = instantiate(cfg.processor)
241+
processor = instantiate(model_cfg.processor)
242+
spatial_shape = tuple(input_shape[2:-1])
243+
initialize_flow_matching_backbone(
244+
processor,
245+
inferred_n_steps_input,
246+
channel_count,
247+
spatial_shape,
248+
)
240249

241-
epd_cfg = cfg.get("encoder_processor_decoder")
242-
learning_rate = epd_cfg.get("learning_rate", 1e-3) if epd_cfg is not None else 1e-3
243-
loss_cfg = epd_cfg.get("loss_func") if epd_cfg is not None else None
250+
epd_cfg = model_cfg
251+
learning_rate = epd_cfg.get("learning_rate", 1e-3)
252+
train_processor_only = epd_cfg.get("train_processor_only", False)
253+
teacher_forcing_ratio = epd_cfg.get("teacher_forcing_ratio", 0.5)
254+
max_rollout_steps = epd_cfg.get("max_rollout_steps", 10)
255+
loss_cfg = epd_cfg.get("loss_func")
244256
loss_func = instantiate(loss_cfg) if loss_cfg is not None else nn.MSELoss()
257+
training_cfg = cfg.get("training") or {}
258+
stride = training_cfg.get("stride", 1)
245259

246260
model = EncoderProcessorDecoder(
247261
encoder_decoder=encoder_decoder,
248262
processor=processor,
249263
learning_rate=learning_rate,
264+
train_processor_only=train_processor_only,
265+
stride=stride,
266+
teacher_forcing_ratio=teacher_forcing_ratio,
267+
max_rollout_steps=max_rollout_steps,
250268
loss_func=loss_func,
251269
)
252270

0 commit comments

Comments
 (0)