@@ -157,67 +157,52 @@ def forward(
157157
158158 hiddens_dict = {index : hidden for index , hidden in enumerate (hiddens )}
159159
160- # network as they proposed - following figure 4
161-
162- def evaluate_network_ (
163- network : Module ,
164- hidden_combine : Module ,
165- network_index
166- ):
167-
168- all_hiddens = (
169- tokens ,
170- * hiddens_dict .values ()
171- )
172-
173- # combine with mean pool for now
174-
175- combined_input = hidden_combine (all_hiddens , network_index )
176-
177- # forward
178-
179- next_hidden = network (combined_input )
180-
181- # store hiddens at appropriate hierarchy, low to highest
182-
183- hiddens_dict [network_index ] = next_hidden
160+ # calculate total steps
184161
185- def evaluate_pred ():
186- # prediction is done from the hiddens of highest hierarchy
187-
188- highest_hidden = hiddens_dict [self .num_networks - 1 ]
162+ total_low_steps = reasoning_steps * self .lowest_steps_per_reasoning_step
189163
190- return self . to_pred ( highest_hidden )
164+ # network as they proposed - following figure 4
191165
192- # maybe 1-step
166+ for index in range ( total_low_steps ):
193167
194- context = torch . no_grad if one_step_grad else nullcontext
168+ iteration = index + 1
195169
196- total_low_steps = reasoning_steps * self . lowest_steps_per_reasoning_step
170+ # maybe 1-step gradient learning
197171
198- with context ():
199- for index in range (total_low_steps - 1 ): # -1 to omit last step for the proposed 1-step grad learning
172+ is_last_step = index == (total_low_steps - 1 )
200173
201- iteration = index + 1
174+ context = torch . no_grad if one_step_grad and is_last_step else nullcontext
202175
176+ with context ():
203177 # evaluate all networks depending on their period
204178
205179 for network_index , (network , hidden_combine , evaluate_network_at ) in enumerate (zip (self .networks , self .hidden_combiners , self .evaluate_networks_at )):
206180
207181 if not divisible_by (iteration , evaluate_network_at ):
208182 continue
209183
210- evaluate_network_ (network , hidden_combine , network_index )
184+ all_hiddens = (
185+ tokens ,
186+ * hiddens_dict .values ()
187+ )
211188
212- # 1-step gradient learning
189+ # combine with concat project
213190
214- for network_index , ( network , hidden_combine ) in enumerate ( zip ( self . networks , self . hidden_combiners )):
191+ combined_input = hidden_combine ( all_hiddens , network_index )
215192
216- evaluate_network_ (network , hidden_combine , network_index )
193+ # forward
194+
195+ next_hidden = network (combined_input )
196+
197+ # store hiddens at appropriate hierarchy, low to highest
198+
199+ hiddens_dict [network_index ] = next_hidden
217200
218201 # to output prediction, using the hiddens from the highest hierarchy
219202
220- pred = evaluate_pred ()
203+ highest_hidden = hiddens_dict [self .num_networks - 1 ]
204+
205+ pred = self .to_pred (highest_hidden )
221206
222207 # if labels passed in, cross entropy loss
223208
0 commit comments