Skip to content

Commit bf12fa6

Browse files
Update records/track_10min_16mb/hardik_top5_run/train_gpt.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent c497bbf commit bf12fa6

1 file changed

Lines changed: 16 additions & 1 deletion

File tree

records/track_10min_16mb/hardik_top5_run/train_gpt.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,22 @@
4949
except ImportError:
5050
def flash_attn_3_func(q: Tensor, k: Tensor, v: Tensor, causal: bool = True) -> Tensor:
5151
scale = 1.0 / math.sqrt(q.size(-1))
52-
return F.scaled_dot_product_attention(q, k, v, is_causal=causal, scale=scale)
52+
q_sdpa = q.transpose(1, 2)
53+
k_sdpa = k.transpose(1, 2)
54+
v_sdpa = v.transpose(1, 2)
55+
if q_sdpa.size(1) != k_sdpa.size(1):
56+
if q_sdpa.size(1) % k_sdpa.size(1) != 0:
57+
raise ValueError(
58+
f"Incompatible attention head counts for fallback SDPA: "
59+
f"q has {q_sdpa.size(1)} heads, k/v have {k_sdpa.size(1)} heads."
60+
)
61+
repeat_factor = q_sdpa.size(1) // k_sdpa.size(1)
62+
k_sdpa = k_sdpa.repeat_interleave(repeat_factor, dim=1)
63+
v_sdpa = v_sdpa.repeat_interleave(repeat_factor, dim=1)
64+
out = F.scaled_dot_product_attention(
65+
q_sdpa, k_sdpa, v_sdpa, is_causal=causal, scale=scale
66+
)
67+
return out.transpose(1, 2)
5368

5469

5570
# ──────────────────────────────────────────────────────────────────────────────

0 commit comments

Comments
 (0)