Skip to content

Commit 5e7f81d

Browse files
committed
Remove obsolete backbone kwargs, update config
1 parent c0c6a46 commit 5e7f81d

5 files changed

Lines changed: 14 additions & 87 deletions

File tree

configs/processor/flow_matching.yaml

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,19 @@ learning_rate: 0.0001
66
flow_ode_steps: 4
77
n_steps_output: null
88
n_channels_out: null
9-
backbone_kwargs:
9+
backbone:
10+
_target_: autocast.nn.unet.TemporalUNetBackbone
11+
in_channels: null
12+
out_channels: null
13+
cond_channels: null
1014
mod_features: 256
11-
hid_channels: [32, 64, 128]
12-
hid_blocks: [2, 2, 2]
15+
hid_channels:
16+
_target_: builtins.tuple
17+
_args_:
18+
- [32, 64, 128]
19+
hid_blocks:
20+
_target_: builtins.tuple
21+
_args_:
22+
- [2, 2, 2]
1323
spatial: 2
1424
periodic: false

src/autocast/eval/processor.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
)
3131
from autocast.models.encoder_decoder import EncoderDecoder
3232
from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder
33-
from autocast.processors.utils import initialize_flow_matching_backbone
3433
from autocast.train.configuration import (
3534
compose_training_config,
3635
configure_module_dimensions,
@@ -340,21 +339,12 @@ def _load_state_dict(checkpoint_path: Path) -> OrderedDict[str, torch.Tensor]:
340339
def _load_model(
341340
cfg: DictConfig,
342341
checkpoint_path: Path,
343-
n_steps_input: int,
344-
channel_count: int,
345-
spatial_shape: Sequence[int],
346342
) -> EncoderProcessorDecoder:
347343
model_cfg = cfg.get("model") or cfg
348344
encoder = instantiate(model_cfg.encoder)
349345
decoder = instantiate(model_cfg.decoder)
350346
encoder_decoder = EncoderDecoder(encoder=encoder, decoder=decoder)
351347
processor = instantiate(model_cfg.processor)
352-
initialize_flow_matching_backbone(
353-
processor,
354-
n_steps_input,
355-
channel_count,
356-
spatial_shape,
357-
)
358348
epd_cfg = model_cfg
359349
learning_rate = epd_cfg.get("learning_rate", 1e-3)
360350
training_cfg = cfg.get("training") or {}
@@ -486,7 +476,7 @@ def main() -> None:
486476
channel_count,
487477
inferred_n_steps_input,
488478
inferred_n_steps_output,
489-
input_shape,
479+
_,
490480
_,
491481
) = prepare_datamodule(cfg)
492482

@@ -500,13 +490,9 @@ def main() -> None:
500490

501491
metrics = _build_metrics(args.metrics or ("mse", "rmse"))
502492

503-
spatial_shape = tuple(input_shape[2:-1])
504493
model = _load_model(
505494
cfg,
506495
args.checkpoint,
507-
inferred_n_steps_input,
508-
channel_count,
509-
spatial_shape,
510496
)
511497
device = _resolve_device(args.device)
512498
model.to(device)

src/autocast/processors/flow_matching.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
from typing import Any
44

55
import torch
6-
from omegaconf import DictConfig, OmegaConf
76
from torch import nn
87

9-
from autocast.nn.unet import TemporalUNetBackbone
108
from autocast.processors.base import Processor
119
from autocast.types import EncodedBatch, Tensor
1210

