Skip to content

Commit c876013

Browse files
committed
setup their scheme, presumably each hierarchy can only have access to the one directly below it. the lowest hidden is implied to be the original token sequence, unchanging. lots of inspiration from the old DEQ paper
1 parent d3e38a0 commit c876013

File tree

1 file changed

+61
-16
lines changed

1 file changed

+61
-16
lines changed

HRM/hrm.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from contextlib import nullcontext
33

44
import torch
5-
from torch import nn, Tensor, tensor, is_tensor
65
import torch.nn.functional as F
6+
from torch import nn, Tensor, tensor, is_tensor, stack
77
from torch.nn import Embedding, Linear, Module, ModuleList
88
from torch.utils._pytree import tree_map
99

10-
from einops import rearrange
10+
from einops import rearrange, repeat
1111

1212
from x_transformers import Encoder
1313

@@ -18,6 +18,12 @@
1818
def exists(v):
1919
return v is not None
2020

21+
def first(arr):
22+
return arr[0]
23+
24+
def last(arr):
25+
return arr[-1]
26+
2127
def default(v, d):
2228
return v if exists(v) else d
2329

@@ -36,8 +42,9 @@ def __init__(
3642
*,
3743
dim,
3844
num_tokens,
39-
reasoning_steps = 2, # N in the paper - the number of forward evals for the last network (highest hierarchy) above
40-
relative_period: int | tuple[int, ...] = 2 # the relative period for each network evaluation call to the one just previous - in the paper, they do 2 networks with a period of 2
45+
reasoning_steps = 2, # N in the paper - the number of forward evals for the last network (highest hierarchy) above
46+
relative_period: int | tuple[int, ...] = 2, # the relative period for each network evaluation call to the one just previous - in the paper, they do 2 networks with a period of 2
47+
ignore_index = -1
4148
):
4249
super().__init__()
4350

@@ -57,34 +64,39 @@ def __init__(
5764

5865
self.networks.append(network)
5966

60-
assert len(self.networks) > 0
67+
self.num_networks = len(self.networks)
68+
assert self.num_networks > 0
6169

6270
# setup how frequent each network is called
6371
# the first network (lowest in the hierarchy) should be called every iteration
6472

65-
num_higher_networks = len(self.networks) - 1
73+
num_higher_networks = self.num_networks - 1
6674

6775
if not isinstance(relative_period, tuple):
6876
relative_period = (relative_period,) * num_higher_networks
6977

7078
# implied that first network is called always
7179

72-
if len(relative_period) == (len(self.networks) - 1):
80+
if len(relative_period) == (self.num_networks - 1):
7381
relative_period = (1, *relative_period)
7482

7583
# for the paper, they did (low: 1, high: 2) -
7684

77-
assert len(relative_period) == len(self.networks) and relative_period[0] == 1
85+
assert len(relative_period) == self.num_networks and relative_period[0] == 1
7886

7987
self.evaluate_networks_at = tensor(relative_period).cumprod(dim = -1).tolist()
8088

8189
self.reasoning_steps = reasoning_steps
82-
self.lowest_steps_per_reasoning_step = self.evaluate_networks_at[-1]
90+
self.lowest_steps_per_reasoning_step = last(self.evaluate_networks_at)
8391

8492
# output
8593

8694
self.to_pred = Linear(dim, num_tokens, bias = False)
8795

96+
# loss related
97+
98+
self.ignore_index = ignore_index
99+
88100
def forward(
89101
self,
90102
seq,
@@ -102,29 +114,61 @@ def forward(
102114

103115
tokens = self.to_input_embed(seq)
104116

117+
# handle hiddens
118+
119+
if not exists(hiddens):
120+
hiddens = torch.zeros_like(tokens)
121+
hiddens = repeat(hiddens, '... -> num_networks ...', num_networks = self.num_networks)
122+
123+
assert len(hiddens) == self.num_networks
124+
105125
# network as they proposed - following figure 4
106126

127+
def evaluate_network(
128+
network: Module,
129+
network_index
130+
):
131+
all_hiddens = (tokens, *hiddens)
132+
network_input = all_hiddens[network_index:-1]
133+
134+
# combine with mean pool for now
135+
136+
combined_input = stack(network_input).mean(dim = 0)
137+
138+
# forward
139+
140+
next_hidden = network(combined_input)
141+
142+
# store hiddens at appropriate hierarchy, low to highest
143+
144+
hiddens[network_index] = next_hidden
145+
146+
# maybe 1-step
147+
107148
context = torch.no_grad if one_step_grad else nullcontext
108149

109150
with context():
110151
for index in range(self.reasoning_steps * self.lowest_steps_per_reasoning_step - 1):
111152
iteration = index + 1
112153

113-
for network, evaluate_network_at in zip(self.networks, self.evaluate_networks_at):
154+
for network_index, (network, evaluate_network_at) in enumerate(zip(self.networks, self.evaluate_networks_at)):
114155

115156
if not divisible_by(iteration, evaluate_network_at):
116157
continue
117158

118-
tokens = network(tokens)
159+
evaluate_network(network, network_index)
119160

120161
# 1-step gradient learning
121162

122-
for network in self.networks:
123-
tokens = network(tokens)
163+
for network_index, network in enumerate(self.networks):
164+
165+
evaluate_network(network, network_index)
166+
167+
# to output prediction, using the hiddens from the highest hierarchy
124168

125-
# to output prediction
169+
highest_hidden = last(hiddens)
126170

127-
pred = self.to_pred(tokens)
171+
pred = self.to_pred(highest_hidden)
128172

129173
# if labels passed in, cross entropy loss
130174

@@ -133,7 +177,8 @@ def forward(
133177

134178
loss = F.cross_entropy(
135179
rearrange(pred, 'b n l -> b l n'),
136-
labels
180+
labels,
181+
ignore_index = self.ignore_index
137182
)
138183

139184
return loss, (pred, hiddens)

0 commit comments

Comments
 (0)