Skip to content

Commit d1d075b

Browse files
committed
add hparam
1 parent fe45604 commit d1d075b

File tree

4 files changed

+20
-9
lines changed

4 files changed

+20
-9
lines changed

HRM/hrm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from einops import rearrange, repeat
1010
from einops.layers.torch import Rearrange, Reduce
1111

12-
from x_transformers import Encoder, RMSNorm
12+
from x_transformers import Encoder, Decoder, RMSNorm
1313

1414
# helper functions
1515

@@ -66,9 +66,11 @@ def __init__(
6666
num_tokens,
6767
reasoning_steps = 2, # N in the paper - the number of forward evals for the last network (highest hierarchy) above
6868
relative_period: int | tuple[int, ...] = 2, # the relative period for each network evaluation call to the one just previous - in the paper, they do 2 networks with a period of 2
69+
causal = False,
6970
ignore_index = -1,
7071
):
7172
super().__init__()
73+
attn_layers_klass = Encoder if not causal else Decoder
7274

7375
# input
7476

@@ -82,7 +84,7 @@ def __init__(
8284

8385
for network in networks:
8486
if isinstance(network, dict):
85-
network = Encoder(**network)
87+
network = attn_layers_klass(**network)
8688

8789
self.networks.append(network)
8890

HRM/hrm_with_act.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from einops import rearrange, repeat
1414
from einops.layers.torch import Rearrange, Reduce
1515

16-
from x_transformers import Encoder, RMSNorm
16+
from x_transformers import Encoder, Decoder, RMSNorm
1717

1818
# constants
1919

@@ -80,13 +80,15 @@ def __init__(
8080
num_tokens,
8181
reasoning_steps = 2, # N in the paper - the number of forward evals for the last network (highest hierarchy) above
8282
relative_period: int | tuple[int, ...] = 2, # the relative period for each network evaluation call to the one just previous - in the paper, they do 2 networks with a period of 2
83+
causal = False,
8384
min_reasoning_steps_epsilon_prob = 0.5, # they stochastically choose the minimum segment from 2 .. max with this probability, and 1 step the rest of the time
8485
max_reasoning_steps = 10,
8586
act_loss_weight = 1.,
8687
discount_factor = 1.,
8788
ignore_index = -1,
8889
):
8990
super().__init__()
91+
attn_layers_klass = Encoder if not causal else Decoder
9092

9193
# input
9294

@@ -100,7 +102,7 @@ def __init__(
100102

101103
for network in networks:
102104
if isinstance(network, dict):
103-
network = Encoder(**network)
105+
network = attn_layers_klass(**network)
104106

105107
self.networks.append(network)
106108

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "HRM-pytorch"
3-
version = "0.1.2"
3+
version = "0.1.4"
44
description = "The proposal from a Singaporean AGI company"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_hrm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import pytest
2+
param = pytest.mark.parametrize
3+
24
import torch
35

4-
def test_hrm():
6+
@param('causal', (False, True))
7+
def test_hrm(causal):
58
from HRM.hrm import HRM
69
from x_transformers import Encoder
710

@@ -35,6 +38,7 @@ def test_hrm():
3538
pre_norm = False
3639
)
3740
],
41+
causal = causal,
3842
num_tokens = 256,
3943
dim = 32,
4044
reasoning_steps = 10
@@ -53,9 +57,11 @@ def test_hrm():
5357

5458
pred = hrm(seq, reasoning_steps = 5)
5559

56-
@pytest.mark.parametrize('compute_loss_across_reasoning_steps', (False, True))
60+
@param('compute_loss_across_reasoning_steps', (False, True))
61+
@param('causal', (False, True))
5762
def test_hrm_with_act(
58-
compute_loss_across_reasoning_steps
63+
compute_loss_across_reasoning_steps,
64+
causal
5965
):
6066
from HRM.hrm_with_act import HRM
6167

@@ -73,7 +79,8 @@ def test_hrm_with_act(
7379
],
7480
num_tokens = 256,
7581
dim = 32,
76-
max_reasoning_steps = 10
82+
max_reasoning_steps = 10,
83+
causal = causal
7784
)
7885

7986
seq = torch.randint(0, 256, (3, 1024))

0 commit comments

Comments
 (0)