Skip to content

Commit 2b41f0d

Browse files
add
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
1 parent e9452b5 commit 2b41f0d

File tree

5 files changed

+371
-152
lines changed

5 files changed

+371
-152
lines changed

.github/scripts/atom_oot_test.sh

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
#!/bin/bash
2+
set -euo pipefail
3+
4+
# Usage:
5+
# .github/scripts/atom_oot_test.sh launch <mode> [model_name]
6+
# .github/scripts/atom_oot_test.sh accuracy <mode> [model_name]
7+
#
8+
# TYPE:
9+
# launch - launch vLLM server and wait until ready
10+
# accuracy - run gsm8k accuracy test (and threshold check)
11+
#
12+
# MODE:
13+
# ci - only Kimi-K2
14+
# full - all OOT-supported models
15+
#
16+
# Optional model_name can be used to run a single model in full mode.
17+
18+
TYPE=${1:-launch}
19+
MODE=${2:-ci}
20+
SELECTED_MODEL=${3:-}
21+
22+
if [[ "$TYPE" != "launch" && "$TYPE" != "accuracy" ]]; then
23+
echo "Invalid TYPE: $TYPE. Expected: launch or accuracy"
24+
exit 2
25+
fi
26+
27+
if [[ "$MODE" != "ci" && "$MODE" != "full" ]]; then
28+
echo "Invalid MODE: $MODE. Expected: ci or full"
29+
exit 2
30+
fi
31+
32+
MAX_WAIT_RETRIES=${MAX_WAIT_RETRIES:-60}
33+
WAIT_INTERVAL_SEC=${WAIT_INTERVAL_SEC:-30}
34+
VLLM_PORT=${VLLM_PORT:-8000}
35+
VLLM_HOST=${VLLM_HOST:-0.0.0.0}
36+
VLLM_PID_FILE=${VLLM_PID_FILE:-/tmp/vllm_oot.pid}
37+
VLLM_LOG_FILE=${VLLM_LOG_FILE:-/tmp/vllm_oot.log}
38+
RESULT_DIR=${RESULT_DIR:-/tmp/oot_accuracy_results}
39+
ACCURACY_LOG_FILE=${ACCURACY_LOG_FILE:-/tmp/oot_accuracy_output.txt}
40+
41+
# Format:
42+
# MODEL_NAME|MODEL_PATH|EXTRA_ARGS|THRESHOLD
43+
CI_MODE_MODELS=(
44+
"Kimi-K2|amd/Kimi-K2-Thinking-MXFP4|--trust-remote-code --kv-cache-dtype fp8 --tensor-parallel-size 8 --enable-expert-parallel|0.90"
45+
)
46+
47+
FULL_MODE_MODELS=(
48+
"Qwen3 Dense|Qwen/Qwen3-8B|--trust-remote-code --kv-cache-dtype fp8 --tensor-parallel-size 1|0.70"
49+
"Qwen3 MoE|Qwen/Qwen3-235B-A22B-Instruct-2507-FP8|--trust-remote-code --kv-cache-dtype fp8 --tensor-parallel-size 8 --enable-expert-parallel|0.87"
50+
"DeepSeek-V3 family|deepseek-ai/DeepSeek-R1-0528|--trust-remote-code --kv-cache-dtype fp8 --tensor-parallel-size 8|0.93"
51+
"GPT-OSS|openai/gpt-oss-120b|--trust-remote-code --kv-cache-dtype fp8 --tensor-parallel-size 2 --enable-dp-attention --enable-expert-parallel --gpu-memory-utilization 0.3|0.38"
52+
"Kimi-K2|amd/Kimi-K2-Thinking-MXFP4|--trust-remote-code --kv-cache-dtype fp8 --tensor-parallel-size 8 --enable-expert-parallel|0.90"
53+
)
54+
55+
declare -a ACTIVE_MODELS=()
56+
if [[ "$MODE" == "ci" ]]; then
57+
ACTIVE_MODELS=("${CI_MODE_MODELS[@]}")
58+
else
59+
ACTIVE_MODELS=("${FULL_MODE_MODELS[@]}")
60+
fi
61+
62+
resolve_model_path() {
63+
local model_path="$1"
64+
if [[ -f "/models/${model_path}/config.json" ]]; then
65+
echo "/models/${model_path}"
66+
else
67+
echo "${model_path}"
68+
fi
69+
}
70+
71+
wait_server_ready() {
72+
local model_name="$1"
73+
echo ""
74+
echo "========== Waiting for vLLM server (${model_name}) =========="
75+
for ((i=1; i<=MAX_WAIT_RETRIES; i++)); do
76+
if curl -sS "http://127.0.0.1:${VLLM_PORT}/v1/models" >/dev/null; then
77+
echo "vLLM server is ready for ${model_name}."
78+
return 0
79+
fi
80+
81+
if [[ -f "${VLLM_PID_FILE}" ]]; then
82+
local pid
83+
pid=$(cat "${VLLM_PID_FILE}")
84+
if ! kill -0 "${pid}" 2>/dev/null; then
85+
echo "vLLM process exited early for ${model_name}."
86+
tail -n 200 "${VLLM_LOG_FILE}" || true
87+
return 1
88+
fi
89+
fi
90+
91+
echo "Waiting for vLLM server... (${i}/${MAX_WAIT_RETRIES})"
92+
sleep "${WAIT_INTERVAL_SEC}"
93+
done
94+
95+
echo "vLLM server did not become ready in time for ${model_name}."
96+
tail -n 200 "${VLLM_LOG_FILE}" || true
97+
return 1
98+
}
99+
100+
stop_server() {
101+
if [[ -f "${VLLM_PID_FILE}" ]]; then
102+
local pid
103+
pid=$(cat "${VLLM_PID_FILE}")
104+
kill "${pid}" 2>/dev/null || true
105+
rm -f "${VLLM_PID_FILE}" || true
106+
fi
107+
}
108+
109+
launch_one_model() {
110+
local model_name="$1"
111+
local model_path="$2"
112+
local extra_args="$3"
113+
114+
local resolved_model_path
115+
resolved_model_path=$(resolve_model_path "${model_path}")
116+
117+
echo ""
118+
echo "========== Launching vLLM server =========="
119+
echo "Model name: ${model_name}"
120+
echo "Model path: ${resolved_model_path}"
121+
echo "Extra args: ${extra_args}"
122+
123+
export SAFETENSORS_FAST_GPU=1
124+
export VLLM_ROCM_USE_AITER=1
125+
export VLLM_RPC_TIMEOUT=1800000
126+
export VLLM_CACHE_ROOT=/tmp/.cache/vllm
127+
export TORCHINDUCTOR_CACHE_DIR=/tmp/.cache/inductor
128+
rm -rf /tmp/.cache
129+
130+
rm -f "${VLLM_PID_FILE}" || true
131+
132+
nohup vllm serve "${resolved_model_path}" \
133+
--host "${VLLM_HOST}" \
134+
--port "${VLLM_PORT}" \
135+
--disable-log-requests \
136+
--async-scheduling \
137+
--load-format fastsafetensors \
138+
--max-model-len 16384 \
139+
${extra_args} \
140+
> "${VLLM_LOG_FILE}" 2>&1 &
141+
echo $! > "${VLLM_PID_FILE}"
142+
echo "Server PID: $(cat "${VLLM_PID_FILE}")"
143+
144+
wait_server_ready "${model_name}"
145+
}
146+
147+
accuracy_one_model() {
148+
local model_name="$1"
149+
local model_path="$2"
150+
local extra_args="$3"
151+
local threshold="$4"
152+
153+
local resolved_model_path
154+
resolved_model_path=$(resolve_model_path "${model_path}")
155+
156+
if ! command -v lm_eval >/dev/null 2>&1; then
157+
echo "========== Installing lm-eval =========="
158+
pip install 'lm-eval[api]'
159+
fi
160+
161+
mkdir -p "${RESULT_DIR}"
162+
local result_file="${RESULT_DIR}/$(date +%Y%m%d%H%M%S)_${model_name// /_}.json"
163+
164+
echo ""
165+
echo "========== Running OOT gsm8k accuracy =========="
166+
echo "Model name: ${model_name}"
167+
echo "Threshold: ${threshold}"
168+
169+
lm_eval --model local-completions \
170+
--model_args model="${resolved_model_path}",base_url="http://127.0.0.1:${VLLM_PORT}/v1/completions",num_concurrent=65,max_retries=1,tokenized_requests=False \
171+
--tasks gsm8k \
172+
--num_fewshot 3 \
173+
--output_path "${result_file}" 2>&1 | tee -a "${ACCURACY_LOG_FILE}"
174+
175+
local value
176+
value=$(python - <<PY
177+
import json
178+
with open("${result_file}", "r", encoding="utf-8") as f:
179+
data = json.load(f)
180+
print(data["results"]["gsm8k"]["exact_match,flexible-extract"])
181+
PY
182+
)
183+
184+
echo "Result file: ${result_file}"
185+
echo "Flexible extract value: ${value}"
186+
echo "Accuracy threshold: ${threshold}"
187+
188+
python - <<PY
189+
value = float("${value}")
190+
threshold = float("${threshold}")
191+
assert value >= threshold, f"Accuracy failed: {value} < {threshold}"
192+
print(f"Accuracy passed: {value} >= {threshold}")
193+
PY
194+
}
195+
196+
run_for_models() {
197+
local action="$1"
198+
local matched=0
199+
200+
for entry in "${ACTIVE_MODELS[@]}"; do
201+
IFS='|' read -r model_name model_path extra_args threshold <<< "${entry}"
202+
203+
if [[ -n "${SELECTED_MODEL}" && "${SELECTED_MODEL}" != "${model_name}" ]]; then
204+
continue
205+
fi
206+
matched=1
207+
208+
if [[ "${action}" == "launch" ]]; then
209+
launch_one_model "${model_name}" "${model_path}" "${extra_args}"
210+
break
211+
fi
212+
213+
# accuracy mode: launch + evaluate each selected model, then stop server.
214+
launch_one_model "${model_name}" "${model_path}" "${extra_args}"
215+
accuracy_one_model "${model_name}" "${model_path}" "${extra_args}" "${threshold}"
216+
stop_server
217+
done
218+
219+
if [[ "${matched}" -eq 0 ]]; then
220+
echo "No model matched MODE=${MODE}, SELECTED_MODEL=${SELECTED_MODEL}"
221+
exit 2
222+
fi
223+
}
224+
225+
trap 'stop_server' EXIT
226+
227+
if [[ "${TYPE}" == "launch" ]]; then
228+
run_for_models "launch"
229+
else
230+
run_for_models "accuracy"
231+
fi
232+

