Skip to content

Commit 270c067

Browse files
committed
when doing ACT, it makes sense to have prediction loss at each reasoning step
1 parent 3850007 commit 270c067

File tree

3 files changed

+61
-14
lines changed

3 files changed

+61
-14
lines changed

HRM/hrm_with_act.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def forward(
160160
hiddens: tuple[Tensor, ...] | None = None,
161161
*,
162162
labels = None,
163+
compute_loss_across_reasoning_steps = False,
163164
detach_hiddens = True,
164165
one_step_grad = True,
165166
max_reasoning_steps = None,
@@ -265,33 +266,46 @@ def forward(
265266

266267
pred_q_halt_continues.append(q_halt_continue)
267268

268-
# to output prediction, using the hiddens from the highest hierarchy
269+
# if labels passed in, cross entropy loss
269270

270-
highest_hidden = hiddens[self.num_networks - 1]
271+
hiddens = list(hiddens.values())
271272

272-
logits = self.to_logits(highest_hidden)
273+
if not return_loss:
274+
# to output prediction, using the hiddens from the highest hierarchy
273275

274-
# if labels passed in, cross entropy loss
276+
highest_hidden = hiddens[self.num_networks - 1]
275277

276-
hiddens = hiddens.values()
278+
logits = self.to_logits(highest_hidden)
277279

278-
if not return_loss:
279280
return logits, hiddens
280281

281282
# get main loss
282283

283-
main_pred_loss = F.cross_entropy(
284-
rearrange(logits, 'b n c -> b c n'),
285-
labels,
286-
ignore_index = self.ignore_index
287-
)
284+
highest_hiddens = stack(highest_hiddens) # (l b n d)
285+
286+
if not compute_loss_across_reasoning_steps:
287+
logits = self.to_logits(highest_hiddens[-1])
288+
289+
main_pred_loss = F.cross_entropy(
290+
rearrange(logits, 'b n c -> b c n'),
291+
labels,
292+
ignore_index = self.ignore_index
293+
)
294+
295+
else:
296+
all_logits = self.to_logits(highest_hiddens)
297+
num_layers = all_logits.shape[0]
298+
299+
main_pred_loss = F.cross_entropy(
300+
rearrange(all_logits, 'l b n c -> b c n l'),
301+
repeat(labels, 'b n -> b n l', l = num_layers),
302+
ignore_index = self.ignore_index
303+
)
288304

289305
# compute the act loss
290306

291307
q_halts, q_continues = rearrange(pred_q_halt_continues, 'l halt_continue b -> halt_continue l b')
292308

293-
highest_hiddens = stack(highest_hiddens) # (l b n d)
294-
295309
# q halt loss is simply on whether the prediction is correct or not
296310

297311
with torch.no_grad():

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "HRM-pytorch"
3-
version = "0.0.14"
3+
version = "0.0.15"
44
description = "The proposal from a Singaporean AGI company"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_hrm.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,36 @@ def test_hrm():
5252
# after much training
5353

5454
pred = hrm(seq, reasoning_steps = 5)
55+
56+
@pytest.mark.parametrize('compute_loss_across_reasoning_steps', (False, True))
57+
def test_hrm_with_act(
58+
compute_loss_across_reasoning_steps
59+
):
60+
from HRM.hrm_with_act import HRM
61+
62+
hrm = HRM(
63+
networks = [
64+
dict(
65+
dim = 32,
66+
depth = 2,
67+
attn_dim_head = 8,
68+
heads = 1,
69+
use_rmsnorm = True,
70+
rotary_pos_emb = True,
71+
pre_norm = False
72+
)
73+
],
74+
num_tokens = 256,
75+
dim = 32,
76+
max_reasoning_steps = 10
77+
)
78+
79+
seq = torch.randint(0, 256, (3, 1024))
80+
labels = torch.randint(0, 256, (3, 1024))
81+
82+
loss, *_ = hrm(seq, labels = labels)
83+
loss.backward()
84+
85+
# after much training
86+
87+
pred = hrm(seq, max_reasoning_steps = 5, compute_loss_across_reasoning_steps = compute_loss_across_reasoning_steps)

0 commit comments

Comments
 (0)