Skip to content

Commit d53d411

Browse files
committed
Add ambient rollout re-encode invariant test
Pin that EncoderProcessorDecoder.rollout invokes the encoder once per rollout step by counting encode calls via a wrapped PermuteConcat. This guards eval.mode=ambient against a silent regression where a refactor collapses the rollout into a latent-only loop: in that case ambient and latent eval would report identical numbers, so the ablation (the whole reason the mode exists) would be meaningless.
1 parent d5abf7e commit d53d411

1 file changed

Lines changed: 83 additions & 0 deletions

File tree

tests/models/test_encoder_processor_decoder.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,86 @@ def test_encoder_processor_decoder_rollout_handles_short_trajectory(
265265
# Ground truth only for windows where data was available
266266
assert gts is not None
267267
assert gts.shape == (batch_size, expected_gt_windows * n_steps_output, 32, 32, 1)
268+
269+
270+
class CountingPermuteConcat(PermuteConcat):
271+
"""PermuteConcat encoder that tracks how many times ``encode`` is called."""
272+
273+
def __init__(
274+
self, in_channels: int, n_steps_input: int, with_constants: bool = False
275+
) -> None:
276+
super().__init__(
277+
in_channels=in_channels,
278+
n_steps_input=n_steps_input,
279+
with_constants=with_constants,
280+
)
281+
self.encode_calls = 0
282+
283+
def encode(self, batch: Batch) -> Tensor: # type: ignore[override]
284+
self.encode_calls += 1
285+
return super().encode(batch)
286+
287+
288+
def test_encoder_processor_decoder_rollout_re_encodes_each_step(make_toy_batch):
289+
"""Ambient rollout must re-invoke the encoder at every rollout step.
290+
291+
This is the invariant the whole ``eval.mode=ambient`` path rests on: in
292+
ambient rollout each step decodes the prediction and re-encodes it as
293+
the next input, so decode/encode drift accumulates. If a future refactor
294+
ever collapsed this into a latent-only loop, latent and ambient eval
295+
would silently report the same numbers and ambient-vs-latent ablations
296+
would be meaningless. This test pins the contract.
297+
"""
298+
max_rollout_steps = 3
299+
n_steps_input = 2
300+
n_steps_output = 2
301+
stride = 2
302+
batch_size = 2
303+
trajectory_length = 20
304+
305+
batch = make_toy_batch(
306+
batch_size=batch_size,
307+
t_in=n_steps_input,
308+
t_out=trajectory_length - n_steps_input,
309+
)
310+
output_channels = batch.output_fields.shape[-1]
311+
merged_input_channels = output_channels * n_steps_input
312+
merged_output_channels = output_channels * n_steps_output
313+
314+
encoder = CountingPermuteConcat(
315+
in_channels=output_channels,
316+
n_steps_input=n_steps_input,
317+
with_constants=False,
318+
)
319+
decoder = ChannelsLast(output_channels=output_channels, time_steps=n_steps_output)
320+
loss = nn.MSELoss()
321+
encoder_decoder = EncoderDecoder(encoder=encoder, decoder=decoder, loss_func=loss)
322+
processor = TinyProcessor(
323+
in_channels=merged_input_channels, out_channels=merged_output_channels
324+
)
325+
model = EncoderProcessorDecoder(
326+
encoder_decoder=encoder_decoder,
327+
processor=processor,
328+
loss_func=loss,
329+
optimizer_config=get_optimizer_config(),
330+
stride=stride,
331+
max_rollout_steps=max_rollout_steps,
332+
)
333+
model.eval()
334+
335+
calls_before = encoder.encode_calls
336+
preds, _ = model.rollout(
337+
batch,
338+
stride=stride,
339+
max_rollout_steps=max_rollout_steps,
340+
free_running_only=True,
341+
)
342+
calls_during = encoder.encode_calls - calls_before
343+
344+
assert calls_during >= max_rollout_steps, (
345+
"Ambient rollout must invoke the encoder at least once per rollout "
346+
f"step; got {calls_during} encode calls for "
347+
f"{max_rollout_steps} rollout steps."
348+
)
349+
assert preds.shape[0] == batch_size
350+
assert preds.shape[1] == max_rollout_steps * n_steps_output

0 commit comments

Comments
 (0)