Skip to content

Commit d9c77e4

Browse files
committed
able to override reasoning steps on forward
1 parent 15fe9e1 commit d9c77e4

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

HRM/hrm.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55
import torch.nn.functional as F
6-
from torch import nn, Tensor, tensor, is_tensor, cat, stack
6+
from torch import Tensor, tensor, is_tensor, cat, stack
77
from torch.nn import Embedding, Linear, Module, ModuleList
88
from torch.utils._pytree import tree_map
99

@@ -38,9 +38,9 @@ def __init__(
3838
super().__init__()
3939
self.num_hiddens_to_concat = num_hiddens_to_concat
4040

41-
self.norms = ModuleList([nn.RMSNorm(dim) for _ in range(num_hiddens_to_concat)])
41+
self.norms = ModuleList([RMSNorm(dim) for _ in range(num_hiddens_to_concat)])
4242

43-
self.to_combined = nn.Linear(dim * self.num_hiddens_to_concat, dim, bias = False)
43+
self.to_combined = Linear(dim * self.num_hiddens_to_concat, dim, bias = False)
4444

4545
def forward(
4646
self,
@@ -132,9 +132,12 @@ def forward(
132132
*,
133133
labels = None,
134134
detach_hiddens = True,
135-
one_step_grad = True
135+
one_step_grad = True,
136+
reasoning_steps = None
136137
):
137138

139+
reasoning_steps = default(reasoning_steps, self.reasoning_steps)
140+
138141
if detach_hiddens:
139142
hiddens = tree_map_tensor(hiddens, lambda t: t.detach())
140143

@@ -184,7 +187,7 @@ def evaluate_network_(
184187
context = torch.no_grad if one_step_grad else nullcontext
185188

186189
with context():
187-
for index in range(self.reasoning_steps * self.lowest_steps_per_reasoning_step - 1):
190+
for index in range(reasoning_steps * self.lowest_steps_per_reasoning_step - 1):
188191
iteration = index + 1
189192

190193
for network_index, (network, hidden_combine, evaluate_network_at) in enumerate(zip(self.networks, self.hidden_combiners, self.evaluate_networks_at)):

tests/test_hrm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def test_hrm():
3030
],
3131
num_tokens = 256,
3232
dim = 512,
33+
reasoning_steps = 3
3334

3435
)
3536

@@ -38,3 +39,7 @@ def test_hrm():
3839

3940
loss, (logits, hiddens) = hrm(seq, labels = labels)
4041
loss.backward()
42+
43+
# after much training
44+
45+
pred = hrm(seq, reasoning_steps = 5)

0 commit comments

Comments
 (0)