|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Validate Ming Omni thinker output consistency across TP configurations. |
| 3 | +
|
| 4 | +Usage: |
| 5 | + python scripts/test_ming_tp.py run --tp 1 --cpu-offload-gb 150 |
| 6 | + python scripts/test_ming_tp.py run --tp 2 --cpu-offload-gb 40 |
| 7 | + python scripts/test_ming_tp.py compare tp1_results.json tp2_results.json |
| 8 | +""" |
| 9 | +from __future__ import annotations |
| 10 | + |
| 11 | +import argparse |
| 12 | +import asyncio |
| 13 | +import json |
| 14 | +import logging |
| 15 | +import multiprocessing as mp |
| 16 | +import os |
| 17 | +import sys |
| 18 | + |
| 19 | +logging.basicConfig( |
| 20 | + level=os.environ.get("LOGLEVEL", "INFO").upper(), |
| 21 | + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| 22 | +) |
| 23 | +logger = logging.getLogger(__name__) |
| 24 | + |
| 25 | +TEST_PROMPTS = [ |
| 26 | + "What is 1+1?", |
| 27 | + "What is the capital of France?", |
| 28 | + "What is the capital of Japan?", |
| 29 | + "Explain quantum computing in one sentence.", |
| 30 | +] |
| 31 | + |
| 32 | + |
| 33 | +async def run_thinker( |
| 34 | + tp_size: int, |
| 35 | + cpu_offload_gb: int, |
| 36 | + mem_fraction: float, |
| 37 | + output_file: str, |
| 38 | + attention_backend: str | None = None, |
| 39 | +): |
| 40 | + from sglang_omni.models.ming_omni.config import MingOmniPipelineConfig |
| 41 | + from sglang_omni.pipeline.mp_runner import MultiProcessPipelineRunner |
| 42 | + from sglang_omni.proto import OmniRequest |
| 43 | + |
| 44 | + overrides = { |
| 45 | + "tp_size": tp_size, |
| 46 | + "cpu_offload_gb": cpu_offload_gb, |
| 47 | + "mem_fraction_static": mem_fraction, |
| 48 | + } |
| 49 | + if attention_backend is not None: |
| 50 | + overrides["attention_backend"] = attention_backend |
| 51 | + |
| 52 | + config = MingOmniPipelineConfig( |
| 53 | + model_path="inclusionAI/Ming-flash-omni-2.0", |
| 54 | + relay_backend="shm", |
| 55 | + server_args_overrides=overrides, |
| 56 | + ) |
| 57 | + |
| 58 | + runner = MultiProcessPipelineRunner(config) |
| 59 | + logger.info( |
| 60 | + "Starting pipeline with TP=%d, cpu_offload_gb=%d, attention_backend=%s ...", |
| 61 | + tp_size, |
| 62 | + cpu_offload_gb, |
| 63 | + attention_backend, |
| 64 | + ) |
| 65 | + await runner.start(timeout=600) |
| 66 | + |
| 67 | + results = [] |
| 68 | + try: |
| 69 | + for i, prompt in enumerate(TEST_PROMPTS): |
| 70 | + logger.info("[%d/%d] Prompt: %s", i + 1, len(TEST_PROMPTS), prompt) |
| 71 | + request = { |
| 72 | + "messages": [ |
| 73 | + { |
| 74 | + "role": "system", |
| 75 | + "content": "You are a friendly AI assistant. Please answer concisely.", |
| 76 | + }, |
| 77 | + {"role": "user", "content": prompt}, |
| 78 | + ], |
| 79 | + "audios": [], |
| 80 | + } |
| 81 | + result = await asyncio.wait_for( |
| 82 | + runner.coordinator.submit( |
| 83 | + f"tp-test-{i}", |
| 84 | + OmniRequest( |
| 85 | + inputs=request, |
| 86 | + params={"max_new_tokens": 64, "temperature": 0.0}, |
| 87 | + ), |
| 88 | + ), |
| 89 | + timeout=120, |
| 90 | + ) |
| 91 | + text = "" |
| 92 | + if isinstance(result, dict): |
| 93 | + for stage_name, payload in result.items(): |
| 94 | + data = ( |
| 95 | + payload |
| 96 | + if isinstance(payload, dict) |
| 97 | + else getattr(payload, "data", {}) |
| 98 | + ) |
| 99 | + if isinstance(data, dict) and "text" in data: |
| 100 | + text = data["text"] |
| 101 | + break |
| 102 | + assert text, f"Empty output for prompt: {prompt}" |
| 103 | + results.append({"prompt": prompt, "output": text}) |
| 104 | + logger.info(" Output: %s", text[:200]) |
| 105 | + finally: |
| 106 | + await runner.stop() |
| 107 | + |
| 108 | + with open(output_file, "w") as f: |
| 109 | + json.dump( |
| 110 | + {"tp_size": tp_size, "results": results}, f, indent=2, ensure_ascii=False |
| 111 | + ) |
| 112 | + logger.info("Results saved to %s", output_file) |
| 113 | + |
| 114 | + |
| 115 | +def compare_outputs(file1: str, file2: str): |
| 116 | + with open(file1) as f: |
| 117 | + data1 = json.load(f) |
| 118 | + with open(file2) as f: |
| 119 | + data2 = json.load(f) |
| 120 | + |
| 121 | + print(f"\n{'='*60}") |
| 122 | + print(f"Comparing TP={data1['tp_size']} vs TP={data2['tp_size']}") |
| 123 | + print(f"{'='*60}") |
| 124 | + |
| 125 | + all_match = True |
| 126 | + for r1, r2 in zip(data1["results"], data2["results"]): |
| 127 | + match = r1["output"].strip() == r2["output"].strip() |
| 128 | + status = "MATCH" if match else "MISMATCH" |
| 129 | + if not match: |
| 130 | + all_match = False |
| 131 | + print(f"\n[{status}] Prompt: {r1['prompt']}") |
| 132 | + print(f" TP={data1['tp_size']}: {r1['output'][:120]}") |
| 133 | + print(f" TP={data2['tp_size']}: {r2['output'][:120]}") |
| 134 | + |
| 135 | + print(f"\n{'='*60}") |
| 136 | + if all_match: |
| 137 | + print("ALL OUTPUTS MATCH - TP validation PASSED") |
| 138 | + else: |
| 139 | + print("OUTPUTS DIFFER - TP validation FAILED, needs investigation") |
| 140 | + print(f"{'='*60}") |
| 141 | + return all_match |
| 142 | + |
| 143 | + |
| 144 | +def main(): |
| 145 | + mp.set_start_method("spawn", force=True) |
| 146 | + |
| 147 | + parser = argparse.ArgumentParser(description=__doc__) |
| 148 | + sub = parser.add_subparsers(dest="cmd") |
| 149 | + |
| 150 | + run_p = sub.add_parser("run") |
| 151 | + run_p.add_argument("--tp", type=int, required=True) |
| 152 | + run_p.add_argument("--cpu-offload-gb", type=int, default=80) |
| 153 | + run_p.add_argument("--mem-fraction", type=float, default=0.80) |
| 154 | + run_p.add_argument("--attention-backend", type=str, default=None) |
| 155 | + run_p.add_argument("--output", type=str, default=None) |
| 156 | + |
| 157 | + cmp_p = sub.add_parser("compare") |
| 158 | + cmp_p.add_argument("file1") |
| 159 | + cmp_p.add_argument("file2") |
| 160 | + |
| 161 | + args = parser.parse_args() |
| 162 | + |
| 163 | + if args.cmd == "run": |
| 164 | + output = args.output or f"tp{args.tp}_results.json" |
| 165 | + asyncio.run( |
| 166 | + run_thinker( |
| 167 | + args.tp, |
| 168 | + args.cpu_offload_gb, |
| 169 | + args.mem_fraction, |
| 170 | + output, |
| 171 | + args.attention_backend, |
| 172 | + ) |
| 173 | + ) |
| 174 | + elif args.cmd == "compare": |
| 175 | + sys.exit(0 if compare_outputs(args.file1, args.file2) else 1) |
| 176 | + else: |
| 177 | + parser.print_help() |
| 178 | + |
| 179 | + |
| 180 | +if __name__ == "__main__": |
| 181 | + main() |
0 commit comments