3
3
import torch .nn .functional as F
4
4
5
5
from einops import rearrange
6
+ from einops .layers .torch import Rearrange
6
7
7
8
# helper functions
8
9
@@ -22,15 +23,6 @@ class Swish(nn.Module):
22
23
def forward (self , x ):
23
24
return x * x .sigmoid ()
24
25
25
- class Transpose (nn .Module ):
26
- def __init__ (self , dims ):
27
- super ().__init__ ()
28
- assert len (dims ) == 2 , 'dims must be a tuple of two dimensions'
29
- self .dims = dims
30
-
31
- def forward (self , x ):
32
- return x .transpose (* self .dims )
33
-
34
26
class GLU (nn .Module ):
35
27
def __init__ (self , dim ):
36
28
super ().__init__ ()
@@ -104,7 +96,7 @@ def forward(self, x, context = None, mask = None, context_mask = None):
104
96
105
97
# shaw's relative positional embedding
106
98
seq = torch .arange (n , device = device )
107
- dist = seq [:, None ] - seq [ None , :]
99
+ dist = rearrange ( seq , 'i -> i ()' ) - rearrange ( seq , 'j -> () j' )
108
100
dist = dist .clip (- max_pos_emb , max_pos_emb ) + max_pos_emb
109
101
rel_pos_emb = self .rel_pos_emb (dist ).to (q )
110
102
pos_attn = einsum ('b h n d, n r d -> b h n r' , q , rel_pos_emb ) * self .scale
@@ -114,7 +106,7 @@ def forward(self, x, context = None, mask = None, context_mask = None):
114
106
mask = default (mask , lambda : torch .ones (* x .shape [:2 ], device = device ))
115
107
context_mask = default (context_mask , mask ) if not has_context else default (context_mask , lambda : torch .ones (* context .shape [:2 ], device = device ))
116
108
mask_value = - torch .finfo (dots .dtype ).max
117
- mask = mask [:, None , :, None ] * context_mask [:, None , None , :]
109
+ mask = rearrange ( mask , 'b i -> b () i ()' ) * rearrange ( context_mask , 'b j -> b () () j' )
118
110
dots .masked_fill_ (~ mask , mask_value )
119
111
120
112
attn = dots .softmax (dim = - 1 )
@@ -158,14 +150,14 @@ def __init__(
158
150
159
151
self .net = nn .Sequential (
160
152
nn .LayerNorm (dim ),
161
- Transpose (( 1 , 2 ) ),
153
+ Rearrange ( 'b n c -> b c n' ),
162
154
nn .Conv1d (dim , inner_dim * 2 , 1 ),
163
155
GLU (dim = 1 ),
164
156
DepthWiseConv1d (inner_dim , inner_dim , kernel_size = kernel_size , padding = padding ),
165
157
nn .BatchNorm1d (inner_dim ) if not causal else nn .Identity (),
166
158
Swish (),
167
159
nn .Conv1d (inner_dim , dim , 1 ),
168
- Transpose (( 1 , 2 ) ),
160
+ Rearrange ( 'b c n -> b n c' ),
169
161
nn .Dropout (dropout )
170
162
)
171
163
0 commit comments