@@ -160,6 +160,7 @@ def forward(
160160 hiddens : tuple [Tensor , ...] | None = None ,
161161 * ,
162162 labels = None ,
163+ compute_loss_across_reasoning_steps = False ,
163164 detach_hiddens = True ,
164165 one_step_grad = True ,
165166 max_reasoning_steps = None ,
@@ -265,33 +266,46 @@ def forward(
265266
266267 pred_q_halt_continues .append (q_halt_continue )
267268
268- # to output prediction, using the hiddens from the highest hierarchy
269+ # if labels passed in, cross entropy loss
269270
270- highest_hidden = hiddens [ self . num_networks - 1 ]
271+ hiddens = list ( hiddens . values ())
271272
272- logits = self .to_logits (highest_hidden )
273+ if not return_loss :
274+ # to output prediction, using the hiddens from the highest hierarchy
273275
274- # if labels passed in, cross entropy loss
276+ highest_hidden = hiddens [ self . num_networks - 1 ]
275277
276- hiddens = hiddens . values ( )
278+ logits = self . to_logits ( highest_hidden )
277279
278- if not return_loss :
279280 return logits , hiddens
280281
281282 # get main loss
282283
283- main_pred_loss = F .cross_entropy (
284- rearrange (logits , 'b n c -> b c n' ),
285- labels ,
286- ignore_index = self .ignore_index
287- )
284+ highest_hiddens = stack (highest_hiddens ) # (l b n d)
285+
286+ if not compute_loss_across_reasoning_steps :
287+ logits = self .to_logits (highest_hiddens [- 1 ])
288+
289+ main_pred_loss = F .cross_entropy (
290+ rearrange (logits , 'b n c -> b c n' ),
291+ labels ,
292+ ignore_index = self .ignore_index
293+ )
294+
295+ else :
296+ all_logits = self .to_logits (highest_hiddens )
297+ num_layers = all_logits .shape [0 ]
298+
299+ main_pred_loss = F .cross_entropy (
300+ rearrange (all_logits , 'l b n c -> b c n l' ),
301+ repeat (labels , 'b n -> b n l' , l = num_layers ),
302+ ignore_index = self .ignore_index
303+ )
288304
289305 # compute the act loss
290306
291307 q_halts , q_continues = rearrange (pred_q_halt_continues , 'l halt_continue b -> halt_continue l b' )
292308
293- highest_hiddens = stack (highest_hiddens ) # (l b n d)
294-
295309 # q halt loss is simply on whether the prediction is correct or not
296310
297311 with torch .no_grad ():
0 commit comments