Skip to content

Commit 64d8503

Browse files
committed
setup parameters for Q (continnue | halt) for their proposed adaptive computation time
1 parent 254e258 commit 64d8503

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

HRM/hrm.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import torch
55
import torch.nn.functional as F
66
from torch import Tensor, tensor, is_tensor, cat, stack
7-
from torch.nn import Embedding, Linear, Module, ModuleList
7+
from torch.nn import Embedding, Linear, Sequential, Module, ModuleList
88
from torch.utils._pytree import tree_map
99

1010
from einops import rearrange, repeat
11+
from einops.layers.torch import Rearrange, Reduce
1112

1213
from x_transformers import Encoder, RMSNorm
1314

@@ -71,6 +72,8 @@ def __init__(
7172
num_tokens,
7273
reasoning_steps = 2, # N in the paper - the number of forward evals for the last network (highest hierarchy) above
7374
relative_period: int | tuple[int, ...] = 2, # the relative period for each network evaluation call to the one just previous - in the paper, they do 2 networks with a period of 2
75+
min_reasoning_steps = 1,
76+
max_reasoning_steps = 10,
7477
ignore_index = -1
7578
):
7679
super().__init__()
@@ -124,6 +127,18 @@ def __init__(
124127

125128
self.to_pred = Linear(dim, num_tokens, bias = False)
126129

130+
# Q(continue|halt) for their adaptive computation time setup
131+
132+
self.min_reasoning_steps = min_reasoning_steps
133+
self.max_reasoning_steps = max_reasoning_steps
134+
135+
self.to_q_continue_halt = Sequential(
136+
Reduce('b n d -> b d', 'mean'),
137+
RMSNorm(dim),
138+
Linear(dim, 2, bias = False),
139+
Rearrange('... continue_halt -> continue_halt ...')
140+
)
141+
127142
# loss related
128143

129144
self.ignore_index = ignore_index
@@ -191,6 +206,7 @@ def evaluate_network_(
191206

192207
with context():
193208
for index in range(reasoning_steps * self.lowest_steps_per_reasoning_step - 1):
209+
194210
iteration = index + 1
195211

196212
for network_index, (network, hidden_combine, evaluate_network_at) in enumerate(zip(self.networks, self.hidden_combiners, self.evaluate_networks_at)):
@@ -200,6 +216,19 @@ def evaluate_network_(
200216

201217
evaluate_network_(network, hidden_combine, network_index)
202218

219+
# adaptive computation time
220+
221+
is_reasoning_step_boundary = divisible_by(index, reasoning_steps)
222+
num_reasoning_steps = index // reasoning_steps
223+
224+
if is_reasoning_step_boundary and num_reasoning_steps > self.min_reasoning_steps:
225+
226+
highest_hidden = hiddens[self.num_networks - 1]
227+
228+
q_continue, q_halt = self.to_q_continue_halt(highest_hidden)
229+
230+
should_continue_ = q_halt > q_continue
231+
203232
# 1-step gradient learning
204233

205234
for network_index, (network, hidden_combine) in enumerate(zip(self.networks, self.hidden_combiners)):

0 commit comments

Comments
 (0)