@@ -117,10 +117,20 @@ def _true_slice(
117117
118118 def _advance_batch (self , batch : Batch , next_inputs : Tensor , stride : int ) -> Batch :
119119 """Shift the input/output windows forward by `stride` using `next_inputs`."""
120- next_inputs = torch .cat (
121- [batch .input_fields [:, stride :, ...], next_inputs [:, :stride , ...]],
122- dim = 1 ,
123- )
120+ # Get the original number of input time steps to maintain consistency
121+ n_steps_input = batch .input_fields .shape [1 ]
122+
123+ # Concatenate remaining inputs with new predictions
124+ remaining_inputs = batch .input_fields [:, stride :, ...]
125+ new_predictions = next_inputs [:, :stride , ...]
126+
127+ if remaining_inputs .shape [1 ] == 0 :
128+ # No remaining inputs, use most recent n_steps_input from predictions
129+ combined = new_predictions [:, - n_steps_input :, ...]
130+ else :
131+ combined = torch .cat ([remaining_inputs , new_predictions ], dim = 1 )
132+ # Keep only the most recent n_steps_input time steps
133+ combined = combined [:, - n_steps_input :, ...]
124134
125135 next_outputs = (
126136 batch .output_fields [:, stride :, ...]
@@ -129,7 +139,7 @@ def _advance_batch(self, batch: Batch, next_inputs: Tensor, stride: int) -> Batc
129139 )
130140
131141 return Batch (
132- input_fields = next_inputs ,
142+ input_fields = combined ,
133143 output_fields = next_outputs ,
134144 constant_scalars = batch .constant_scalars ,
135145 constant_fields = batch .constant_fields ,
@@ -165,8 +175,17 @@ def __init__(
165175
166176 def training_step (self , batch : Batch , batch_idx : int ) -> Tensor : # noqa: ARG002
167177 encoded_batch = self .encoder_decoder .encoder .encode_batch (batch )
178+ # TODO: ensure no grads propagate through encoder_decoder
168179 loss = self .processor .loss (encoded_batch )
169180 self .log (
170181 "train_loss" , loss , prog_bar = True , batch_size = batch .input_fields .shape [0 ]
171182 )
172183 return loss
184+
185+ def validation_step (self , batch , batch_idx : int ): # noqa: ARG002
186+ encoded_batch = self .encoder_decoder .encoder .encode_batch (batch )
187+ loss = self .processor .loss (encoded_batch )
188+ self .log (
189+ "valid_loss" , loss , prog_bar = True , batch_size = batch .input_fields .shape [0 ]
190+ )
191+ return loss
0 commit comments