Skip to content

Commit 16bd329

Browse files
committed
setup the min reasoning steps correctly
1 parent da69b7d commit 16bd329

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

HRM/hrm.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22
from contextlib import nullcontext
3+
from random import randrange, random
34

45
import torch
56
import torch.nn.functional as F
@@ -25,6 +26,9 @@ def default(v, d):
2526
def last(arr):
2627
return arr[-1]
2728

29+
def satisfy_prob(prob):
30+
return random() < prob
31+
2832
def divisible_by(num, den):
2933
return (num % den) == 0
3034

@@ -72,7 +76,7 @@ def __init__(
7276
num_tokens,
7377
reasoning_steps = 2, # N in the paper - the number of forward evals for the last network (highest hierarchy) above
7478
relative_period: int | tuple[int, ...] = 2, # the relative period for each network evaluation call to the one just previous - in the paper, they do 2 networks with a period of 2
75-
min_reasoning_steps = 1,
79+
min_reasoning_steps_epsilon_prob = 0.5, # they stochastically choose the minimum segment from 2 .. max with this probability, and 1 step the rest of the time
7680
max_reasoning_steps = 10,
7781
act_binary_ce_loss_weight = 1.,
7882
ignore_index = -1
@@ -132,7 +136,7 @@ def __init__(
132136

133137
self.act_binary_ce_loss_weight = act_binary_ce_loss_weight
134138

135-
self.min_reasoning_steps = min_reasoning_steps
139+
self.min_reasoning_steps_epsilon_prob = min_reasoning_steps_epsilon_prob
136140
self.max_reasoning_steps = max_reasoning_steps
137141

138142
self.to_q_continue_halt = Sequential(
@@ -154,10 +158,10 @@ def forward(
154158
labels = None,
155159
detach_hiddens = True,
156160
one_step_grad = True,
157-
reasoning_steps = None
161+
max_reasoning_steps = None
158162
):
159163

160-
reasoning_steps = default(reasoning_steps, self.reasoning_steps)
164+
max_reasoning_steps = default(max_reasoning_steps, self.max_reasoning_steps)
161165

162166
if detach_hiddens:
163167
hiddens = tree_map_tensor(hiddens, lambda t: t.detach())
@@ -207,8 +211,13 @@ def evaluate_network_(
207211

208212
context = torch.no_grad if one_step_grad else nullcontext
209213

214+
min_reasoning_steps = self.max_reasoning_steps
215+
216+
if self.training:
217+
min_reasoning_steps = randrange(2, self.max_reasoning_steps + 1) if satisfy_prob(self.min_reasoning_steps_epsilon_prob) else 1
218+
210219
with context():
211-
for index in range(reasoning_steps * self.lowest_steps_per_reasoning_step - 1):
220+
for index in range(max_reasoning_steps * self.lowest_steps_per_reasoning_step - 1):
212221

213222
iteration = index + 1
214223

@@ -221,10 +230,10 @@ def evaluate_network_(
221230

222231
# adaptive computation time
223232

224-
is_reasoning_step_boundary = divisible_by(index, reasoning_steps)
225-
num_reasoning_steps = index // reasoning_steps
233+
is_reasoning_step_boundary = divisible_by(index, self.lowest_steps_per_reasoning_step)
234+
num_reasoning_steps = index // self.lowest_steps_per_reasoning_step
226235

227-
if is_reasoning_step_boundary and num_reasoning_steps > self.min_reasoning_steps:
236+
if is_reasoning_step_boundary and num_reasoning_steps > min_reasoning_steps:
228237

229238
highest_hidden = hiddens[self.num_networks - 1]
230239

tests/test_hrm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ def test_hrm():
3636
],
3737
num_tokens = 256,
3838
dim = 32,
39-
reasoning_steps = 3
40-
39+
reasoning_steps = 10
4140
)
4241

4342
seq = torch.randint(0, 256, (3, 1024))

0 commit comments

Comments
 (0)