Skip to content

Commit 21d81c5

Browse files
authored
Merge branch 'main' into zejun/add_oot_ut
2 parents 2b41f0d + 2c4dc27 commit 21d81c5

File tree

6 files changed

+321
-108
lines changed

6 files changed

+321
-108
lines changed

.github/scripts/atom_test.sh

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,48 @@ if [ "$TYPE" == "launch" ]; then
2020

2121
echo ""
2222
echo "========== Waiting for ATOM server to start =========="
23+
# Phase 1: Wait for HTTP server to be up via /health endpoint
2324
max_retries=30
2425
retry_interval=60
26+
server_up=false
2527
for ((i=1; i<=max_retries; i++)); do
26-
if curl -s http://localhost:8000/v1/completions -o /dev/null; then
27-
echo "ATOM server is up."
28+
if curl -sf http://localhost:8000/health -o /dev/null; then
29+
echo "ATOM server HTTP endpoint is up."
30+
server_up=true
2831
break
2932
fi
3033
echo "Waiting for ATOM server to be ready... ($i/$max_retries)"
3134
sleep $retry_interval
3235
done
33-
if ! curl -s http://localhost:8000/v1/completions -o /dev/null; then
36+
if [ "$server_up" = false ]; then
3437
echo "ATOM server did not start after $((max_retries * retry_interval)) seconds."
3538
kill $atom_server_pid
3639
exit 1
3740
fi
41+
42+
# Phase 2: Warmup - send a real completion request to ensure model is fully ready
43+
# (CUDA graph capture, JIT compilation, etc. may still be in progress after /health returns OK)
44+
echo "========== Warming up ATOM server =========="
45+
warmup_retries=10
46+
warmup_interval=30
47+
warmup_done=false
48+
for ((i=1; i<=warmup_retries; i++)); do
49+
if curl -sf http://localhost:8000/v1/completions \
50+
-H "Content-Type: application/json" \
51+
-d '{"model":"'"$MODEL_PATH"'","prompt":"hi","max_tokens":1}' \
52+
-o /dev/null --max-time 120; then
53+
echo "ATOM server warmup completed successfully."
54+
warmup_done=true
55+
break
56+
fi
57+
echo "Warmup attempt $i/$warmup_retries failed, retrying in ${warmup_interval}s..."
58+
sleep $warmup_interval
59+
done
60+
if [ "$warmup_done" = false ]; then
61+
echo "ATOM server warmup failed after $((warmup_retries * warmup_interval)) seconds."
62+
kill $atom_server_pid
63+
exit 1
64+
fi
3865
fi
3966

4067
if [ "$TYPE" == "accuracy" ]; then
@@ -51,7 +78,7 @@ if [ "$TYPE" == "accuracy" ]; then
5178
mkdir -p accuracy_test_results
5279
RESULT_FILENAME=accuracy_test_results/$(date +%Y%m%d%H%M%S).json
5380
lm_eval --model local-completions \
54-
--model_args model="$MODEL_PATH",base_url=http://localhost:8000/v1/completions,num_concurrent=65,max_retries=1,tokenized_requests=False \
81+
--model_args model="$MODEL_PATH",base_url=http://localhost:8000/v1/completions,num_concurrent=16,max_retries=3,tokenized_requests=False \
5582
--tasks gsm8k \
5683
--num_fewshot 3 \
5784
--output_path "${RESULT_FILENAME}"

.github/workflows/atom-test.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ jobs:
194194
CONTAINER_NAME: atom_test_${{ strategy.job-index }}
195195

