Skip to content

Commit 2356f3f

Browse files
committed
handle single conditioning broadcasted across entire sequence of features
1 parent 88b2be4 commit 2356f3f

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

mogrifier/mogrifier.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import nn, Tensor
55
from torch.nn import Module
66

7-
from einops import pack, unpack
7+
from einops import repeat, pack, unpack
88

99
# constants
1010

@@ -42,7 +42,7 @@ def __init__(
4242
iters = 5,
4343
factorize_k: int | None = None,
4444
dim_hidden: int | None = None,
45-
hidden_factorize_k: int | None = None
45+
hidden_factorize_k: int | None = None,
4646
):
4747
super().__init__()
4848
assert iters > 1
@@ -74,6 +74,9 @@ def forward(
7474
):
7575
iters = default(iters, self.iters)
7676

77+
if inputs.ndim == 3 and hiddens.ndim == 2:
78+
hiddens = repeat(hiddens, 'b d -> b n d', n = inputs.shape[-2])
79+
7780
assert inputs.shape[-1] == self.dim
7881
assert hiddens.shape[-1] == self.dim_hidden
7982
assert inputs.shape[:-2] == hiddens.shape[:-2]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'mogrifier',
55
packages = find_packages(),
6-
version = '0.0.4',
6+
version = '0.0.5',
77
license='MIT',
88
description = 'Implementation of Mogrifier circuit from Deepmind',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)