@@ -89,14 +89,75 @@ def configure_optimizers(self):
8989 return torch .optim .Adam (self .parameters (), lr = self .learning_rate )
9090
9191 def rollout (self , batch : Batch ) -> RolloutOutput :
92- """Rollout over multiple time steps."""
93- pred_outs , gt_outs = [], []
92+ """Rollout over multiple time steps with optional teacher forcing."""
93+ pred_outs : list [Tensor ] = []
94+ gt_outs : list [Tensor ] = []
95+
96+ # Initialize the current batch for rollout
97+ current_batch = Batch (
98+ input_fields = batch .input_fields .clone (),
99+ output_fields = batch .output_fields .clone (),
100+ constant_scalars = (
101+ batch .constant_scalars .clone ()
102+ if batch .constant_scalars is not None
103+ else None
104+ ),
105+ constant_fields = (
106+ batch .constant_fields .clone ()
107+ if batch .constant_fields is not None
108+ else None
109+ ),
110+ )
111+
112+ # Rollout loop with teacher forcing
94113 for _ in range (0 , self .max_rollout_steps , self .stride ):
95- x = self .encoder_decoder .encoder (batch )
96- pred_outs .append (self .processor (x ))
97- # TODO: combining teacher forcing logic
98- gt_outs .append (batch .output_fields ) # This assumes we have output fields
99- return torch .stack (pred_outs ), torch .stack (gt_outs )
114+ output = self (current_batch )
115+ pred_outs .append (output )
116+
117+ if current_batch .output_fields .shape [1 ] >= self .stride :
118+ gt_slice = current_batch .output_fields [:, : self .stride , ...]
119+ gt_outs .append (gt_slice )
120+ else :
121+ gt_slice = current_batch .output_fields
122+
123+ # Simple teacher forcing logic with Bernoulli sampling
124+ rand_val = torch .rand (1 , device = output .device ).item ()
125+ teacher_force = (
126+ gt_slice .numel () > 0 and rand_val < self .teacher_forcing_ratio
127+ )
128+ feedback = gt_slice if teacher_force else output .detach ()
129+
130+ if feedback .shape [1 ] < self .stride :
131+ break
132+
133+ current_batch = self ._advance_batch (current_batch , feedback , self .stride )
134+
135+ # Stack predictions and ground truths and return
136+ predictions = torch .stack (pred_outs )
137+ if gt_outs :
138+ return predictions , torch .stack (gt_outs )
139+ return predictions , None
140+
141+ @staticmethod
142+ def _advance_batch (batch : Batch , feedback : Tensor , stride : int ) -> Batch :
143+ """Shift the input/output windows forward by `stride` using `feedback`."""
144+ next_inputs = torch .cat (
145+ [batch .input_fields [:, stride :, ...], feedback [:, :stride , ...]],
146+ dim = 1 ,
147+ )
148+
149+ next_outputs = (
150+ batch .output_fields [:, stride :, ...]
151+ if batch .output_fields .shape [1 ] > stride
152+ else batch .output_fields [:, 0 :0 , ...] # Empty tensor with correct shape
153+ )
154+
155+ return Batch (
156+ input_fields = next_inputs ,
157+ output_fields = next_outputs ,
158+ constant_scalars = batch .constant_scalars ,
159+ constant_fields = batch .constant_fields ,
160+ )
100161
101162
102163# # TODO: consider if separate rollout class would be better
0 commit comments