Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import os
from typing import Optional

import aiter
Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(
if self.kv_cache_dtype == "fp8"
else 1.0
)
self.kv_scale = torch.tensor(self.kv_scale_float, dtype=torch.float32)
self.kv_scale = torch.tensor(self.kv_scale_float, dtype=torch.float32).cuda()
self.sinks = sinks
self.sliding_window = sliding_window if sliding_window is not None else -1
self.rotary_emb = rotary_emb
Expand Down Expand Up @@ -100,8 +101,11 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
k_scale = kv_cache_data[f"layer_{self.layer_num}"].k_scale
v_scale = kv_cache_data[f"layer_{self.layer_num}"].v_scale

use_triton_attn = self.sliding_window != -1 or self.head_dim != 128
self.use_triton_attn = use_triton_attn
# PA dispatch decision (independent of cache update strategy)
use_asm_pa = self.sliding_window == -1 and self.head_dim == 128
if os.environ.get("AITER_FORCE_TRITON_ATTN", "0") == "1":
use_asm_pa = False
self.use_triton_attn = not use_asm_pa

if (
self.rotary_emb is not None
Expand Down Expand Up @@ -134,8 +138,9 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
q, k, v = qkv.split(
[self.num_heads, self.num_kv_heads, self.num_kv_heads], dim=1
)
elif use_triton_attn and self.rotary_emb is not None:
k_scale = v_scale = self.kv_scale
elif self.rotary_emb is not None:
# Always use Triton fused rope+cache (fast, no module_cache JIT)
triton_scale = self.kv_scale

q, k, k_cache, v_cache = fused_qk_rope_reshape_and_cache(
q,
Expand All @@ -147,8 +152,8 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
position,
self.rotary_emb.cos_cache,
self.rotary_emb.sin_cache,
k_scale,
v_scale,
triton_scale,
triton_scale,
self.rotary_emb.is_neox_style,
flash_layout=False,
apply_scale=self.kv_cache_dtype.startswith("fp8"),
Expand All @@ -157,14 +162,30 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
k_out=k,
output_zeros=False,
)

# Set scales for the PA backend
if use_asm_pa:
if self.kv_cache_dtype == "bf16":
# bf16 cache: no dequant needed
k_scale = None
v_scale = None
else:
# fp8 cache: Triton fused cache applied per-tensor scale
# inline. Fill per-token scale buffers with the uniform
# per-tensor value so ASM PA can dequant correctly.
slots = attn_metadata.slot_mapping.clamp(min=0)
block_size = k_cache.shape[3]
block_indices = slots // block_size
block_offsets = slots % block_size
# k_scale/v_scale from kv_cache_data: [blocks, heads, block_size]
k_scale[block_indices, :, block_offsets] = self.kv_scale
v_scale[block_indices, :, block_offsets] = self.kv_scale
else:
# Triton PA uses per-tensor scale directly
k_scale = triton_scale
v_scale = triton_scale
else:
# for asm paged attention
asm_layout = True
if use_triton_attn:
asm_layout = False
if self.rotary_emb is not None:
assert position is not None
q, k = self.rotary_emb(position, q, k)
# Non-rope fallback (models without rotary embedding)
if self.q_norm is not None:
q = self.q_norm(q)
if self.k_norm is not None:
Expand All @@ -178,7 +199,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
k_scale,
v_scale,
attn_metadata.slot_mapping,
asm_layout=asm_layout,
asm_layout=use_asm_pa,
)
else:
aiter.reshape_and_cache(
Expand All @@ -190,7 +211,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
kv_cache_dtype="auto",
k_scale=None,
v_scale=None,
asm_layout=asm_layout,
asm_layout=use_asm_pa,
)

return q, k, v, k_cache, v_cache, k_scale, v_scale
Expand Down
8 changes: 6 additions & 2 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,13 @@ def __init__(
else None
)
self.layer_num = layer_num
self.use_triton_mla_decode = envs.ATOM_USE_TRITON_MLA_DECODE
self.use_triton_mla_decode = (
envs.ATOM_USE_TRITON_MLA_DECODE or envs.ATOM_CK_FREE
)
if self.use_triton_mla_decode:
logger.info("Using Triton MLA decode (ATOM_USE_TRITON_MLA_DECODE=1)")
logger.info(
"Using Triton MLA decode (ATOM_USE_TRITON_MLA_DECODE=1 or ATOM_CK_FREE=1)"
)

