Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 50 additions & 27 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import logging
from typing import Optional

import aiter
Expand All @@ -15,6 +16,8 @@

from .attention_mla import MLAModules

logger = logging.getLogger("atom")

from atom.plugin.prepare import is_plugin_mode, is_vllm
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file


Expand Down Expand Up @@ -122,38 +125,58 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
use_triton_attn = self.sliding_window != -1 or self.head_dim != 128
self.use_triton_attn = use_triton_attn

_fused_ok = False
if (
self.rotary_emb is not None
and self.q_norm is not None
and self.k_norm is not None
):
fused_qk_norm_rope_cache_quant_shuffle(
qkv,
num_heads_q=self.num_heads,
num_heads_k=self.num_kv_heads,
num_heads_v=self.num_kv_heads,
head_dim=self.head_dim,
eps=self.q_norm.eps,
qw=self.q_norm.weight,
kw=self.k_norm.weight,
cos_sin_cache=self.rotary_emb.cos_sin_cache,
is_neox_style=self.rotary_emb.is_neox_style,
pos_ids=position,
k_cache=k_cache,
v_cache=v_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype=(
"auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype
),
k_scale=k_scale,
v_scale=v_scale,
)
qkv_backup = qkv.clone()
try:
fused_qk_norm_rope_cache_quant_shuffle(
qkv,
num_heads_q=self.num_heads,
num_heads_k=self.num_kv_heads,
num_heads_v=self.num_kv_heads,
head_dim=self.head_dim,
eps=self.q_norm.eps,
qw=self.q_norm.weight,
kw=self.k_norm.weight,
cos_sin_cache=self.rotary_emb.cos_sin_cache,
is_neox_style=self.rotary_emb.is_neox_style,
pos_ids=position,
k_cache=k_cache,
v_cache=v_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype=(
"auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype
),
k_scale=k_scale,
v_scale=v_scale,
)

qkv = qkv.view(qkv.shape[0], -1, self.head_dim)
q, k, v = qkv.split(
[self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1
)
elif use_triton_attn and self.rotary_emb is not None:
qkv = qkv.view(qkv.shape[0], -1, self.head_dim)
q, k, v = qkv.split(
[self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1
)
_fused_ok = True
except Exception as e:
if not getattr(PagedAttentionImpl, "_fused_rope_warned", False):
logger.warning(
"fused_qk_norm_rope_cache_quant_shuffle failed (%s), "
"falling back to non-fused path",
e,
)
PagedAttentionImpl._fused_rope_warned = True
qkv.copy_(qkv_backup)
del qkv_backup

if (
not _fused_ok
and use_triton_attn
and self.rotary_emb is not None
and self.q_norm is None
):
k_scale = v_scale = self.kv_scale

q, k, k_cache, v_cache = fused_qk_rope_reshape_and_cache(
Expand All @@ -176,7 +199,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
k_out=k,
output_zeros=False,
)
else:
elif not _fused_ok:
# for asm paged attention
asm_layout = True
if use_triton_attn:
Expand Down
Loading