Skip to content

Commit 2fc9b65

Browse files
committed
add the loss for Q continue
1 parent 89b22f8 commit 2fc9b65

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

HRM/hrm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ def forward(
162162
):
163163

164164
act_losses = []
165+
prev_q_continue = None
166+
165167
return_loss = exists(labels)
166168

167169
max_reasoning_steps = default(max_reasoning_steps, self.max_reasoning_steps)
@@ -230,6 +232,10 @@ def evaluate_pred():
230232
for index in range(max_reasoning_steps * self.lowest_steps_per_reasoning_step - 1):
231233

232234
iteration = index + 1
235+
is_reasoning_step_boundary = divisible_by(index, self.lowest_steps_per_reasoning_step)
236+
num_reasoning_steps = index // self.lowest_steps_per_reasoning_step
237+
238+
# evaluate all networks depending on their period
233239

234240
for network_index, (network, hidden_combine, evaluate_network_at) in enumerate(zip(self.networks, self.hidden_combiners, self.evaluate_networks_at)):
235241

@@ -240,10 +246,7 @@ def evaluate_pred():
240246

241247
# adaptive computation time
242248

243-
is_reasoning_step_boundary = divisible_by(index, self.lowest_steps_per_reasoning_step)
244-
num_reasoning_steps = index // self.lowest_steps_per_reasoning_step
245-
246-
if is_reasoning_step_boundary and num_reasoning_steps >= min_reasoning_steps:
249+
if is_reasoning_step_boundary:
247250

248251
highest_hidden = hiddens[self.num_networks - 1]
249252

@@ -259,6 +262,11 @@ def evaluate_pred():
259262

260263
act_losses.append(halt_target_loss)
261264

265+
if exists(prev_q_continue):
266+
continue_target_loss = F.binary_cross_entropy(prev_q_continue, torch.maximum(q_continue, q_halt))
267+
268+
act_losses.append(continue_target_loss)
269+
262270
# 1-step gradient learning
263271

264272
for network_index, (network, hidden_combine) in enumerate(zip(self.networks, self.hidden_combiners)):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "HRM-pytorch"
3-
version = "0.0.6"
3+
version = "0.0.7"
44
description = "The proposal from a Singaporean AGI company"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)