Skip to content

Commit 21d12c9

Browse files
committed
Add support for NATTEN 0.17.0 (fused neighborhood attention)
1 parent 6ab5146 commit 21d12c9

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

k_diffusion/models/image_transformer_v2.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -407,18 +407,28 @@ def forward(self, x, pos, cond):
407407
skip = x
408408
x = self.norm(x, cond)
409409
qkv = self.qkv_proj(x)
410-
q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head)
411-
q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6)
412-
theta = self.pos_emb(pos).movedim(-2, -4)
413-
q = apply_rotary_emb_(q, theta)
414-
k = apply_rotary_emb_(k, theta)
415410
if natten is None:
416411
raise ModuleNotFoundError("natten is required for neighborhood attention")
417-
flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size)
418-
qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1)
419-
a = torch.softmax(qk, dim=-1).to(v.dtype)
420-
x = natten.functional.natten2dav(a, v, self.kernel_size, 1)
421-
x = rearrange(x, "n nh h w e -> n h w (nh e)")
412+
if natten.has_fused_na():
413+
q, k, v = rearrange(qkv, "n h w (t nh e) -> t n h w nh e", t=3, e=self.d_head)
414+
q, k = scale_for_cosine_sim(q, k, self.scale[:, None], 1e-6)
415+
theta = self.pos_emb(pos)
416+
q = apply_rotary_emb_(q, theta)
417+
k = apply_rotary_emb_(k, theta)
418+
flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size)
419+
x = natten.functional.na2d(q, k, v, self.kernel_size, scale=1.0)
420+
x = rearrange(x, "n h w nh e -> n h w (nh e)")
421+
else:
422+
q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head)
423+
q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6)
424+
theta = self.pos_emb(pos).movedim(-2, -4)
425+
q = apply_rotary_emb_(q, theta)
426+
k = apply_rotary_emb_(k, theta)
427+
flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size)
428+
qk = natten.functional.na2d_qk(q, k, self.kernel_size)
429+
a = torch.softmax(qk, dim=-1).to(v.dtype)
430+
x = natten.functional.na2d_av(a, v, self.kernel_size)
431+
x = rearrange(x, "n nh h w e -> n h w (nh e)")
422432
x = self.dropout(x)
423433
x = self.out_proj(x)
424434
return x + skip

0 commit comments

Comments
 (0)