Skip to content
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.git
__pycache__
*.pyc
*.egg-info
build/
dist/
42 changes: 42 additions & 0 deletions .github/workflows/sync-upstream.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Sync upstream main

on:
schedule:
# Run nightly at 06:00 UTC (midnight CST)
- cron: '0 6 * * *'
workflow_dispatch: # Allow manual trigger

jobs:
sync:
runs-on: ubuntu-latest
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This workflow pushes to main using GITHUB_TOKEN, but no explicit permissions are set. On many repos the default token permissions are read-only, so the push will fail. Add permissions: contents: write (workflow- or job-level) so the scheduled sync can push merges.

Suggested change
runs-on: ubuntu-latest
runs-on: ubuntu-latest
permissions:
contents: write

Copilot uses AI. Check for mistakes.
steps:
- name: Checkout fork
uses: actions/checkout@v4
with:
ref: main
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}

- name: Add upstream remote
run: git remote add upstream https://github.com/ROCm/ATOM.git

- name: Fetch upstream
run: git fetch upstream main

- name: Check for new commits
id: check
run: |
BEHIND=$(git rev-list --count HEAD..upstream/main)
echo "behind=$BEHIND" >> "$GITHUB_OUTPUT"
echo "Fork is $BEHIND commit(s) behind upstream"

- name: Merge upstream
if: steps.check.outputs.behind != '0'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git merge upstream/main --no-edit

