1+ from __future__ import annotations
2+
13import torch
2- from torch .nn import Module , ModuleList
4+ from torch import nn , Tensor , is_tensor
5+ import torch .nn .functional as F
6+ from torch .nn import Embedding , Linear , Module , ModuleList
7+ from torch .utils ._pytree import tree_map
38
49from einops import rearrange
510
@@ -15,20 +20,51 @@ def exists(v):
1520def default (v , d ):
1621 return v if exists (v ) else d
1722
23+ def divisible_by (num , den ):
24+ return (num % den ) == 0
25+
26+ def tree_map_tensor (sample , fn ):
27+ return tree_map (lambda t : t if not is_tensor (t ) else fn (t ), sample )
28+
1829# modules
1930
20- class Input (Module ):
21- def __init__ (self ):
31+ class HRM (Module ):
32+ def __init__ (
33+ self ,
34+ * ,
35+ dim ,
36+ num_tokens ,
37+ ):
2238 super ().__init__ ()
2339
24- class SlowHighLevelRecurrent (Module ):
25- def __init__ (self ):
26- super ().__init__ ()
40+ self .to_input_embed = Embedding (num_tokens , dim )
2741
28- class FastLowLevelRecurrent (Module ):
29- def __init__ (self ):
30- super ().__init__ ()
42+ self .to_pred = Linear (dim , num_tokens , bias = False )
3143
32- class Output (Module ):
33- def __init__ (self ):
34- super ().__init__ ()
44+ def forward (
45+ self ,
46+ seq ,
47+ hiddens : tuple [Tensor , ...] | None = None ,
48+ * ,
49+ labels = None ,
50+ detach_hiddens = True
51+ ):
52+
53+ if detach_hiddens :
54+ hiddens = tree_map_tensor (hiddens , lambda t : t .detach ())
55+
56+ tokens = self .to_input_embed (seq )
57+
58+ pred = self .to_pred (tokens )
59+
60+ # if labels passed in, cross entropy loss
61+
62+ if not exists (labels ):
63+ return pred , hiddens
64+
65+ loss = F .cross_entropy (
66+ rearrange (pred , 'b n l -> b l n' ),
67+ labels
68+ )
69+
70+ return loss , (pred , hiddens )
0 commit comments