33from typing import Any
44
55import torch
6- from omegaconf import DictConfig , OmegaConf
76from torch import nn
87
9- from autocast .nn .unet import TemporalUNetBackbone
108from autocast .processors .base import Processor
119from autocast .types import EncodedBatch , Tensor
1210
@@ -29,7 +27,6 @@ def __init__(
2927 flow_ode_steps : int = 1 ,
3028 n_steps_output : int = 4 ,
3129 n_channels_out : int = 1 ,
32- backbone_kwargs : dict [str , Any ] | DictConfig | None = None ,
3330 ** kwargs : Any ,
3431 ) -> None :
3532 # Store core hyperparameters and optional prebuilt backbone.
@@ -47,37 +44,6 @@ def __init__(
4744 self .flow_ode_steps = max (flow_ode_steps , 1 )
4845 self .n_steps_output = n_steps_output
4946 self .n_channels_out = n_channels_out
50- processed_kwargs : dict [str , Any ] = {}
51- raw_kwargs : Any | None
52- if isinstance (backbone_kwargs , DictConfig ):
53- raw_kwargs = OmegaConf .to_container (backbone_kwargs , resolve = True )
54- else :
55- raw_kwargs = backbone_kwargs
56- if isinstance (raw_kwargs , dict ):
57- processed_kwargs = {str (k ): v for k , v in raw_kwargs .items ()}
58- for field in ("hid_channels" , "hid_blocks" ):
59- value = processed_kwargs .get (field )
60- if isinstance (value , list ):
61- processed_kwargs [field ] = tuple (value )
62- self .backbone_kwargs = processed_kwargs
63-
64- def _maybe_build_backbone (self , x : Tensor ) -> None :
65- """Lazily build TemporalUNetBackbone when no model is provided."""
66- if self .flow_matching_model is not None :
67- return
68-
69- # Infer in/out channels from configured temporal/channel counts.
70- t_in = x .shape [1 ]
71- c_in = x .shape [- 1 ]
72- t_out = self .n_steps_output
73- c_out = self .n_channels_out
74-
75- self .flow_matching_model = TemporalUNetBackbone (
76- in_channels = t_out * c_out ,
77- out_channels = t_out * c_out ,
78- cond_channels = t_in * c_in ,
79- ** self .backbone_kwargs ,
80- )
8147
8248 def flow_field (self , z : Tensor , t : Tensor , x : Tensor ) -> Tensor :
8349 """Flow matching vector field.
@@ -94,7 +60,6 @@ def flow_field(self, z: Tensor, t: Tensor, x: Tensor) -> Tensor:
9460 -------
9561 Time derivative of output states with the same shape as `z`.
9662 """
97- self ._maybe_build_backbone (x )
9863 assert self .flow_matching_model is not None # for type checkers
9964 return self .flow_matching_model (z , t , x )
10065
@@ -146,8 +111,6 @@ def loss(self, batch: EncodedBatch) -> Tensor:
146111 )
147112 raise ValueError (msg )
148113
149- self ._maybe_build_backbone (input_states )
150-
151114 batch_size = target_states .shape [0 ]
152115
153116 z0 = torch .randn_like (target_states , requires_grad = True )
0 commit comments