def process_weights_after_loading(self):
if is_rocm_aiter_fp4bmm_enabled():
Expand Down
7 changes: 4 additions & 3 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
per_tensor_dequantize,
shuffle_weights,
)
from atom.utils import envs
from atom.utils.custom_register import direct_register_custom_op
from atom.utils.forward_context import get_forward_context
from torch import nn
Expand Down Expand Up @@ -840,7 +841,7 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig):
from atom.model_ops.utils import has_triton_kernels

self.use_triton = get_gfx() in ("gfx942", "gfx950") and has_triton_kernels()
if not self.use_triton and not _has_ck_moe_sorting():
if not self.use_triton and (not _has_ck_moe_sorting() or envs.ATOM_CK_FREE):
if has_triton_kernels():
self.use_triton = True
_moe_logger.info(
Expand Down Expand Up @@ -1191,7 +1192,7 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig):
# Detect CK MOE availability; fall back to FlyDSL or Triton if unavailable
self.use_flydsl_moe = False
self.use_triton_moe = False
if not _has_ck_moe_sorting():
if not _has_ck_moe_sorting() or envs.ATOM_CK_FREE:
if not self.block_quant and _has_flydsl_moe():
self.use_flydsl_moe = True
_moe_logger.info(
Expand Down Expand Up @@ -1622,7 +1623,7 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig):
# Detect CK MOE availability; fall back to FlyDSL or Triton if unavailable
self.use_flydsl_moe = False
self.use_triton_moe = False
if not _has_ck_moe_sorting():
if not _has_ck_moe_sorting() or envs.ATOM_CK_FREE:
if not self.block_quant and _has_flydsl_moe():
self.use_flydsl_moe = True
_moe_logger.info("CK unavailable, using FlyDSL MOE for FP8")
Expand Down
1 change: 1 addition & 0 deletions atom/utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
== "1",
"ATOM_USE_FLYDSL_MOE": lambda: os.getenv("ATOM_USE_FLYDSL_MOE", "0") == "1",
"ATOM_CK_FREE": lambda: os.getenv("ATOM_CK_FREE", "0") == "1",
}


Expand Down
150 changes: 150 additions & 0 deletions tests/test_ck_free_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#!/usr/bin/env python3
"""
Tests for ATOM_CK_FREE=1 routing logic.
No GPU or model weights required — tests env var detection and routing conditions.

Run: cd /home/pensun/ATOM && python3 -m pytest tests/test_ck_free_mode.py -v
"""

import pytest
import importlib

Comment on lines +10 to +11

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
importlib imported but unused

Suggested change
import importlib

# All ATOM_* env vars that could affect tests
_ATOM_ENV_VARS = [
"ATOM_CK_FREE",
"ATOM_USE_TRITON_MLA_DECODE",
"ATOM_USE_FLYDSL_MOE",
"ATOM_USE_TRITON_GEMM",
"ATOM_USE_TRITON_MXFP4_BMM",
]


@pytest.fixture(autouse=True)
def _clean_env(monkeypatch):
"""Ensure ATOM_* env vars are unset so defaults are tested reliably."""
for var in _ATOM_ENV_VARS:
monkeypatch.delenv(var, raising=False)


def _get_envs():
"""Return the envs module; lazy __getattr__ re-evaluates on each access."""
import atom.utils.envs as envs

return envs


class TestAtomCkFreeEnvVar:
"""Test ATOM_CK_FREE env var detection."""

def test_default_is_false(self):
assert _get_envs().ATOM_CK_FREE is False

def test_set_to_1_is_true(self, monkeypatch):
monkeypatch.setenv("ATOM_CK_FREE", "1")
assert _get_envs().ATOM_CK_FREE is True

def test_set_to_0_is_false(self, monkeypatch):
monkeypatch.setenv("ATOM_CK_FREE", "0")
assert _get_envs().ATOM_CK_FREE is False

def test_set_to_empty_is_false(self, monkeypatch):
monkeypatch.setenv("ATOM_CK_FREE", "")
assert _get_envs().ATOM_CK_FREE is False