.github/workflows/atom-test.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ concurrency:
2121
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
2222

2323
env:
24-
ATOM_BASE_NIGTHLY_IMAGE: rocm/atom-dev:latest
24+
ATOM_BASE_NIGHTLY_IMAGE: rocm/atom-dev:latest
2525
GITHUB_REPO_URL: ${{ github.event.pull_request.head.repo.clone_url || 'https://github.com/ROCm/ATOM.git' }}
2626
GITHUB_COMMIT_SHA: ${{ github.event.pull_request.head.sha || github.event.head_commit.id }}
2727

@@ -46,7 +46,7 @@ jobs:
4646
if: ${{ !github.event.pull_request.head.repo.fork }}
4747
run: |
4848
cat <<EOF > Dockerfile.mod
49-
FROM ${{ env.ATOM_BASE_NIGTHLY_IMAGE }}
49+
FROM ${{ env.ATOM_BASE_NIGHTLY_IMAGE }}
5050
RUN pip install -U lm-eval[api]
5151
RUN pip show lm-eval || true
5252
RUN pip install hf_transfer
@@ -229,7 +229,7 @@ jobs:
229229
if: (matrix.run_on_pr == true || github.event_name != 'pull_request') && github.event.pull_request.head.repo.fork
230230
run: |
231231
cat <<EOF > Dockerfile.mod
232-
FROM ${{ env.ATOM_BASE_NIGTHLY_IMAGE }}
232+
FROM ${{ env.ATOM_BASE_NIGHTLY_IMAGE }}
233233
RUN pip install -U lm-eval[api]
234234
RUN pip show lm-eval || true
235235
RUN pip install hf_transfer

