Skip to content

Commit 6fe4886

Browse files
authored
Merge pull request #101 from alan-turing-institute/77-flow-matching-and-diffusion-config-remove-backbone-kwargs
Remove obsolete backbone kwargs, update config (#77)
2 parents 528d035 + e0526bb commit 6fe4886

9 files changed

Lines changed: 20 additions & 95 deletions

File tree

configs/data/reaction_diffusion.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
data_path: null
2-
use_simulator: true
1+
data_path: ./datasets/reaction_diffusion/
2+
use_simulator: false
33
split:
44
n_train: 4
55
n_valid: 2

configs/processor/flow_matching.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ learning_rate: 0.0001
66
flow_ode_steps: 4
77
n_steps_output: null
88
n_channels_out: null
9-
backbone_kwargs:
9+
backbone:
10+
_target_: autocast.nn.unet.TemporalUNetBackbone
11+
in_channels: null
12+
out_channels: null
13+
cond_channels: null
1014
mod_features: 256
1115
hid_channels: [32, 64, 128]
1216
hid_blocks: [2, 2, 2]

src/autocast/data/dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ def _from_f(self, f):
180180
self.constant_scalars = (
181181
torch.Tensor(f["constant_scalars"][:]).to(self.dtype) # type: ignore # noqa: PGH003
182182
if "constant_scalars" in f
183+
and f["constant_scalars"] is not None
184+
and f["constant_scalars"] != {}
183185
else None
184186
) # [N, C]
185187

@@ -188,7 +190,9 @@ def _from_f(self, f):
188190
torch.Tensor(f["constant_fields"][:]).to( # type: ignore # noqa: PGH003
189191
self.dtype
190192
) # [N, W, H, C]
191-
if "constant_fields" in f and f["constant_fields"] != {}
193+
if "constant_fields" in f
194+
and f["constant_fields"] is not None
195+
and f["constant_fields"] != {}
192196
else None
193197
)
194198

src/autocast/eval/processor.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
)
3131
from autocast.models.encoder_decoder import EncoderDecoder
3232
from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder
33-
from autocast.processors.utils import initialize_flow_matching_backbone
3433
from autocast.train.configuration import (
3534
compose_training_config,
3635
configure_module_dimensions,
@@ -340,21 +339,12 @@ def _load_state_dict(checkpoint_path: Path) -> OrderedDict[str, torch.Tensor]:
340339
def _load_model(
341340
cfg: DictConfig,
342341
checkpoint_path: Path,
343-
n_steps_input: int,
344-
channel_count: int,
345-
spatial_shape: Sequence[int],
346342
) -> EncoderProcessorDecoder:
347343
model_cfg = cfg.get("model") or cfg
348344
encoder = instantiate(model_cfg.encoder)
349345
decoder = instantiate(model_cfg.decoder)
350346
encoder_decoder = EncoderDecoder(encoder=encoder, decoder=decoder)
351347
processor = instantiate(model_cfg.processor)
352-
initialize_flow_matching_backbone(
353-
processor,
354-
n_steps_input,
355-
channel_count,
356-
spatial_shape,
357-
)
358348
epd_cfg = model_cfg
359349
learning_rate = epd_cfg.get("learning_rate", 1e-3)
360350
training_cfg = cfg.get("training") or {}
@@ -486,7 +476,7 @@ def main() -> None:
486476
channel_count,
487477
inferred_n_steps_input,
488478
inferred_n_steps_output,
489-
input_shape,
479+
_,
490480
_,
491481
) = prepare_datamodule(cfg)
492482

@@ -500,13 +490,9 @@ def main() -> None:
500490

501491
metrics = _build_metrics(args.metrics or ("mse", "rmse"))
502492

503-
spatial_shape = tuple(input_shape[2:-1])
504493
model = _load_model(
505494
cfg,
506495
args.checkpoint,
507-
inferred_n_steps_input,
508-
channel_count,
509-
spatial_shape,
510496
)
511497
device = _resolve_device(args.device)
512498
model.to(device)

src/autocast/nn/unet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import Sequence
2+
13
from azula.nn.embedding import SineEncoding
24
from azula.nn.unet import UNet
35
from einops import rearrange
@@ -15,8 +17,8 @@ def __init__(
1517
out_channels: int = 1,
1618
cond_channels: int = 1,
1719
mod_features: int = 256,
18-
hid_channels: tuple = (32, 64, 128),
19-
hid_blocks: tuple = (2, 2, 2),
20+
hid_channels: Sequence[int] = (32, 64, 128),
21+
hid_blocks: Sequence[int] = (2, 2, 2),
2022
spatial: int = 2,
2123
periodic: bool = False,
2224
):

src/autocast/processors/flow_matching.py

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
from typing import Any
44

55
import torch
6-
from omegaconf import DictConfig, OmegaConf
76
from torch import nn
87

9-
from autocast.nn.unet import TemporalUNetBackbone
108
from autocast.processors.base import Processor
119
from autocast.types import EncodedBatch, Tensor
1210

