diff --git a/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py b/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py index 5d875f4b..197dc802 100644 --- a/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py +++ b/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py @@ -150,11 +150,9 @@ def _naive_attn(self, x): q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) - attn = ((q * self.scale) @ k.transpose(-2, -1)) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + with torch.backends.cuda.sdp_kernel(enable_math=True): + x = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(B, N, C) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x