@@ -265,3 +265,86 @@ def test_encoder_processor_decoder_rollout_handles_short_trajectory(
265265 # Ground truth only for windows where data was available
266266 assert gts is not None
267267 assert gts .shape == (batch_size , expected_gt_windows * n_steps_output , 32 , 32 , 1 )
268+
269+
270+ class CountingPermuteConcat (PermuteConcat ):
271+ """PermuteConcat encoder that tracks how many times ``encode`` is called."""
272+
273+ def __init__ (
274+ self , in_channels : int , n_steps_input : int , with_constants : bool = False
275+ ) -> None :
276+ super ().__init__ (
277+ in_channels = in_channels ,
278+ n_steps_input = n_steps_input ,
279+ with_constants = with_constants ,
280+ )
281+ self .encode_calls = 0
282+
283+ def encode (self , batch : Batch ) -> Tensor : # type: ignore[override]
284+ self .encode_calls += 1
285+ return super ().encode (batch )
286+
287+
288+ def test_encoder_processor_decoder_rollout_re_encodes_each_step (make_toy_batch ):
289+ """Ambient rollout must re-invoke the encoder at every rollout step.
290+
291+ This is the invariant the whole ``eval.mode=ambient`` path rests on: in
292+ ambient rollout each step decodes the prediction and re-encodes it as
293+ the next input, so decode/encode drift accumulates. If a future refactor
294+ ever collapsed this into a latent-only loop, latent and ambient eval
295+ would silently report the same numbers and ambient-vs-latent ablations
296+ would be meaningless. This test pins the contract.
297+ """
298+ max_rollout_steps = 3
299+ n_steps_input = 2
300+ n_steps_output = 2
301+ stride = 2
302+ batch_size = 2
303+ trajectory_length = 20
304+
305+ batch = make_toy_batch (
306+ batch_size = batch_size ,
307+ t_in = n_steps_input ,
308+ t_out = trajectory_length - n_steps_input ,
309+ )
310+ output_channels = batch .output_fields .shape [- 1 ]
311+ merged_input_channels = output_channels * n_steps_input
312+ merged_output_channels = output_channels * n_steps_output
313+
314+ encoder = CountingPermuteConcat (
315+ in_channels = output_channels ,
316+ n_steps_input = n_steps_input ,
317+ with_constants = False ,
318+ )
319+ decoder = ChannelsLast (output_channels = output_channels , time_steps = n_steps_output )
320+ loss = nn .MSELoss ()
321+ encoder_decoder = EncoderDecoder (encoder = encoder , decoder = decoder , loss_func = loss )
322+ processor = TinyProcessor (
323+ in_channels = merged_input_channels , out_channels = merged_output_channels
324+ )
325+ model = EncoderProcessorDecoder (
326+ encoder_decoder = encoder_decoder ,
327+ processor = processor ,
328+ loss_func = loss ,
329+ optimizer_config = get_optimizer_config (),
330+ stride = stride ,
331+ max_rollout_steps = max_rollout_steps ,
332+ )
333+ model .eval ()
334+
335+ calls_before = encoder .encode_calls
336+ preds , _ = model .rollout (
337+ batch ,
338+ stride = stride ,
339+ max_rollout_steps = max_rollout_steps ,
340+ free_running_only = True ,
341+ )
342+ calls_during = encoder .encode_calls - calls_before
343+
344+ assert calls_during >= max_rollout_steps , (
345+ "Ambient rollout must invoke the encoder at least once per rollout "
346+ f"step; got { calls_during } encode calls for "
347+ f"{ max_rollout_steps } rollout steps."
348+ )
349+ assert preds .shape [0 ] == batch_size
350+ assert preds .shape [1 ] == max_rollout_steps * n_steps_output
0 commit comments