@@ -17,8 +15,7 @@ class FlowMatchingProcessor(Processor):
1715
def __init__(
1816
self,
1917
*,
20-
flow_matching_model: nn.Module | None = None,
21-
backbone: nn.Module | None = None,
18+
backbone: nn.Module,
2219
schedule: Any | None = None,
2320
denoiser_type: str | None = None,
2421
stride: int = 1,
@@ -29,7 +26,6 @@ def __init__(
2926
flow_ode_steps: int = 1,
3027
n_steps_output: int = 4,
3128
n_channels_out: int = 1,
32-
backbone_kwargs: dict[str, Any] | DictConfig | None = None,
3329
**kwargs: Any,
3430
) -> None:
3531
# Store core hyperparameters and optional prebuilt backbone.
@@ -40,44 +36,13 @@ def __init__(
4036
loss_func=loss_func or nn.MSELoss(),
4137
**kwargs,
4238
)
43-
self.flow_matching_model = flow_matching_model or backbone
39+
self.flow_matching_model = backbone
4440
self.schedule = schedule # accepted for API compatibility
4541
self.denoiser_type = denoiser_type
4642
self.learning_rate = learning_rate
4743
self.flow_ode_steps = max(flow_ode_steps, 1)
4844
self.n_steps_output = n_steps_output
4945
self.n_channels_out = n_channels_out
50-
processed_kwargs: dict[str, Any] = {}
51-
raw_kwargs: Any | None
52-
if isinstance(backbone_kwargs, DictConfig):
53-
raw_kwargs = OmegaConf.to_container(backbone_kwargs, resolve=True)
54-
else:
55-
raw_kwargs = backbone_kwargs
56-
if isinstance(raw_kwargs, dict):
57-
processed_kwargs = {str(k): v for k, v in raw_kwargs.items()}
58-
for field in ("hid_channels", "hid_blocks"):
59-
value = processed_kwargs.get(field)
60-
if isinstance(value, list):
61-
processed_kwargs[field] = tuple(value)
62-
self.backbone_kwargs = processed_kwargs
63-
64-
def _maybe_build_backbone(self, x: Tensor) -> None:
65-
"""Lazily build TemporalUNetBackbone when no model is provided."""
66-
if self.flow_matching_model is not None:
67-
return
68-
69-
# Infer in/out channels from configured temporal/channel counts.
70-
t_in = x.shape[1]
71-
c_in = x.shape[-1]
72-
t_out = self.n_steps_output
73-
c_out = self.n_channels_out
74-
75-
self.flow_matching_model = TemporalUNetBackbone(
76-
in_channels=t_out * c_out,
77-
out_channels=t_out * c_out,
78-
cond_channels=t_in * c_in,
79-
**self.backbone_kwargs,
80-
)
8146

8247
def flow_field(self, z: Tensor, t: Tensor, x: Tensor) -> Tensor:
8348
"""Flow matching vector field.
@@ -94,8 +59,6 @@ def flow_field(self, z: Tensor, t: Tensor, x: Tensor) -> Tensor:
9459
-------
9560
Time derivative of output states with the same shape as `z`.
9661
"""
97-
self._maybe_build_backbone(x)
98-
assert self.flow_matching_model is not None # for type checkers
9962
return self.flow_matching_model(z, t, x)
10063

10164
def forward(self, x: Tensor) -> Tensor:
@@ -146,8 +109,6 @@ def loss(self, batch: EncodedBatch) -> Tensor:
146109
)
147110
raise ValueError(msg)
148111

149-
self._maybe_build_backbone(input_states)
150-
151112
batch_size = target_states.shape[0]
152113

153114
z0 = torch.randn_like(target_states, requires_grad=True)

src/autocast/processors/utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1 @@
11
from __future__ import annotations
2-
3-
from collections.abc import Sequence
4-
5-
import torch
6-
7-
8-
def initialize_flow_matching_backbone(
9-
processor,
10-
n_steps_input: int | None,
11-
channel_count: int | None,
12-
spatial_shape: Sequence[int] | None,
13-
) -> None:
14-
"""Instantiate the flow-matching backbone before optimizers are created."""
15-
builder = getattr(processor, "_maybe_build_backbone", None)
16-
has_model = getattr(processor, "flow_matching_model", None) is not None
17-
if builder is None or has_model:
18-
return
19-
if n_steps_input is None or channel_count is None:
20-
return
21-
spatial = tuple(spatial_shape) if spatial_shape is not None else ()
22-
dummy = torch.zeros(
23-
(1, n_steps_input, *spatial, channel_count), dtype=torch.float32
24-
)
25-
builder(dummy)

src/autocast/train/processor.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from autocast.models.ae import AE, AELoss
1818
from autocast.models.encoder_decoder import EncoderDecoder
1919
from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder
20-
from autocast.processors.utils import initialize_flow_matching_backbone
2120
from autocast.train.configuration import (
2221
compose_training_config,
2322
configure_module_dimensions,
@@ -245,13 +244,6 @@ def main() -> None: # noqa: PLR0915
245244
_freeze_module(encoder_decoder.decoder)
246245

247246
processor = instantiate(model_cfg.processor)
248-
spatial_shape = tuple(input_shape[2:-1])
249-
initialize_flow_matching_backbone(
250-
processor,
251-
inferred_n_steps_input,
252-
channel_count,
253-
spatial_shape,
254-
)
255247

256248
epd_cfg = model_cfg
257249
learning_rate = epd_cfg.get("learning_rate", 1e-3)

tests/processors/test_flow_matching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_flow_matching_processor(
106106
encoded_batch = next(iter(encoded_loader))
107107

108108
processor = FlowMatchingProcessor(
109-
flow_matching_model=TemporalUNetBackbone(
109+
backbone=TemporalUNetBackbone(
110110
in_channels=n_steps_output * n_channels_out,
111111
out_channels=n_steps_output * n_channels_out,
112112
cond_channels=n_steps_input * n_channels_in,

0 commit comments

Comments
 (0)