1717from autocast .models .ae import AE , AELoss
1818from autocast .models .encoder_decoder import EncoderDecoder
1919from autocast .models .encoder_processor_decoder import EncoderProcessorDecoder
20+ from autocast .processors .utils import initialize_flow_matching_backbone
2021from autocast .train .configuration import (
2122 compose_training_config ,
2223 configure_module_dimensions ,
@@ -165,7 +166,7 @@ def instantiate_trainer(
165166 )
166167
167168
168- def main () -> None :
169+ def main () -> None : # noqa: PLR0915
169170 """CLI entrypoint for training the processor."""
170171 args = parse_args ()
171172 logging .basicConfig (level = logging .INFO )
@@ -175,6 +176,7 @@ def main() -> None:
175176
176177 cfg = compose_training_config (args )
177178 resolved_cfg = OmegaConf .to_container (cfg , resolve = True )
179+ model_cfg = cfg .get ("model" ) or cfg
178180 wandb_logger , watch_cfg = create_wandb_logger (
179181 cfg .get ("logging" ),
180182 experiment_name = cfg .get ("experiment_name" , "processor" ),
@@ -225,8 +227,8 @@ def main() -> None:
225227 normalize_processor_cfg (cfg )
226228
227229 encoder , decoder = build_autoencoder_modules (
228- cfg .encoder ,
229- cfg .decoder ,
230+ model_cfg .encoder ,
231+ model_cfg .decoder ,
230232 training_params .autoencoder_checkpoint ,
231233 )
232234 encoder_decoder = EncoderDecoder (encoder = encoder , decoder = decoder )
@@ -236,17 +238,33 @@ def main() -> None:
236238 _freeze_module (encoder_decoder .encoder )
237239 _freeze_module (encoder_decoder .decoder )
238240
239- processor = instantiate (cfg .processor )
241+ processor = instantiate (model_cfg .processor )
242+ spatial_shape = tuple (input_shape [2 :- 1 ])
243+ initialize_flow_matching_backbone (
244+ processor ,
245+ inferred_n_steps_input ,
246+ channel_count ,
247+ spatial_shape ,
248+ )
240249
241- epd_cfg = cfg .get ("encoder_processor_decoder" )
242- learning_rate = epd_cfg .get ("learning_rate" , 1e-3 ) if epd_cfg is not None else 1e-3
243- loss_cfg = epd_cfg .get ("loss_func" ) if epd_cfg is not None else None
250+ epd_cfg = model_cfg
251+ learning_rate = epd_cfg .get ("learning_rate" , 1e-3 )
252+ train_processor_only = epd_cfg .get ("train_processor_only" , False )
253+ teacher_forcing_ratio = epd_cfg .get ("teacher_forcing_ratio" , 0.5 )
254+ max_rollout_steps = epd_cfg .get ("max_rollout_steps" , 10 )
255+ loss_cfg = epd_cfg .get ("loss_func" )
244256 loss_func = instantiate (loss_cfg ) if loss_cfg is not None else nn .MSELoss ()
257+ training_cfg = cfg .get ("training" ) or {}
258+ stride = training_cfg .get ("stride" , 1 )
245259
246260 model = EncoderProcessorDecoder (
247261 encoder_decoder = encoder_decoder ,
248262 processor = processor ,
249263 learning_rate = learning_rate ,
264+ train_processor_only = train_processor_only ,
265+ stride = stride ,
266+ teacher_forcing_ratio = teacher_forcing_ratio ,
267+ max_rollout_steps = max_rollout_steps ,
250268 loss_func = loss_func ,
251269 )
252270
0 commit comments