diff --git a/tests/unittests/test_dump_metrics.py b/tests/unittests/test_dump_metrics.py new file mode 100644 index 000000000..708d1de5f --- /dev/null +++ b/tests/unittests/test_dump_metrics.py @@ -0,0 +1,81 @@ +# Copyright 2025 Rebellions Inc. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This writes metrics dump files when enabled. + +Requires RBLN NPU, network (HF model), and opt-in: + VLLM_RBLN_DUMP_METRICS=1 pytest ... +""" + +from __future__ import annotations + +import os +import subprocess +import sys +from pathlib import Path + +import pytest + +# Smallest public Qwen2 causal LM in the Qwen2 family (~0.5B params). +QWEN2_SMALL_MODEL = "Qwen/Qwen3-0.6B" + +REPO_ROOT = Path(__file__).resolve().parents[2] +OFFLINE_INFERENCE_BASIC = REPO_ROOT / "examples" / "experimental" / "offline_inference_basic.py" + + +pytestmark = pytest.mark.skipif( + os.environ.get("VLLM_RBLN_DUMP_METRICS", "").lower() + not in ("1", "true", "yes"), + reason="Set VLLM_RBLN_DUMP_METRICS=1 to run (needs RBLN + HF).", +) + + +def test_dump_metrics(tmp_path: Path) -> None: + """Run dump metrics; expect *_metrics.txt in cwd.""" + env = os.environ.copy() + env["RBLN_USE_CUSTOM_KERNEL"] = "0" + env["VLLM_RBLN_COMPILE_STRICT_MODE"] = "1" + env["VLLM_DISABLE_COMPILE_CACHE"] = "1" + env["VLLM_USE_V1"] = "1" + env["VLLM_RBLN_METRICS"] = "1" + env["VLLM_RBLN_DUMP_METRICS"] = "1" + + cmd = [ + sys.executable, + str(OFFLINE_INFERENCE_BASIC), + "--model", + QWEN2_SMALL_MODEL, + "--max-num-seqs", + "1", + "--block-size", + "4096", + "--tensor-parallel-size", + "1", + ] + subprocess.run( + cmd, + cwd=tmp_path, + env=env, + check=True, + timeout=7200, + ) + + dumped = list(tmp_path.glob("*_metrics.txt")) + assert dumped, ( + f"Expected at least one *_metrics.txt under {tmp_path}, " + f"got {list(tmp_path.iterdir())}" + ) + for path in dumped: + text = path.read_text() + assert "METRICS" in text, f"{path} should contain METRICS section, got {text!r}" diff --git a/vllm_rbln/v1/worker/metrics.py b/vllm_rbln/v1/worker/metrics.py index d895c8de0..63b0a60ae 100644 --- a/vllm_rbln/v1/worker/metrics.py +++ b/vllm_rbln/v1/worker/metrics.py @@ -13,9 +13,11 @@ # limitations under the License. import atexit +import os from collections import defaultdict from dataclasses import dataclass, field +import vllm_rbln.rbln_envs as envs from vllm_rbln.logger import init_logger logger = init_logger(__name__) @@ -130,27 +132,37 @@ def get_call_counts(self) -> int: """Get total number of requests processed.""" return len(self.latencies) - def show_stats(self, stat_type: str): + def gen_stats(self, stat_type: str) -> str: + stats = f"" if self.get_call_counts() > 0: - logger.info("%s METRICS:", stat_type) - logger.info(" Total call counts: %d", self.get_call_counts()) - logger.info(" Average latency: %.2f ms", self.get_avg_latency()) + stats += f"{stat_type} METRICS:\n" + stats += f" Total call counts: {self.get_call_counts()}\n" + stats += f" Average latency: {self.get_avg_latency()} ms\n" if sum(self.token_counts) > 0: - logger.info(" Total tokens processed: %d", sum(self.token_counts)) - logger.info( - " Average throughput: %.2f tokens/sec", self.get_avg_throughput() - ) + stats += f" Total tokens processed: {sum(self.token_counts)}\n" + stats += f" Average throughput: {self.get_avg_throughput()} tokens/sec\n" if self.host_times: - logger.info(" Average host time: %.2f us", self.get_avg_host_time()) + stats += f" Average host time: {self.get_avg_host_time()} us\n" if self.device_times: - logger.info( - " Average device time: %.2f us", self.get_avg_device_time() - ) + stats += f" Average device time: {self.get_avg_device_time()} us\n" if self.ccl_times: - logger.info(" Average ccl time: %.2f us", self.get_avg_ccl_time()) + stats += f" Average ccl time: {self.get_avg_ccl_time()} us\n" else: - logger.info("%s METRICS: No data recorded", stat_type) + stats += f"{stat_type} METRICS: No data recorded\n" + return stats + + def dump_stats(self, stat_type: str, stats: str): + filename = f"{stat_type}_metrics.txt" + if os.path.exists(filename): + os.remove(filename) + with open(filename, "w") as f: + f.write(stats) + def show_stats(self, stat_type: str): + stats = self.gen_stats(stat_type) + logger.info(stats) + if envs.VLLM_RBLN_DUMP_METRICS: + self.dump_stats(stat_type, stats) class PrefillMetricsByRequestID: """Metrics for prefill step by request id."""