196196
steps:
197+
- name: Set HF_TOKEN
198+
if: matrix.run_on_pr == true || github.event_name != 'pull_request'
199+
run: echo "HF_TOKEN=${HF_TOKEN:-${{ secrets.AMD_HF_TOKEN }}}" >> $GITHUB_ENV
200+
197201
- name: Kill all Docker containers and clean up workspace
198202
if: (matrix.run_on_pr == true || github.event_name != 'pull_request') && matrix.runner == 'atom-mi355-8gpu.predownload'
199203
run: |
@@ -341,7 +345,8 @@ jobs:
341345
fi
342346
343347
- name: Run ATOM simple inference
344-
if: matrix.run_on_pr == true || github.event_name != 'pull_request'
348+
# Skip simple inference; accuracy test already validates correctness
349+
if: (matrix.run_on_pr == true || github.event_name != 'pull_request') && false
345350
timeout-minutes: 30
346351
run: |
347352
# Run the inference and capture output

atom/plugin/attention.py

Lines changed: 129 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
from typing import Generic, Optional, TypeVar
22
import logging
3-
import os
43

54
from dataclasses import dataclass
65

76
import torch
87

8+
from aiter import dtypes, get_mla_metadata_info_v1, get_mla_metadata_v1
9+
from aiter.dist.parallel_state import get_tp_group
910
from atom.plugin.prepare import is_vllm, is_sglang
10-
from atom.utils import CpuGpuBuffer
11+
from atom.utils import CpuGpuBuffer, envs
12+
from atom.config import get_current_atom_config
13+
1114
from atom.utils.forward_context import Context, AttentionMetaData
1215
from atom.model_ops.attention_mha import PagedAttentionImpl
13-
from atom.model_ops.attention_mla import MLAAttention
16+
from atom.model_ops.attention_mla import MLAAttention, _MLA_MIN_HEADS
1417

1518
logger = logging.getLogger("atom")
1619

1720
_PARTITION_SIZE_ROCM = 256
1821
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
22+
disable_vllm_plugin_attention = envs.ATOM_DISABLE_VLLM_PLUGIN_ATTENTION
1923

2024

2125
@dataclass
@@ -643,6 +647,7 @@ class AiterMLAChunkedContextMetadataForPluginMode:
643647
query_seq_lens: torch.Tensor | None = None
644648
workspace_buffer: torch.Tensor | None = None
645649
q_data_type: torch.dtype | None = None
650+
output_dtype: torch.dtype | None = None
646651

647652

648653
D = TypeVar("D", bound=AiterMLACommonDecodeMetadataForPluginMode)
@@ -699,6 +704,49 @@ def __init__(self):
699704
"Its methods are meant to be added to other classes via decorators."
700705
)
701706

707+
# TODO: support mtp and sparse
708+
def _set_mla_persistent_worker_buffers(
709+
self, bs: int, cu_seqlens_q: torch.Tensor, max_q_len: int = 1
710+
):
711+
split_params = {
712+
"kv_granularity": max(self.block_size, 16),
713+
"max_seqlen_qo": max_q_len,
714+
"uni_seqlen_qo": max_q_len,
715+
"fast_mode": 1,
716+
"max_split_per_batch": 16,
717+
}
718+
var = self.mla_persistent_metadata
719+
work_meta_data = var["work_meta_data"]
720+
work_info_set = var["work_info_set"]
721+
work_indptr = var["work_indptr"]
722+
reduce_indptr = var["reduce_indptr"]
723+
reduce_final_map = var["reduce_final_map"]
724+
reduce_partial_map = var["reduce_partial_map"]
725+
get_mla_metadata_v1(
726+
cu_seqlens_q,
727+
self.paged_kv_indptr[: bs + 1], # TODO: support sparse
728+
self.paged_kv_last_page_len[:bs],
729+
self.padded_num_attention_heads,
730+
1, # nhead_kv,
731+
True,
732+
work_meta_data,
733+
work_info_set,
734+
work_indptr,
735+
reduce_indptr,
736+
reduce_final_map,
737+
reduce_partial_map,
738+
page_size=self.block_size,
739+
**split_params,
740+
)
741+
return {
742+
"work_meta_data": work_meta_data,
743+
"work_info_set": work_info_set,
744+
"work_indptr": work_indptr,
745+
"reduce_indptr": reduce_indptr,
746+
"reduce_final_map": reduce_final_map,
747+
"reduce_partial_map": reduce_partial_map,
748+
}
749+
702750
def _build_decode(
703751
self,
704752
block_table_tensor: torch.Tensor,
@@ -733,34 +781,29 @@ def _build_decode(
733781
qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
734782
max_qo_len = qo_len.max().item()
735783

736-
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
737-
num_actual_pages = paged_kv_indices.size(0)
784+
num_actual_pages = paged_kv_indices.size(0)
738785

739-
self.paged_kv_indices[:num_actual_pages].copy_(
740-
paged_kv_indices, non_blocking=True
741-
)
742-
self.paged_kv_indices[num_actual_pages:].fill_(-1)
743-
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
786+
self.paged_kv_indices[:num_actual_pages].copy_(
787+
paged_kv_indices, non_blocking=True
788+
)
789+
self.paged_kv_indices[num_actual_pages:].fill_(-1)
790+
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
744791

745-
self.paged_kv_indptr[: 1 + num_reqs].copy_(
746-
paged_kv_indptr, non_blocking=True
747-
)
748-
self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1])
749-
paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs]
792+
self.paged_kv_indptr[: 1 + num_reqs].copy_(paged_kv_indptr, non_blocking=True)
793+
self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1])
794+
paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs]
750795

