Skip to content

Commit 158ce70

Browse files
committed
stray from their pseudocode for something cleaner
1 parent 5cec9b0 commit 158ce70

File tree

1 file changed

+25
-40
lines changed

1 file changed

+25
-40
lines changed

HRM/hrm.py

Lines changed: 25 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -157,67 +157,52 @@ def forward(
157157

158158
hiddens_dict = {index: hidden for index, hidden in enumerate(hiddens)}
159159

160-
# network as they proposed - following figure 4
161-
162-
def evaluate_network_(
163-
network: Module,
164-
hidden_combine: Module,
165-
network_index
166-
):
167-
168-
all_hiddens = (
169-
tokens,
170-
*hiddens_dict.values()
171-
)
172-
173-
# combine with mean pool for now
174-
175-
combined_input = hidden_combine(all_hiddens, network_index)
176-
177-
# forward
178-
179-
next_hidden = network(combined_input)
180-
181-
# store hiddens at appropriate hierarchy, low to highest
182-
183-
hiddens_dict[network_index] = next_hidden
160+
# calculate total steps
184161

185-
def evaluate_pred():
186-
# prediction is done from the hiddens of highest hierarchy
187-
188-
highest_hidden = hiddens_dict[self.num_networks - 1]
162+
total_low_steps = reasoning_steps * self.lowest_steps_per_reasoning_step
189163

190-
return self.to_pred(highest_hidden)
164+
# network as they proposed - following figure 4
191165

192-
# maybe 1-step
166+
for index in range(total_low_steps):
193167

194-
context = torch.no_grad if one_step_grad else nullcontext
168+
iteration = index + 1
195169

196-
total_low_steps = reasoning_steps * self.lowest_steps_per_reasoning_step
170+
# maybe 1-step gradient learning
197171

198-
with context():
199-
for index in range(total_low_steps - 1): # -1 to omit last step for the proposed 1-step grad learning
172+
is_last_step = index == (total_low_steps - 1)
200173

201-
iteration = index + 1
174+
context = torch.no_grad if one_step_grad and is_last_step else nullcontext
202175

176+
with context():
203177
# evaluate all networks depending on their period
204178

205179
for network_index, (network, hidden_combine, evaluate_network_at) in enumerate(zip(self.networks, self.hidden_combiners, self.evaluate_networks_at)):
206180

207181
if not divisible_by(iteration, evaluate_network_at):
208182
continue
209183

210-
evaluate_network_(network, hidden_combine, network_index)
184+
all_hiddens = (
185+
tokens,
186+
*hiddens_dict.values()
187+
)
211188

212-
# 1-step gradient learning
189+
# combine with concat project
213190

214-
for network_index, (network, hidden_combine) in enumerate(zip(self.networks, self.hidden_combiners)):
191+
combined_input = hidden_combine(all_hiddens, network_index)
215192

216-
evaluate_network_(network, hidden_combine, network_index)
193+
# forward
194+
195+
next_hidden = network(combined_input)
196+
197+
# store hiddens at appropriate hierarchy, low to highest
198+
199+
hiddens_dict[network_index] = next_hidden
217200

218201
# to output prediction, using the hiddens from the highest hierarchy
219202

220-
pred = evaluate_pred()
203+
highest_hidden = hiddens_dict[self.num_networks - 1]
204+
205+
pred = self.to_pred(highest_hidden)
221206

222207
# if labels passed in, cross entropy loss
223208

0 commit comments

Comments
 (0)