Skip to content

Commit 0d21bcb

Browse files
committed
another step
1 parent d1d075b commit 0d21bcb

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

HRM/hrm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
):
7272
super().__init__()
7373
attn_layers_klass = Encoder if not causal else Decoder
74+
self.causal = causal
7475

7576
# input
7677

@@ -133,9 +134,14 @@ def forward(
133134
labels = None,
134135
detach_hiddens = True,
135136
one_step_grad = True,
136-
reasoning_steps = None
137+
reasoning_steps = None,
138+
return_autoreg_loss = False
137139
):
138140

141+
if return_autoreg_loss:
142+
assert self.causal
143+
seq, labels = seq[:, :-1], seq[:, 1:]
144+
139145
return_loss = exists(labels)
140146

141147
reasoning_steps = default(reasoning_steps, self.reasoning_steps)

HRM/hrm_with_act.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
):
9090
super().__init__()
9191
attn_layers_klass = Encoder if not causal else Decoder
92+
self.causal = causal
9293

9394
# input
9495

@@ -169,12 +170,17 @@ def forward(
169170
detach_hiddens = True,
170171
one_step_grad = True,
171172
max_reasoning_steps = None,
172-
adaptive_compute = None
173+
adaptive_compute = None,
174+
return_autoreg_loss = False
173175
):
174176
batch, device = seq.shape[0], seq.device
175177

176178
highest_hidden_index = self.num_networks - 1
177179

180+
if return_autoreg_loss:
181+
assert self.causal
182+
seq, labels = seq[:, :-1], seq[:, 1:]
183+
178184
return_loss = exists(labels)
179185

180186
# ACT related variables

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "HRM-pytorch"
3-
version = "0.1.4"
3+
version = "0.1.5"
44
description = "The proposal from a Singaporean AGI company"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)