From a3fe08f18465fd11440c8f2d36c3595794092567 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tr=E1=BA=A7n=20B=E1=BA=A3o=20Ch=C3=AD?= Date: Mon, 9 Sep 2024 08:49:59 +0000 Subject: [PATCH 1/2] sdpa --- .../model/internvl_chat/modeling_intern_vit.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py b/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py index 5d875f4b..55ce04a0 100644 --- a/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py +++ b/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py @@ -144,17 +144,13 @@ def _naive_attn(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - + if self.qk_normalization: B_, H_, N_, D_ = q.shape q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) - - attn = ((q * self.scale) @ k.transpose(-2, -1)) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + with torch.backends.cuda.sdp_kernel(enable_math=True): + x = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x From 918a90f84d03e098404e2265677a8eb1297f4cf2 Mon Sep 17 00:00:00 2001 From: Zhe Chen Date: Sat, 23 Nov 2024 22:57:18 +0800 Subject: [PATCH 2/2] Update modeling_intern_vit.py --- .../internvl/model/internvl_chat/modeling_intern_vit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py b/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py index 55ce04a0..197dc802 100644 --- a/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py +++ b/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py @@ -144,13 +144,15 @@ def _naive_attn(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - + if self.qk_normalization: B_, H_, N_, D_ = q.shape q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + with torch.backends.cuda.sdp_kernel(enable_math=True): x = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) x = self.proj_drop(x) return x