- name: Push
if: steps.check.outputs.behind != '0'
run: git push origin main
1 change: 1 addition & 0 deletions atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def postprocess(
continue
token_ids = prev_token_ids[seq.id]
num_new_token = len(token_ids)
num_rejected = 0
self.update_spec_stats(num_new_token)
idx = fwd_output.req_ids.index(seq.id)
if is_deferred_out or self.use_spec:
Expand Down
17 changes: 15 additions & 2 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,19 @@ def prefill_attention_triton(
if ctx.is_prefill:
k_cache = k.unsqueeze(1)
v_cache = v.unsqueeze(1)
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unified_attention expects the V cache to be in the [num_blocks, num_kv_heads, head_dim, block_size] layout (consistent with how v_cache is created in ModelRunner.allocate_kv_cache). For the prefill block_size=1 case, v_cache should therefore be v.unsqueeze(-1) rather than v.unsqueeze(1), otherwise the value layout is incorrect.

Suggested change
v_cache = v.unsqueeze(1)
v_cache = v.unsqueeze(-1)

Copilot uses AI. Check for mistakes.
block_tables = attn_metadata.fake_block_tables
# Create fake block_tables for prefill: each token is its own
# "block" (block_size=1). Shape [num_seqs, max_seqlen_k].
batch_size = attn_metadata.cu_seqlens_k.shape[0] - 1
max_len = attn_metadata.max_seqlen_k
block_tables = torch.zeros(
batch_size, max_len, dtype=torch.int32, device=q.device
)
for i in range(batch_size):
s = attn_metadata.cu_seqlens_k[i].item()
e = attn_metadata.cu_seqlens_k[i + 1].item()
block_tables[i, : e - s] = torch.arange(
s, e, dtype=torch.int32, device=q.device
)
Comment on lines +377 to +385
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constructing block_tables for every prefill call allocates a potentially large [batch_size, max_seqlen_k] tensor and fills it in a Python loop, which is likely to be a significant prefill-time overhead. Consider generating this table once in metadata preparation (or caching it on attn_metadata), and/or vectorizing the fill to avoid per-sequence torch.arange in Python.

Suggested change
block_tables = torch.zeros(
batch_size, max_len, dtype=torch.int32, device=q.device
)
for i in range(batch_size):
s = attn_metadata.cu_seqlens_k[i].item()
e = attn_metadata.cu_seqlens_k[i + 1].item()
block_tables[i, : e - s] = torch.arange(
s, e, dtype=torch.int32, device=q.device
)
# Vectorized construction of block_tables to avoid Python loop
cu_seqlens_k = attn_metadata.cu_seqlens_k.to(device=q.device)
starts = cu_seqlens_k[:-1].to(dtype=torch.int32) # [batch_size]
lengths = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).to(dtype=torch.int32)
positions = torch.arange(
max_len, dtype=torch.int32, device=q.device
) # [max_len]
# Broadcast to [batch_size, max_len]
start_grid = starts.unsqueeze(1)
pos_grid = positions.unsqueeze(0)
indices = start_grid + pos_grid
valid_mask = pos_grid < lengths.unsqueeze(1)
block_tables = torch.where(
valid_mask, indices, torch.zeros_like(indices)
)

Copilot uses AI. Check for mistakes.

o = torch.empty_like(q)
descale_shape = (attn_metadata.cu_seqlens_q.shape[0] - 1, k.shape[1])
Expand Down Expand Up @@ -407,7 +419,8 @@ def dispatch_backend(self, fwd_ctx: ForwardContext):
ctx = fwd_ctx.context

if ctx.is_prefill:
return self.prefill_attention
# Always use Triton prefill (no CK/flash_attn_varlen_func dependency)
return self.prefill_attention_triton
Comment on lines +422 to +423
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dispatch_backend() now always routes prefill through prefill_attention_triton, which makes prefill_attention() (the aiter.flash_attn_varlen_func path) unused. If this change is only intended as a CK-unavailable fallback, it should be conditional (e.g., only use the Triton prefill path when the required kernels are present, otherwise keep the existing varlen flash attention path).

Suggested change
# Always use Triton prefill (no CK/flash_attn_varlen_func dependency)
return self.prefill_attention_triton
# Use Triton prefill when Triton attention is enabled; otherwise, use varlen flash attention
if self.use_triton_attn:
return self.prefill_attention_triton
return self.prefill_attention

Copilot uses AI. Check for mistakes.
else:
if self.use_triton_attn:
return self.paged_attention_triton
Expand Down
29 changes: 16 additions & 13 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,19 +396,22 @@ def _forward_prefill_mha(

k = torch.cat((k_nope, k_rope.expand((*k_nope.shape[:-1], -1))), dim=-1)

output = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=attn_metadata.cu_seqlens_q,
cu_seqlens_k=attn_metadata.cu_seqlens_k,
max_seqlen_q=attn_metadata.max_seqlen_q,
max_seqlen_k=attn_metadata.max_seqlen_k,
min_seqlen_q=attn_metadata.min_seqlen_q,
dropout_p=attn_metadata.dropout_p,
softmax_scale=self.scale,
causal=True,
)
# Use PyTorch SDPA for MLA prefill attention (no CK dependency)
import torch.nn.functional as F

cu_q = attn_metadata.cu_seqlens_q
cu_k = attn_metadata.cu_seqlens_k
num_seqs = cu_q.shape[0] - 1
Comment on lines +399 to +404
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description focuses on enabling Triton MoE on gfx950 / CK-unavailable fallback for MoE, but this diff also changes the MLA prefill attention implementation (switching kernels/backends). Please update the PR description (and test plan) to cover this additional behavioral/performance change, or split it into a separate PR if it’s not required for the MoE enablement.

Copilot uses AI. Check for mistakes.
outputs = []
for i in range(num_seqs):
qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0)
ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
oi = F.scaled_dot_product_attention(
qi, ki, vi, is_causal=True, scale=self.scale
)
outputs.append(oi.squeeze(0).transpose(0, 1))
output = torch.cat(outputs, dim=0)

Comment on lines +399 to 415
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replaces a single flash_attn_varlen_func call with a Python loop over sequences and per-sequence scaled_dot_product_attention calls, which will scale poorly with batch size and likely be a major prefill performance regression. Consider keeping flash_attn_varlen_func as the fast path and only falling back to the SDPA loop when the varlen kernel is unavailable.