.github/workflows/atom-vllm-oot-full-test.yaml

Lines changed: 8 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ jobs:
113113
run: |
114114
docker pull "${{ needs.build-oot-image.outputs.oot_image_tag }}"
115115
116-
- name: Run plugin unit tests
116+
- name: Run all plugin unit tests
117117
run: |
118118
docker run --rm \
119119
-v "${{ github.workspace }}:/workspace" \
@@ -221,71 +221,20 @@ jobs:
221221
env:
222222
GITHUB_WORKSPACE: ${{ github.workspace }}
223223

224-
- name: Resolve and download model
224+
- name: Pre-download model if /models exists
225225
run: |
226-
if [ -f "/models/${{ matrix.model_path }}/config.json" ]; then
227-
echo "MODEL_PATH=/models/${{ matrix.model_path }}" >> "$GITHUB_ENV"
226+
if [ -d "/models" ] && [ ! -f "/models/${{ matrix.model_path }}/config.json" ]; then
227+
docker exec -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} "$CONTAINER_NAME" bash -lc "hf download ${{ matrix.model_path }} --local-dir /models/${{ matrix.model_path }}"
228228
else
229-
echo "MODEL_PATH=${{ matrix.model_path }}" >> "$GITHUB_ENV"
230-
if [ -d "/models" ]; then
231-
docker exec -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} "$CONTAINER_NAME" bash -lc "hf download ${{ matrix.model_path }} --local-dir /models/${{ matrix.model_path }}"
232-
echo "MODEL_PATH=/models/${{ matrix.model_path }}" >> "$GITHUB_ENV"
233-
fi
229+
echo "Skip model pre-download"
234230
fi
235231
236-
- name: Launch vLLM server with ATOM OOT plugin
232+
- name: Run OOT launch and gsm8k accuracy via script (full mode)
233+
timeout-minutes: 120
237234
run: |
238235
docker exec "$CONTAINER_NAME" bash -lc "
239236
set -euo pipefail
240-
export SAFETENSORS_FAST_GPU=1
241-
export VLLM_ROCM_USE_AITER=1
242-
export VLLM_RPC_TIMEOUT=1800000
243-
export VLLM_CACHE_ROOT=/tmp/.cache/vllm
244-
export TORCHINDUCTOR_CACHE_DIR=/tmp/.cache/inductor
245-
rm -rf /tmp/.cache
246-
247-
nohup vllm serve \"$MODEL_PATH\" \
248-
--host 0.0.0.0 \
249-
--port 8000 \
250-
${{ matrix.extra_args }} \
251-
> /tmp/vllm_oot.log 2>&1 &
252-
echo \$! > /tmp/vllm_oot.pid
253-
"
254-
255-
- name: Wait for vLLM readiness
256-
timeout-minutes: 30
257-
run: |
258-
set -euo pipefail
259-
for i in $(seq 1 60); do
260-
if docker exec "$CONTAINER_NAME" bash -lc "curl -sS http://127.0.0.1:8000/v1/models >/dev/null"; then
261-
echo "vLLM server is ready."
262-
exit 0
263-
fi
264-
echo "Waiting for server... ($i/60)"
265-
sleep 30
266-
done
267-
docker exec "$CONTAINER_NAME" bash -lc "tail -n 200 /tmp/vllm_oot.log || true"
268-
exit 1
269-
270-
- name: Run gsm8k accuracy
271-
timeout-minutes: 60
272-
run: |
273-
docker exec "$CONTAINER_NAME" bash -lc "
274-
set -euo pipefail
275-
mkdir -p /tmp/oot_accuracy_results
276-
RESULT_FILE=/tmp/oot_accuracy_results/\$(date +%Y%m%d%H%M%S).json
277-
lm_eval --model local-completions \
278-
--model_args model=\"$MODEL_PATH\",base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=65,max_retries=1,tokenized_requests=False \
279-
--tasks gsm8k \
280-
--num_fewshot 3 \
281-
--output_path \"\$RESULT_FILE\" 2>&1 | tee /tmp/oot_accuracy_output.txt
282-
"
283-
284-
- name: Check accuracy threshold
285-
run: |
286-
docker exec "$CONTAINER_NAME" bash -lc "
287-
set -euo pipefail
288-
python -c \"import json, glob; files=sorted(glob.glob('/tmp/oot_accuracy_results/*.json')); assert files, 'No accuracy JSON found'; threshold=float('${{ matrix.accuracy_test_threshold }}'); result_file=files[-1]; data=json.load(open(result_file)); value=data['results']['gsm8k']['exact_match,flexible-extract']; print('RESULT_FILE:', result_file); print('value:', value, 'threshold:', threshold); assert value >= threshold, f'Accuracy failed: {value} < {threshold}'\"
237+
bash .github/scripts/atom_oot_test.sh accuracy full '${{ matrix.model_name }}'
289238
"
290239
291240
- name: Collect summary

0 commit comments

Comments
 (0)