751-
# paged_kv_last_page_len already uses the pre-initialized buffer slice
752-
# (set above), so no copy needed - buffer is always 1s.
796+
# paged_kv_last_page_len already uses the pre-initialized buffer slice
797+
# (set above), so no copy needed - buffer is always 1s.
753798

754-
self.qo_indptr[: 1 + num_reqs].copy_(
755-
query_start_loc_device, non_blocking=True
756-
)
757-
self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1]
758-
qo_indptr = self.qo_indptr[: 1 + num_reqs]
799+
self.qo_indptr[: 1 + num_reqs].copy_(query_start_loc_device, non_blocking=True)
800+
self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1]
801+
qo_indptr = self.qo_indptr[: 1 + num_reqs]
759802

760-
else:
761-
qo_indptr = torch.arange(
762-
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
763-
)
803+
ctx_mla_ps = self._set_mla_persistent_worker_buffers(
804+
num_reqs, query_start_loc_device, 1
805+
)
806+
self.mla_persistent_metadata.update(ctx_mla_ps)
764807

765808
attn_metadata = AiterMLADecodeMetadataForPluginMode(
766809
block_table=block_table_tensor,
@@ -1056,11 +1099,15 @@ def build(
10561099
decode=decode_metadata,
10571100
)
10581101

1102+
# TODO: support mtp
1103+
ctx_mla_ps = self.mla_persistent_metadata
1104+
10591105
attn_metadata = AttentionMetaData(
10601106
max_seqlen_q=common_attn_metadata.max_query_len,
10611107
block_tables=common_attn_metadata.block_table_tensor,
10621108
slot_mapping=common_attn_metadata.slot_mapping,
10631109
plugin_metadata=attn_metadata_for_plugin_mode,
1110+
**ctx_mla_ps,
10641111
)
10651112
return attn_metadata
10661113

@@ -1095,6 +1142,13 @@ def init_method_under_plugin_mode(
10951142
max_num_pages_per_req = self.vllm_config.model_config.max_model_len
10961143
max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs
10971144
max_num_pages = max_num_reqs * max_num_pages_per_req
1145+
self.num_attention_heads = (
1146+
config.model_config.hf_config.num_attention_heads
1147+
// get_tp_group().world_size
1148+
)
1149+
self.padded_num_attention_heads = max(self.num_attention_heads, _MLA_MIN_HEADS)
1150+
self.block_size = kv_cache_spec.block_size
1151+
self.max_bs = max_num_reqs
10981152

10991153
# Preparing persistent buffers
11001154
# TODO: we can disambiguate between decode and mixed-prefill decode here
@@ -1107,17 +1161,54 @@ def init_method_under_plugin_mode(
11071161
max_num_reqs, dtype=torch.int32, device=device
11081162
)
11091163

1110-
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
1111-
self.paged_kv_indptr = torch.zeros(
1112-
max_num_reqs + 1, dtype=torch.int32, device=device
1113-
)
1114-
self.paged_kv_indices = torch.zeros(
1115-
max_num_pages, dtype=torch.int32, device=device
1116-
)
1164+
self.paged_kv_indptr = torch.zeros(
1165+
max_num_reqs + 1, dtype=torch.int32, device=device
1166+
)
1167+
self.paged_kv_indices = torch.zeros(
1168+
max_num_pages, dtype=torch.int32, device=device
1169+
)
11171170

1118-
self.qo_indptr = torch.zeros(
1119-
max_num_reqs + 1, dtype=torch.int32, device=device
1120-
)
1171+
self.qo_indptr = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device)
1172+
1173+
(
1174+
(work_meta_data_size, work_meta_data_type),
1175+
(work_indptr_size, work_indptr_type),
1176+
(work_info_set_size, work_info_set_type),
1177+
(reduce_indptr_size, reduce_indptr_type),
1178+
(reduce_final_map_size, reduce_final_map_type),
1179+
(reduce_partial_map_size, reduce_partial_map_type),
1180+
) = get_mla_metadata_info_v1(
1181+
max_num_reqs,
1182+
1,
1183+
self.padded_num_attention_heads,
1184+
torch.bfloat16,
1185+
dtypes.d_dtypes[config.cache_config.cache_dtype],
1186+
is_sparse=False, # TODO: support sparse
1187+
fast_mode=True,
1188+
)
1189+
1190+
self.mla_persistent_metadata = {
1191+
"work_meta_data": torch.empty(
1192+
work_meta_data_size, dtype=work_meta_data_type, device=self.device
1193+
),
1194+
"work_indptr": torch.empty(
1195+
work_indptr_size, dtype=work_indptr_type, device=self.device
1196+
),
1197+
"work_info_set": torch.empty(
1198+
work_info_set_size, dtype=work_info_set_type, device=self.device
1199+
),
1200+
"reduce_indptr": torch.empty(
1201+
reduce_indptr_size, dtype=reduce_indptr_type, device=self.device
1202+
),
1203+
"reduce_final_map": torch.empty(
1204+
reduce_final_map_size, dtype=reduce_final_map_type, device=self.device
1205+
),
1206+
"reduce_partial_map": torch.empty(
1207+
reduce_partial_map_size,
1208+
dtype=reduce_partial_map_type,
1209+
device=self.device,
1210+
),
1211+
}
11211212

11221213
return init_method_under_plugin_mode
11231214

@@ -1206,6 +1297,7 @@ def decorator(cls):
12061297
class vllmAiterMLABackendMethods:
12071298
accept_output_buffer: bool = True
12081299
supported_dtypes: list = [torch.float16, torch.bfloat16]
1300+
forward_includes_kv_cache_update: bool = True
12091301

12101302
def __init__(self):
12111303
raise TypeError(
@@ -1288,9 +1380,6 @@ def unified_attention_with_output_base_for_plugin_mode(
12881380
use_mla: bool,
12891381
qkv: torch.Tensor,
12901382
) -> torch.Tensor:
1291-
from atom.config import get_current_atom_config
1292-
from atom.utils import envs
1293-
12941383
atom_config = get_current_atom_config()
12951384
if use_mla:
12961385
# raise NotImplementedError("MLA is not supported for plugin mode for now")
@@ -1300,7 +1389,7 @@ def unified_attention_with_output_base_for_plugin_mode(
13001389
q = self.q_proj(q, q_scale)
13011390
q = q.view(-1, self.num_heads, self.qk_head_dim)
13021391
# Add head dim of 1 to k_pe
1303-
if os.getenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0").lower() == "1":
1392+
if disable_vllm_plugin_attention:
13041393
k_pe = k_pe.unsqueeze(1)
13051394
if self.rotary_emb is not None:
13061395
q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(

0 commit comments

Comments
 (0)