Skip to content

Commit 8a562b0

Browse files
committed
setup peripherals
1 parent d58e047 commit 8a562b0

File tree

2 files changed

+64
-12
lines changed

2 files changed

+64
-12
lines changed

HRM/hrm.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
from __future__ import annotations
2+
13
import torch
2-
from torch.nn import Module, ModuleList
4+
from torch import nn, Tensor, is_tensor
5+
import torch.nn.functional as F
6+
from torch.nn import Embedding, Linear, Module, ModuleList
7+
from torch.utils._pytree import tree_map
38

49
from einops import rearrange
510

@@ -15,20 +20,51 @@ def exists(v):
1520
def default(v, d):
1621
return v if exists(v) else d
1722

23+
def divisible_by(num, den):
24+
return (num % den) == 0
25+
26+
def tree_map_tensor(sample, fn):
27+
return tree_map(lambda t: t if not is_tensor(t) else fn(t), sample)
28+
1829
# modules
1930

20-
class Input(Module):
21-
def __init__(self):
31+
class HRM(Module):
32+
def __init__(
33+
self,
34+
*,
35+
dim,
36+
num_tokens,
37+
):
2238
super().__init__()
2339

24-
class SlowHighLevelRecurrent(Module):
25-
def __init__(self):
26-
super().__init__()
40+
self.to_input_embed = Embedding(num_tokens, dim)
2741

28-
class FastLowLevelRecurrent(Module):
29-
def __init__(self):
30-
super().__init__()
42+
self.to_pred = Linear(dim, num_tokens, bias = False)
3143

32-
class Output(Module):
33-
def __init__(self):
34-
super().__init__()
44+
def forward(
45+
self,
46+
seq,
47+
hiddens: tuple[Tensor, ...] | None = None,
48+
*,
49+
labels = None,
50+
detach_hiddens = True
51+
):
52+
53+
if detach_hiddens:
54+
hiddens = tree_map_tensor(hiddens, lambda t: t.detach())
55+
56+
tokens = self.to_input_embed(seq)
57+
58+
pred = self.to_pred(tokens)
59+
60+
# if labels passed in, cross entropy loss
61+
62+
if not exists(labels):
63+
return pred, hiddens
64+
65+
loss = F.cross_entropy(
66+
rearrange(pred, 'b n l -> b l n'),
67+
labels
68+
)
69+
70+
return loss, (pred, hiddens)

tests/test_hrm.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import pytest
2+
import torch
3+
4+
def test_hrm():
5+
from HRM.hrm import HRM
6+
7+
hrm = HRM(
8+
num_tokens = 256,
9+
dim = 512,
10+
)
11+
12+
seq = torch.randint(0, 256, (3, 1024))
13+
labels = torch.randint(0, 256, (3, 1024))
14+
15+
loss, (logits, hiddens) = hrm(seq, labels = labels)
16+
loss.backward()

0 commit comments

Comments
 (0)