Skip to content

Commit 8e00b7e

Browse files
committed
Fix rollout
1 parent e8905e5 commit 8e00b7e

3 files changed

Lines changed: 42 additions & 11 deletions

File tree

src/auto_cast/models/encoder_processor_decoder.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,20 @@ def _true_slice(
117117

118118
def _advance_batch(self, batch: Batch, next_inputs: Tensor, stride: int) -> Batch:
119119
"""Shift the input/output windows forward by `stride` using `next_inputs`."""
120-
next_inputs = torch.cat(
121-
[batch.input_fields[:, stride:, ...], next_inputs[:, :stride, ...]],
122-
dim=1,
123-
)
120+
# Get the original number of input time steps to maintain consistency
121+
n_steps_input = batch.input_fields.shape[1]
122+
123+
# Concatenate remaining inputs with new predictions
124+
remaining_inputs = batch.input_fields[:, stride:, ...]
125+
new_predictions = next_inputs[:, :stride, ...]
126+
127+
if remaining_inputs.shape[1] == 0:
128+
# No remaining inputs, use most recent n_steps_input from predictions
129+
combined = new_predictions[:, -n_steps_input:, ...]
130+
else:
131+
combined = torch.cat([remaining_inputs, new_predictions], dim=1)
132+
# Keep only the most recent n_steps_input time steps
133+
combined = combined[:, -n_steps_input:, ...]
124134

125135
next_outputs = (
126136
batch.output_fields[:, stride:, ...]
@@ -129,7 +139,7 @@ def _advance_batch(self, batch: Batch, next_inputs: Tensor, stride: int) -> Batc
129139
)
130140

131141
return Batch(
132-
input_fields=next_inputs,
142+
input_fields=combined,
133143
output_fields=next_outputs,
134144
constant_scalars=batch.constant_scalars,
135145
constant_fields=batch.constant_fields,
@@ -165,8 +175,17 @@ def __init__(
165175

166176
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: # noqa: ARG002
167177
encoded_batch = self.encoder_decoder.encoder.encode_batch(batch)
178+
# TODO: ensure no grads propagate through encoder_decoder
168179
loss = self.processor.loss(encoded_batch)
169180
self.log(
170181
"train_loss", loss, prog_bar=True, batch_size=batch.input_fields.shape[0]
171182
)
172183
return loss
184+
185+
def validation_step(self, batch, batch_idx: int): # noqa: ARG002
186+
encoded_batch = self.encoder_decoder.encoder.encode_batch(batch)
187+
loss = self.processor.loss(encoded_batch)
188+
self.log(
189+
"valid_loss", loss, prog_bar=True, batch_size=batch.input_fields.shape[0]
190+
)
191+
return loss

src/auto_cast/models/processor.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,28 @@ def _true_slice(self, batch: EncodedBatch, stride: int) -> tuple[Tensor, bool]:
9191
def _advance_batch(
9292
self, batch: EncodedBatch, next_inputs: Tensor, stride: int
9393
) -> EncodedBatch:
94-
next_inputs = torch.cat(
95-
[batch.encoded_inputs[:, stride:, ...], next_inputs[:, :stride, ...]],
96-
dim=1,
97-
)
94+
# Get the original number of input time steps to maintain consistency
95+
n_steps_input = batch.encoded_inputs.shape[1]
96+
97+
# Concatenate remaining inputs with new predictions
98+
remaining_inputs = batch.encoded_inputs[:, stride:, ...]
99+
new_predictions = next_inputs[:, :stride, ...]
100+
101+
if remaining_inputs.shape[1] == 0:
102+
# No remaining inputs, use most recent n_steps_input from predictions
103+
combined = new_predictions[:, -n_steps_input:, ...]
104+
else:
105+
combined = torch.cat([remaining_inputs, new_predictions], dim=1)
106+
# Keep only the most recent n_steps_input time steps
107+
combined = combined[:, -n_steps_input:, ...]
108+
98109
next_outputs = (
99110
batch.encoded_output_fields[:, stride:, ...]
100111
if batch.encoded_output_fields.shape[1] > stride
101112
else batch.encoded_output_fields[:, 0:0, ...]
102113
)
103114
return EncodedBatch(
104-
encoded_inputs=next_inputs,
115+
encoded_inputs=combined,
105116
encoded_output_fields=next_outputs,
106117
encoded_info=batch.encoded_info,
107118
)

src/auto_cast/processors/rollout.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import torch
77
from einops import rearrange
88

9-
from auto_cast.types import BatchT, RolloutOutput, Tensor
9+
from auto_cast.types import RolloutOutput, Tensor
10+
from auto_cast.types.batch import BatchT
1011

1112

1213
class RolloutMixin(ABC, Generic[BatchT]):

0 commit comments

Comments
 (0)