Skip to content

Commit 48baa5c

Browse files
committed
Update rollout
1 parent d983224 commit 48baa5c

1 file changed

Lines changed: 68 additions & 7 deletions

File tree

src/auto_cast/models/encoder_processor_decoder.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,75 @@ def configure_optimizers(self):
8989
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
9090

9191
def rollout(self, batch: Batch) -> RolloutOutput:
92-
"""Rollout over multiple time steps."""
93-
pred_outs, gt_outs = [], []
92+
"""Rollout over multiple time steps with optional teacher forcing."""
93+
pred_outs: list[Tensor] = []
94+
gt_outs: list[Tensor] = []
95+
96+
# Initialize the current batch for rollout
97+
current_batch = Batch(
98+
input_fields=batch.input_fields.clone(),
99+
output_fields=batch.output_fields.clone(),
100+
constant_scalars=(
101+
batch.constant_scalars.clone()
102+
if batch.constant_scalars is not None
103+
else None
104+
),
105+
constant_fields=(
106+
batch.constant_fields.clone()
107+
if batch.constant_fields is not None
108+
else None
109+
),
110+
)
111+
112+
# Rollout loop with teacher forcing
94113
for _ in range(0, self.max_rollout_steps, self.stride):
95-
x = self.encoder_decoder.encoder(batch)
96-
pred_outs.append(self.processor(x))
97-
# TODO: combining teacher forcing logic
98-
gt_outs.append(batch.output_fields) # This assumes we have output fields
99-
return torch.stack(pred_outs), torch.stack(gt_outs)
114+
output = self(current_batch)
115+
pred_outs.append(output)
116+
117+
if current_batch.output_fields.shape[1] >= self.stride:
118+
gt_slice = current_batch.output_fields[:, : self.stride, ...]
119+
gt_outs.append(gt_slice)
120+
else:
121+
gt_slice = current_batch.output_fields
122+
123+
# Simple teacher forcing logic with Bernoulli sampling
124+
rand_val = torch.rand(1, device=output.device).item()
125+
teacher_force = (
126+
gt_slice.numel() > 0 and rand_val < self.teacher_forcing_ratio
127+
)
128+
feedback = gt_slice if teacher_force else output.detach()
129+
130+
if feedback.shape[1] < self.stride:
131+
break
132+
133+
current_batch = self._advance_batch(current_batch, feedback, self.stride)
134+
135+
# Stack predictions and ground truths and return
136+
predictions = torch.stack(pred_outs)
137+
if gt_outs:
138+
return predictions, torch.stack(gt_outs)
139+
return predictions, None
140+
141+
@staticmethod
142+
def _advance_batch(batch: Batch, feedback: Tensor, stride: int) -> Batch:
143+
"""Shift the input/output windows forward by `stride` using `feedback`."""
144+
next_inputs = torch.cat(
145+
[batch.input_fields[:, stride:, ...], feedback[:, :stride, ...]],
146+
dim=1,
147+
)
148+
149+
next_outputs = (
150+
batch.output_fields[:, stride:, ...]
151+
if batch.output_fields.shape[1] > stride
152+
else batch.output_fields[:, 0:0, ...] # Empty tensor with correct shape
153+
)
154+
155+
return Batch(
156+
input_fields=next_inputs,
157+
output_fields=next_outputs,
158+
constant_scalars=batch.constant_scalars,
159+
constant_fields=batch.constant_fields,
160+
)
100161

101162

102163
# # TODO: consider if separate rollout class would be better

0 commit comments

Comments
 (0)