Skip to content

Commit 5b283c8

Browse files
committed
some fixes, resume with ACT tomorrow
1 parent 16bd329 commit 5b283c8

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

HRM/hrm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def evaluate_network_(
192192

193193
all_hiddens = (
194194
tokens,
195-
*[hiddens[i] for i in range(self.num_networks)]
195+
*hiddens.values()
196196
)
197197

198198
# combine with mean pool for now
@@ -239,7 +239,7 @@ def evaluate_network_(
239239

240240
q_continue, q_halt = self.to_q_continue_halt(highest_hidden).sigmoid()
241241

242-
should_continue = q_halt > q_continue
242+
should_halt = q_halt > q_continue
243243

244244
# 1-step gradient learning
245245

@@ -255,6 +255,8 @@ def evaluate_network_(
255255

256256
# if labels passed in, cross entropy loss
257257

258+
hiddens = hiddens.values()
259+
258260
if not exists(labels):
259261
return pred, hiddens
260262

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "HRM-pytorch"
3-
version = "0.0.3"
3+
version = "0.0.4"
44
description = "The proposal from a Singaporean AGI company"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_hrm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,12 @@ def test_hrm():
4242
seq = torch.randint(0, 256, (3, 1024))
4343
labels = torch.randint(0, 256, (3, 1024))
4444

45-
loss, (logits, hiddens) = hrm(seq, labels = labels)
45+
loss, (_, hiddens) = hrm(seq, labels = labels)
46+
loss.backward()
47+
48+
loss, (_, hiddens) = hrm(seq, hiddens = hiddens, labels = labels)
4649
loss.backward()
4750

4851
# after much training
4952

50-
pred = hrm(seq, reasoning_steps = 5)
53+
pred = hrm(seq, max_reasoning_steps = 5)

0 commit comments

Comments
 (0)