Suggested change
# Use PyTorch SDPA for MLA prefill attention (no CK dependency)
import torch.nn.functional as F
cu_q = attn_metadata.cu_seqlens_q
cu_k = attn_metadata.cu_seqlens_k
num_seqs = cu_q.shape[0] - 1
outputs = []
for i in range(num_seqs):
qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0)
ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
oi = F.scaled_dot_product_attention(
qi, ki, vi, is_causal=True, scale=self.scale
)
outputs.append(oi.squeeze(0).transpose(0, 1))
output = torch.cat(outputs, dim=0)
# Prefer FlashAttention varlen kernel for MLA prefill; fall back to PyTorch SDPA if unavailable.
cu_q = attn_metadata.cu_seqlens_q
cu_k = attn_metadata.cu_seqlens_k
try:
# flash_attn_varlen_func expects [total_tokens, n_heads, head_dim] tensors with varlen metadata.
output = flash_attn_varlen_func(
q,
k,
v,
cu_q,
cu_k,
attn_metadata.max_seqlen_q,
attn_metadata.max_seqlen_k,
0.0, # dropout_p
softmax_scale=self.scale,
causal=True,
)
except Exception:
# Fallback: per-sequence PyTorch SDPA (slower, but no specialized kernel required).
import torch.nn.functional as F
num_seqs = cu_q.shape[0] - 1
outputs = []
for i in range(num_seqs):
qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0)
ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
oi = F.scaled_dot_product_attention(
qi, ki, vi, is_causal=True, scale=self.scale
)
outputs.append(oi.squeeze(0).transpose(0, 1))
output = torch.cat(outputs, dim=0)

Copilot uses AI. Check for mistakes.
return self.o_proj(output.flatten(start_dim=-2))

Expand Down
50 changes: 50 additions & 0 deletions atom/model_ops/attentions/aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,56 @@ def prepare_prefill(self, batch: ScheduledBatch):
bs = batch.total_seqs_num_prefill
sum_scheduled_tokens = batch.total_tokens_num_prefill
var = self.model_runner.forward_vars

# Prepare paged KV metadata for MLA prefill paths
# (needed by mla_prefill_fwd for bf16, unified_attention for fp8)
if batch.block_tables:
context_lens = np.asarray(batch.context_lens[:bs], dtype=np.int32)
num_blocks_per_seq = cdiv(context_lens, self.block_size)
kv_indptr = np.cumsum(num_blocks_per_seq)
sum_blocks = kv_indptr[-1]

dst = var["kv_indices"].np
offset = 0
for i in range(bs):
bt = batch.block_tables[i]
n = len(bt)
dst[offset : offset + n] = bt
offset += n
sum_blocks_before_converted = offset

var["kv_indptr"].np[0] = 0
var["kv_indptr"].np[1 : bs + 1] = kv_indptr

attn_metadata.kv_indptr = var["kv_indptr"].copy_to_gpu(bs + 1)
attn_metadata.kv_indices = var["kv_indices"].copy_to_gpu(
sum_blocks_before_converted
)
attn_metadata.kv_last_page_lens = var["kv_last_page_lens"].gpu[:bs]

if self.block_ratio > 1:
kv_indices_convert_triton(
var["kv_indices"].gpu[:sum_blocks_before_converted],
var["kv_indices_converted"].gpu[:sum_blocks],
var["kv_indptr"].gpu[: bs + 1],
self.block_ratio,
self.block_size,
)
attn_metadata.kv_indices = var["kv_indices_converted"].gpu[:sum_blocks]

# Prepare block_tables for unified_attention (fp8 prefill)
if attn_metadata.block_tables is None:
self.prepare_block_tables(batch)
attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs)
if self.block_ratio > 1:
block_table_convert_triton(
var["block_tables"].gpu[:bs],
var["block_tables_converted"].gpu[:bs],
var["context_lens"].gpu[:bs],
self.block_ratio,
)
attn_metadata.block_tables = var["block_tables_converted"].gpu[:bs]

if self.is_sparse and attn_metadata.max_seqlen_k > self.index_topk:
if attn_metadata.block_tables is None:
self.prepare_block_tables(batch)
Expand Down
4 changes: 4 additions & 0 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,10 @@ def weight_loader(
elif self.quant_type == QuantType.per_Tensor:
shard_offset = loaded_shard_id
shard_size = 1
else:
# per_Token and per_1x32: scale dim 0 matches output_size
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
Expand Down
Loading
Loading