File tree Expand file tree Collapse file tree 2 files changed +6
-3
lines changed
Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change 44from torch import nn , Tensor
55from torch .nn import Module
66
7- from einops import pack , unpack
7+ from einops import repeat , pack , unpack
88
99# constants
1010
@@ -42,7 +42,7 @@ def __init__(
4242 iters = 5 ,
4343 factorize_k : int | None = None ,
4444 dim_hidden : int | None = None ,
45- hidden_factorize_k : int | None = None
45+ hidden_factorize_k : int | None = None ,
4646 ):
4747 super ().__init__ ()
4848 assert iters > 1
@@ -74,6 +74,9 @@ def forward(
7474 ):
7575 iters = default (iters , self .iters )
7676
77+ if inputs .ndim == 3 and hiddens .ndim == 2 :
78+ hiddens = repeat (hiddens , 'b d -> b n d' , n = inputs .shape [- 2 ])
79+
7780 assert inputs .shape [- 1 ] == self .dim
7881 assert hiddens .shape [- 1 ] == self .dim_hidden
7982 assert inputs .shape [:- 2 ] == hiddens .shape [:- 2 ]
Original file line number Diff line number Diff line change 33setup (
44 name = 'mogrifier' ,
55 packages = find_packages (),
6- version = '0.0.4 ' ,
6+ version = '0.0.5 ' ,
77 license = 'MIT' ,
88 description = 'Implementation of Mogrifier circuit from Deepmind' ,
99 long_description_content_type = 'text/markdown' ,
You can’t perform that action at this time.
0 commit comments