class TestMoeRouting:
"""Test the MOE CK-free condition logic (without importing heavy moe.py)."""

def _has_ck_moe_sorting(self) -> bool:
"""Replicate the check from moe.py without importing it."""
try:
import importlib

return importlib.util.find_spec("aiter.jit.module_moe_sorting") is not None
except Exception:
return False

def test_ck_free_forces_non_ck_in_condition(self, monkeypatch):
"""Verify the 'or envs.ATOM_CK_FREE' condition works."""
monkeypatch.setenv("ATOM_CK_FREE", "1")
envs = _get_envs()

# The condition in moe.py is:
# if not _has_ck_moe_sorting() or envs.ATOM_CK_FREE:
# When ATOM_CK_FREE=1, this should be True regardless of _has_ck_moe_sorting()
result = not self._has_ck_moe_sorting() or envs.ATOM_CK_FREE
assert result is True

def test_ck_free_off_respects_ck_availability(self, monkeypatch):
"""When ATOM_CK_FREE=0, the condition depends on _has_ck_moe_sorting()."""
monkeypatch.setenv("ATOM_CK_FREE", "0")
envs = _get_envs()

has_ck = self._has_ck_moe_sorting()
result = not has_ck or envs.ATOM_CK_FREE
# If CK is available, result should be False (use CK path)
# If CK is not available, result should be True (use fallback)
assert result == (not has_ck)


class TestMhaRouting:
"""Test MHA attention routing with ATOM_CK_FREE."""

def test_ck_free_forces_triton_attn(self, monkeypatch):
"""Verify use_triton_attn is forced True when ATOM_CK_FREE=1."""
monkeypatch.setenv("ATOM_CK_FREE", "1")
envs = _get_envs()

# Simulate the routing logic from attention_mha.py:
# use_triton_attn = sliding_window != -1 or head_dim != 128
# if envs.ATOM_CK_FREE: use_triton_attn = True
sliding_window = -1
head_dim = 128
use_triton_attn = sliding_window != -1 or head_dim != 128
assert use_triton_attn is False # Would normally be False
if envs.ATOM_CK_FREE:
use_triton_attn = True
assert use_triton_attn is True # But CK-free forces it True

def test_ck_free_off_normal_routing(self, monkeypatch):
"""Without CK-free, routing follows normal sliding_window/head_dim logic."""
monkeypatch.setenv("ATOM_CK_FREE", "0")
envs = _get_envs()

sliding_window = -1
head_dim = 128
use_triton_attn = sliding_window != -1 or head_dim != 128
if envs.ATOM_CK_FREE:
use_triton_attn = True
assert use_triton_attn is False # Normal routing, no override


class TestMlaRouting:
"""Test MLA decode routing with ATOM_CK_FREE."""

def test_ck_free_forces_triton_mla_decode(self, monkeypatch):
"""Verify use_triton_mla_decode is True when ATOM_CK_FREE=1."""
monkeypatch.setenv("ATOM_CK_FREE", "1")
monkeypatch.setenv("ATOM_USE_TRITON_MLA_DECODE", "0")
envs = _get_envs()

use_triton_mla_decode = envs.ATOM_USE_TRITON_MLA_DECODE or envs.ATOM_CK_FREE
assert use_triton_mla_decode is True

def test_triton_mla_decode_standalone(self, monkeypatch):
"""ATOM_USE_TRITON_MLA_DECODE=1 still works independently."""
monkeypatch.setenv("ATOM_CK_FREE", "0")
monkeypatch.setenv("ATOM_USE_TRITON_MLA_DECODE", "1")
envs = _get_envs()

use_triton_mla_decode = envs.ATOM_USE_TRITON_MLA_DECODE or envs.ATOM_CK_FREE
assert use_triton_mla_decode is True

def test_both_off_no_triton(self, monkeypatch):
"""When both are off, use_triton_mla_decode is False."""
monkeypatch.setenv("ATOM_CK_FREE", "0")
monkeypatch.setenv("ATOM_USE_TRITON_MLA_DECODE", "0")
envs = _get_envs()

use_triton_mla_decode = envs.ATOM_USE_TRITON_MLA_DECODE or envs.ATOM_CK_FREE
assert use_triton_mla_decode is False
Loading