@@ -97,9 +97,9 @@ def forward(self, x, mask = None):
97
97
98
98
if n < seq_len :
99
99
padding = seq_len - n
100
+ mask = default (mask , lambda : torch .ones (b , n , device = device ).bool ())
100
101
x = F .pad (x , (0 , 0 , 0 , padding ), value = 0 )
101
- if exists (mask ):
102
- mask = F .pad (x , (0 , padding ), value = False )
102
+ mask = F .pad (x , (0 , padding ), value = False )
103
103
104
104
qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
105
105
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), qkv )
@@ -196,10 +196,9 @@ def forward(self, x, mask = None):
196
196
197
197
if n < seq_len :
198
198
padding = seq_len - n
199
+ mask = default (mask , lambda : torch .ones (b , n , device = device ).bool ())
199
200
x = F .pad (x , (0 , 0 , 0 , padding ), value = 0 )
200
-
201
- if exists (mask ):
202
- mask = F .pad (x , (0 , padding ), value = False )
201
+ mask = F .pad (x , (0 , padding ), value = False )
203
202
204
203
qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
205
204
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), qkv )
0 commit comments