33from typing import Any
44
55import torch
6- from omegaconf import DictConfig , OmegaConf
76from torch import nn
87
9- from autocast .nn .unet import TemporalUNetBackbone
108from autocast .processors .base import Processor
119from autocast .types import EncodedBatch , Tensor
1210
@@ -17,8 +15,7 @@ class FlowMatchingProcessor(Processor):
1715 def __init__ (
1816 self ,
1917 * ,
20- flow_matching_model : nn .Module | None = None ,
21- backbone : nn .Module | None = None ,
18+ backbone : nn .Module ,
2219 schedule : Any | None = None ,
2320 denoiser_type : str | None = None ,
2421 stride : int = 1 ,
@@ -29,7 +26,6 @@ def __init__(
2926 flow_ode_steps : int = 1 ,
3027 n_steps_output : int = 4 ,
3128 n_channels_out : int = 1 ,
32- backbone_kwargs : dict [str , Any ] | DictConfig | None = None ,
3329 ** kwargs : Any ,
3430 ) -> None :
3531 # Store core hyperparameters and optional prebuilt backbone.
@@ -40,44 +36,13 @@ def __init__(
4036 loss_func = loss_func or nn .MSELoss (),
4137 ** kwargs ,
4238 )
43- self .flow_matching_model = flow_matching_model or backbone
39+ self .flow_matching_model = backbone
4440 self .schedule = schedule # accepted for API compatibility
4541 self .denoiser_type = denoiser_type
4642 self .learning_rate = learning_rate
4743 self .flow_ode_steps = max (flow_ode_steps , 1 )
4844 self .n_steps_output = n_steps_output
4945 self .n_channels_out = n_channels_out
50- processed_kwargs : dict [str , Any ] = {}
51- raw_kwargs : Any | None
52- if isinstance (backbone_kwargs , DictConfig ):
53- raw_kwargs = OmegaConf .to_container (backbone_kwargs , resolve = True )
54- else :
55- raw_kwargs = backbone_kwargs
56- if isinstance (raw_kwargs , dict ):
57- processed_kwargs = {str (k ): v for k , v in raw_kwargs .items ()}
58- for field in ("hid_channels" , "hid_blocks" ):
59- value = processed_kwargs .get (field )
60- if isinstance (value , list ):
61- processed_kwargs [field ] = tuple (value )
62- self .backbone_kwargs = processed_kwargs
63-
64- def _maybe_build_backbone (self , x : Tensor ) -> None :
65- """Lazily build TemporalUNetBackbone when no model is provided."""
66- if self .flow_matching_model is not None :
67- return
68-
69- # Infer in/out channels from configured temporal/channel counts.
70- t_in = x .shape [1 ]
71- c_in = x .shape [- 1 ]
72- t_out = self .n_steps_output
73- c_out = self .n_channels_out
74-
75- self .flow_matching_model = TemporalUNetBackbone (
76- in_channels = t_out * c_out ,
77- out_channels = t_out * c_out ,
78- cond_channels = t_in * c_in ,
79- ** self .backbone_kwargs ,
80- )
8146
8247 def flow_field (self , z : Tensor , t : Tensor , x : Tensor ) -> Tensor :
8348 """Flow matching vector field.
@@ -94,8 +59,6 @@ def flow_field(self, z: Tensor, t: Tensor, x: Tensor) -> Tensor:
9459 -------
9560 Time derivative of output states with the same shape as `z`.
9661 """
97- self ._maybe_build_backbone (x )
98- assert self .flow_matching_model is not None # for type checkers
9962 return self .flow_matching_model (z , t , x )
10063
10164 def forward (self , x : Tensor ) -> Tensor :
@@ -146,8 +109,6 @@ def loss(self, batch: EncodedBatch) -> Tensor:
146109 )
147110 raise ValueError (msg )
148111
149- self ._maybe_build_backbone (input_states )
150-
151112 batch_size = target_states .shape [0 ]
152113
153114 z0 = torch .randn_like (target_states , requires_grad = True )
0 commit comments