@@ -29,7 +27,6 @@ def __init__(
2927
flow_ode_steps: int = 1,
3028
n_steps_output: int = 4,
3129
n_channels_out: int = 1,
32-
backbone_kwargs: dict[str, Any] | DictConfig | None = None,
3330
**kwargs: Any,
3431
) -> None:
3532
# Store core hyperparameters and optional prebuilt backbone.
@@ -47,37 +44,6 @@ def __init__(
4744
self.flow_ode_steps = max(flow_ode_steps, 1)
4845
self.n_steps_output = n_steps_output
4946
self.n_channels_out = n_channels_out
50-
processed_kwargs: dict[str, Any] = {}
51-
raw_kwargs: Any | None
52-
if isinstance(backbone_kwargs, DictConfig):
53-
raw_kwargs = OmegaConf.to_container(backbone_kwargs, resolve=True)
54-
else:
55-
raw_kwargs = backbone_kwargs
56-
if isinstance(raw_kwargs, dict):
57-
processed_kwargs = {str(k): v for k, v in raw_kwargs.items()}
58-
for field in ("hid_channels", "hid_blocks"):
59-
value = processed_kwargs.get(field)
60-
if isinstance(value, list):
61-
processed_kwargs[field] = tuple(value)
62-
self.backbone_kwargs = processed_kwargs
63-
64-
def _maybe_build_backbone(self, x: Tensor) -> None:
65-
"""Lazily build TemporalUNetBackbone when no model is provided."""
66-
if self.flow_matching_model is not None:
67-
return
68-
69-
# Infer in/out channels from configured temporal/channel counts.
70-
t_in = x.shape[1]
71-
c_in = x.shape[-1]
72-
t_out = self.n_steps_output
73-
c_out = self.n_channels_out
74-
75-
self.flow_matching_model = TemporalUNetBackbone(
76-
in_channels=t_out * c_out,
77-
out_channels=t_out * c_out,
78-
cond_channels=t_in * c_in,
79-
**self.backbone_kwargs,
80-
)
8147

8248
def flow_field(self, z: Tensor, t: Tensor, x: Tensor) -> Tensor:
8349
"""Flow matching vector field.
@@ -94,7 +60,6 @@ def flow_field(self, z: Tensor, t: Tensor, x: Tensor) -> Tensor:
9460
-------
9561
Time derivative of output states with the same shape as `z`.
9662
"""
97-
self._maybe_build_backbone(x)
9863
assert self.flow_matching_model is not None # for type checkers
9964
return self.flow_matching_model(z, t, x)
10065

@@ -146,8 +111,6 @@ def loss(self, batch: EncodedBatch) -> Tensor:
146111
)
147112
raise ValueError(msg)
148113

149-
self._maybe_build_backbone(input_states)
150-
151114
batch_size = target_states.shape[0]
152115

153116
z0 = torch.randn_like(target_states, requires_grad=True)

src/autocast/processors/utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1 @@
11
from __future__ import annotations
2-
3-
from collections.abc import Sequence
4-
5-
import torch
6-
7-
8-
def initialize_flow_matching_backbone(
9-
processor,
10-
n_steps_input: int | None,
11-
channel_count: int | None,
12-
spatial_shape: Sequence[int] | None,
13-
) -> None:
14-
"""Instantiate the flow-matching backbone before optimizers are created."""
15-
builder = getattr(processor, "_maybe_build_backbone", None)
16-
has_model = getattr(processor, "flow_matching_model", None) is not None
17-
if builder is None or has_model:
18-
return
19-
if n_steps_input is None or channel_count is None:
20-
return
21-
spatial = tuple(spatial_shape) if spatial_shape is not None else ()
22-
dummy = torch.zeros(
23-
(1, n_steps_input, *spatial, channel_count), dtype=torch.float32
24-
)
25-
builder(dummy)

src/autocast/train/processor.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from autocast.models.ae import AE, AELoss
1818
from autocast.models.encoder_decoder import EncoderDecoder
1919
from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder
20-
from autocast.processors.utils import initialize_flow_matching_backbone
2120
from autocast.train.configuration import (
2221
compose_training_config,
2322
configure_module_dimensions,
@@ -245,13 +244,6 @@ def main() -> None: # noqa: PLR0915
245244
_freeze_module(encoder_decoder.decoder)
246245

247246
processor = instantiate(model_cfg.processor)
248-
spatial_shape = tuple(input_shape[2:-1])
249-
initialize_flow_matching_backbone(
250-
processor,
251-
inferred_n_steps_input,
252-
channel_count,
253-
spatial_shape,
254-
)
255247

256248
epd_cfg = model_cfg
257249
learning_rate = epd_cfg.get("learning_rate", 1e-3)

0 commit comments

Comments
 (0)