|
| 1 | +#!/bin/bash |
| 2 | +set -euo pipefail |
| 3 | + |
| 4 | +# Usage: |
| 5 | +# .github/scripts/atom_oot_test.sh launch <mode> [model_name] |
| 6 | +# .github/scripts/atom_oot_test.sh accuracy <mode> [model_name] |
| 7 | +# |
| 8 | +# TYPE: |
| 9 | +# launch - launch vLLM server and wait until ready |
| 10 | +# accuracy - run gsm8k accuracy test (and threshold check) |
| 11 | +# |
| 12 | +# MODE: |
| 13 | +# ci - only Kimi-K2 |
| 14 | +# full - all OOT-supported models |
| 15 | +# |
| 16 | +# Optional model_name can be used to run a single model in full mode. |
| 17 | + |
| 18 | +TYPE=${1:-launch} |
| 19 | +MODE=${2:-ci} |
| 20 | +SELECTED_MODEL=${3:-} |
| 21 | + |
| 22 | +if [[ "$TYPE" != "launch" && "$TYPE" != "accuracy" ]]; then |
| 23 | + echo "Invalid TYPE: $TYPE. Expected: launch or accuracy" |
| 24 | + exit 2 |
| 25 | +fi |
| 26 | + |
| 27 | +if [[ "$MODE" != "ci" && "$MODE" != "full" ]]; then |
| 28 | + echo "Invalid MODE: $MODE. Expected: ci or full" |
| 29 | + exit 2 |
| 30 | +fi |
| 31 | + |
| 32 | +MAX_WAIT_RETRIES=${MAX_WAIT_RETRIES:-60} |
| 33 | +WAIT_INTERVAL_SEC=${WAIT_INTERVAL_SEC:-30} |
| 34 | +VLLM_PORT=${VLLM_PORT:-8000} |
| 35 | +VLLM_HOST=${VLLM_HOST:-0.0.0.0} |
| 36 | +VLLM_PID_FILE=${VLLM_PID_FILE:-/tmp/vllm_oot.pid} |
| 37 | +VLLM_LOG_FILE=${VLLM_LOG_FILE:-/tmp/vllm_oot.log} |
| 38 | +RESULT_DIR=${RESULT_DIR:-/tmp/oot_accuracy_results} |
| 39 | +ACCURACY_LOG_FILE=${ACCURACY_LOG_FILE:-/tmp/oot_accuracy_output.txt} |
| 40 | + |
| 41 | +# Format: |
| 42 | +# MODEL_NAME|MODEL_PATH|EXTRA_ARGS|THRESHOLD |
| 43 | +CI_MODE_MODELS=( |
| 44 | + "Kimi-K2|amd/Kimi-K2-Thinking-MXFP4|--trust-remote-code --kv-cache-dtype fp8 --tensor-parallel-size 8 --enable-expert-parallel|0.90" |
| 45 | +) |
| 46 | + |
| 47 | +FULL_MODE_MODELS=( |
| 48 | + "Qwen3 Dense|Qwen/Qwen3-8B|--trust-remote-code --kv-cache-dtype fp8 --tensor-parallel-size 1|0.70" |
| 49 | + "Qwen3 MoE|Qwen/Qwen3-235B-A22B-Instruct-2507-FP8|--trust-remote-code --kv-cache-dtype fp8 --tensor-parallel-size 8 --enable-expert-parallel|0.87" |
| 50 | + "DeepSeek-V3 family|deepseek-ai/DeepSeek-R1-0528|--trust-remote-code --kv-cache-dtype fp8 --tensor-parallel-size 8|0.93" |
| 51 | + "GPT-OSS|openai/gpt-oss-120b|--trust-remote-code --kv-cache-dtype fp8 --tensor-parallel-size 2 --enable-dp-attention --enable-expert-parallel --gpu-memory-utilization 0.3|0.38" |
| 52 | + "Kimi-K2|amd/Kimi-K2-Thinking-MXFP4|--trust-remote-code --kv-cache-dtype fp8 --tensor-parallel-size 8 --enable-expert-parallel|0.90" |
| 53 | +) |
| 54 | + |
| 55 | +declare -a ACTIVE_MODELS=() |
| 56 | +if [[ "$MODE" == "ci" ]]; then |
| 57 | + ACTIVE_MODELS=("${CI_MODE_MODELS[@]}") |
| 58 | +else |
| 59 | + ACTIVE_MODELS=("${FULL_MODE_MODELS[@]}") |
| 60 | +fi |
| 61 | + |
| 62 | +resolve_model_path() { |
| 63 | + local model_path="$1" |
| 64 | + if [[ -f "/models/${model_path}/config.json" ]]; then |
| 65 | + echo "/models/${model_path}" |
| 66 | + else |
| 67 | + echo "${model_path}" |
| 68 | + fi |
| 69 | +} |
| 70 | + |
| 71 | +wait_server_ready() { |
| 72 | + local model_name="$1" |
| 73 | + echo "" |
| 74 | + echo "========== Waiting for vLLM server (${model_name}) ==========" |
| 75 | + for ((i=1; i<=MAX_WAIT_RETRIES; i++)); do |
| 76 | + if curl -sS "http://127.0.0.1:${VLLM_PORT}/v1/models" >/dev/null; then |
| 77 | + echo "vLLM server is ready for ${model_name}." |
| 78 | + return 0 |
| 79 | + fi |
| 80 | + |
| 81 | + if [[ -f "${VLLM_PID_FILE}" ]]; then |
| 82 | + local pid |
| 83 | + pid=$(cat "${VLLM_PID_FILE}") |
| 84 | + if ! kill -0 "${pid}" 2>/dev/null; then |
| 85 | + echo "vLLM process exited early for ${model_name}." |
| 86 | + tail -n 200 "${VLLM_LOG_FILE}" || true |
| 87 | + return 1 |
| 88 | + fi |
| 89 | + fi |
| 90 | + |
| 91 | + echo "Waiting for vLLM server... (${i}/${MAX_WAIT_RETRIES})" |
| 92 | + sleep "${WAIT_INTERVAL_SEC}" |
| 93 | + done |
| 94 | + |
| 95 | + echo "vLLM server did not become ready in time for ${model_name}." |
| 96 | + tail -n 200 "${VLLM_LOG_FILE}" || true |
| 97 | + return 1 |
| 98 | +} |
| 99 | + |
| 100 | +stop_server() { |
| 101 | + if [[ -f "${VLLM_PID_FILE}" ]]; then |
| 102 | + local pid |
| 103 | + pid=$(cat "${VLLM_PID_FILE}") |
| 104 | + kill "${pid}" 2>/dev/null || true |
| 105 | + rm -f "${VLLM_PID_FILE}" || true |
| 106 | + fi |
| 107 | +} |
| 108 | + |
| 109 | +launch_one_model() { |
| 110 | + local model_name="$1" |
| 111 | + local model_path="$2" |
| 112 | + local extra_args="$3" |
| 113 | + |
| 114 | + local resolved_model_path |
| 115 | + resolved_model_path=$(resolve_model_path "${model_path}") |
| 116 | + |
| 117 | + echo "" |
| 118 | + echo "========== Launching vLLM server ==========" |
| 119 | + echo "Model name: ${model_name}" |
| 120 | + echo "Model path: ${resolved_model_path}" |
| 121 | + echo "Extra args: ${extra_args}" |
| 122 | + |
| 123 | + export SAFETENSORS_FAST_GPU=1 |
| 124 | + export VLLM_ROCM_USE_AITER=1 |
| 125 | + export VLLM_RPC_TIMEOUT=1800000 |
| 126 | + export VLLM_CACHE_ROOT=/tmp/.cache/vllm |
| 127 | + export TORCHINDUCTOR_CACHE_DIR=/tmp/.cache/inductor |
| 128 | + rm -rf /tmp/.cache |
| 129 | + |
| 130 | + rm -f "${VLLM_PID_FILE}" || true |
| 131 | + |
| 132 | + nohup vllm serve "${resolved_model_path}" \ |
| 133 | + --host "${VLLM_HOST}" \ |
| 134 | + --port "${VLLM_PORT}" \ |
| 135 | + --disable-log-requests \ |
| 136 | + --async-scheduling \ |
| 137 | + --load-format fastsafetensors \ |
| 138 | + --max-model-len 16384 \ |
| 139 | + ${extra_args} \ |
| 140 | + > "${VLLM_LOG_FILE}" 2>&1 & |
| 141 | + echo $! > "${VLLM_PID_FILE}" |
| 142 | + echo "Server PID: $(cat "${VLLM_PID_FILE}")" |
| 143 | + |
| 144 | + wait_server_ready "${model_name}" |
| 145 | +} |
| 146 | + |
| 147 | +accuracy_one_model() { |
| 148 | + local model_name="$1" |
| 149 | + local model_path="$2" |
| 150 | + local extra_args="$3" |
| 151 | + local threshold="$4" |
| 152 | + |
| 153 | + local resolved_model_path |
| 154 | + resolved_model_path=$(resolve_model_path "${model_path}") |
| 155 | + |
| 156 | + if ! command -v lm_eval >/dev/null 2>&1; then |
| 157 | + echo "========== Installing lm-eval ==========" |
| 158 | + pip install 'lm-eval[api]' |
| 159 | + fi |
| 160 | + |
| 161 | + mkdir -p "${RESULT_DIR}" |
| 162 | + local result_file="${RESULT_DIR}/$(date +%Y%m%d%H%M%S)_${model_name// /_}.json" |
| 163 | + |
| 164 | + echo "" |
| 165 | + echo "========== Running OOT gsm8k accuracy ==========" |
| 166 | + echo "Model name: ${model_name}" |
| 167 | + echo "Threshold: ${threshold}" |
| 168 | + |
| 169 | + lm_eval --model local-completions \ |
| 170 | + --model_args model="${resolved_model_path}",base_url="http://127.0.0.1:${VLLM_PORT}/v1/completions",num_concurrent=65,max_retries=1,tokenized_requests=False \ |
| 171 | + --tasks gsm8k \ |
| 172 | + --num_fewshot 3 \ |
| 173 | + --output_path "${result_file}" 2>&1 | tee -a "${ACCURACY_LOG_FILE}" |
| 174 | + |
| 175 | + local value |
| 176 | + value=$(python - <<PY |
| 177 | +import json |
| 178 | +with open("${result_file}", "r", encoding="utf-8") as f: |
| 179 | + data = json.load(f) |
| 180 | +print(data["results"]["gsm8k"]["exact_match,flexible-extract"]) |
| 181 | +PY |
| 182 | +) |
| 183 | + |
| 184 | + echo "Result file: ${result_file}" |
| 185 | + echo "Flexible extract value: ${value}" |
| 186 | + echo "Accuracy threshold: ${threshold}" |
| 187 | + |
| 188 | + python - <<PY |
| 189 | +value = float("${value}") |
| 190 | +threshold = float("${threshold}") |
| 191 | +assert value >= threshold, f"Accuracy failed: {value} < {threshold}" |
| 192 | +print(f"Accuracy passed: {value} >= {threshold}") |
| 193 | +PY |
| 194 | +} |
| 195 | + |
| 196 | +run_for_models() { |
| 197 | + local action="$1" |
| 198 | + local matched=0 |
| 199 | + |
| 200 | + for entry in "${ACTIVE_MODELS[@]}"; do |
| 201 | + IFS='|' read -r model_name model_path extra_args threshold <<< "${entry}" |
| 202 | + |
| 203 | + if [[ -n "${SELECTED_MODEL}" && "${SELECTED_MODEL}" != "${model_name}" ]]; then |
| 204 | + continue |
| 205 | + fi |
| 206 | + matched=1 |
| 207 | + |
| 208 | + if [[ "${action}" == "launch" ]]; then |
| 209 | + launch_one_model "${model_name}" "${model_path}" "${extra_args}" |
| 210 | + break |
| 211 | + fi |
| 212 | + |
| 213 | + # accuracy mode: launch + evaluate each selected model, then stop server. |
| 214 | + launch_one_model "${model_name}" "${model_path}" "${extra_args}" |
| 215 | + accuracy_one_model "${model_name}" "${model_path}" "${extra_args}" "${threshold}" |
| 216 | + stop_server |
| 217 | + done |
| 218 | + |
| 219 | + if [[ "${matched}" -eq 0 ]]; then |
| 220 | + echo "No model matched MODE=${MODE}, SELECTED_MODEL=${SELECTED_MODEL}" |
| 221 | + exit 2 |
| 222 | + fi |
| 223 | +} |
| 224 | + |
| 225 | +trap 'stop_server' EXIT |
| 226 | + |
| 227 | +if [[ "${TYPE}" == "launch" ]]; then |
| 228 | + run_for_models "launch" |
| 229 | +else |
| 230 | + run_for_models "accuracy" |
| 231 | +fi |
| 232 | + |
0 commit comments