@@ -407,18 +407,28 @@ def forward(self, x, pos, cond):
407
407
skip = x
408
408
x = self .norm (x , cond )
409
409
qkv = self .qkv_proj (x )
410
- q , k , v = rearrange (qkv , "n h w (t nh e) -> t n nh h w e" , t = 3 , e = self .d_head )
411
- q , k = scale_for_cosine_sim (q , k , self .scale [:, None , None , None ], 1e-6 )
412
- theta = self .pos_emb (pos ).movedim (- 2 , - 4 )
413
- q = apply_rotary_emb_ (q , theta )
414
- k = apply_rotary_emb_ (k , theta )
415
410
if natten is None :
416
411
raise ModuleNotFoundError ("natten is required for neighborhood attention" )
417
- flops .op (flops .op_natten , q .shape , k .shape , v .shape , self .kernel_size )
418
- qk = natten .functional .natten2dqk (q , k , self .kernel_size , 1 )
419
- a = torch .softmax (qk , dim = - 1 ).to (v .dtype )
420
- x = natten .functional .natten2dav (a , v , self .kernel_size , 1 )
421
- x = rearrange (x , "n nh h w e -> n h w (nh e)" )
412
+ if natten .has_fused_na ():
413
+ q , k , v = rearrange (qkv , "n h w (t nh e) -> t n h w nh e" , t = 3 , e = self .d_head )
414
+ q , k = scale_for_cosine_sim (q , k , self .scale [:, None ], 1e-6 )
415
+ theta = self .pos_emb (pos )
416
+ q = apply_rotary_emb_ (q , theta )
417
+ k = apply_rotary_emb_ (k , theta )
418
+ flops .op (flops .op_natten , q .shape , k .shape , v .shape , self .kernel_size )
419
+ x = natten .functional .na2d (q , k , v , self .kernel_size , scale = 1.0 )
420
+ x = rearrange (x , "n h w nh e -> n h w (nh e)" )
421
+ else :
422
+ q , k , v = rearrange (qkv , "n h w (t nh e) -> t n nh h w e" , t = 3 , e = self .d_head )
423
+ q , k = scale_for_cosine_sim (q , k , self .scale [:, None , None , None ], 1e-6 )
424
+ theta = self .pos_emb (pos ).movedim (- 2 , - 4 )
425
+ q = apply_rotary_emb_ (q , theta )
426
+ k = apply_rotary_emb_ (k , theta )
427
+ flops .op (flops .op_natten , q .shape , k .shape , v .shape , self .kernel_size )
428
+ qk = natten .functional .na2d_qk (q , k , self .kernel_size )
429
+ a = torch .softmax (qk , dim = - 1 ).to (v .dtype )
430
+ x = natten .functional .na2d_av (a , v , self .kernel_size )
431
+ x = rearrange (x , "n nh h w e -> n h w (nh e)" )
422
432
x = self .dropout (x )
423
433
x = self .out_proj (x )
424
434
return x + skip
0 commit comments