@@ -162,6 +162,8 @@ def forward(
162162 ):
163163
164164 act_losses = []
165+ prev_q_continue = None
166+
165167 return_loss = exists (labels )
166168
167169 max_reasoning_steps = default (max_reasoning_steps , self .max_reasoning_steps )
@@ -230,6 +232,10 @@ def evaluate_pred():
230232 for index in range (max_reasoning_steps * self .lowest_steps_per_reasoning_step - 1 ):
231233
232234 iteration = index + 1
235+ is_reasoning_step_boundary = divisible_by (index , self .lowest_steps_per_reasoning_step )
236+ num_reasoning_steps = index // self .lowest_steps_per_reasoning_step
237+
238+ # evaluate all networks depending on their period
233239
234240 for network_index , (network , hidden_combine , evaluate_network_at ) in enumerate (zip (self .networks , self .hidden_combiners , self .evaluate_networks_at )):
235241
@@ -240,10 +246,7 @@ def evaluate_pred():
240246
241247 # adaptive computation time
242248
243- is_reasoning_step_boundary = divisible_by (index , self .lowest_steps_per_reasoning_step )
244- num_reasoning_steps = index // self .lowest_steps_per_reasoning_step
245-
246- if is_reasoning_step_boundary and num_reasoning_steps >= min_reasoning_steps :
249+ if is_reasoning_step_boundary :
247250
248251 highest_hidden = hiddens [self .num_networks - 1 ]
249252
@@ -259,6 +262,11 @@ def evaluate_pred():
259262
260263 act_losses .append (halt_target_loss )
261264
265+ if exists (prev_q_continue ):
266+ continue_target_loss = F .binary_cross_entropy (prev_q_continue , torch .maximum (q_continue , q_halt ))
267+
268+ act_losses .append (continue_target_loss )
269+
262270 # 1-step gradient learning
263271
264272 for network_index , (network , hidden_combine ) in enumerate (zip (self .networks , self .hidden_combiners )):
0 commit comments