-
Notifications
You must be signed in to change notification settings - Fork 395
[Feature] Opt metrics structure #891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 23 commits
501c5ab
d261ba5
8538816
89493ac
f69ed2d
e3a44db
da0ad3d
cb85a3d
371beae
2b98563
134a901
ee12352
0af170c
cf7e2c0
eb51d12
38785ed
aeb5fd6
924f747
8fd556b
a78af4d
44c9635
7c08073
443022b
fbbce79
2b2edfc
a42a656
e7f3fae
2e87f70
93b53f8
d584e71
d359775
6ebe5ee
ab248bb
04230c8
fdfc9b5
b5d154a
274a784
43a266b
e4ff53e
13a87f2
bacd480
b339c38
935481c
2263dd1
f3b88b1
f0bdfaa
da7a271
bcd9ac4
abf941e
03afeaf
842af89
56ecac3
cbdac45
8ee59ce
48707f0
9d76475
2631578
e0ce96f
04da676
26e18b3
a8bcbc0
114a6a3
9126e68
0b905cf
3481dcc
7665b29
c78d420
2b37f16
78963fb
141d8f8
7c95eb9
6687f65
68074ac
ef34329
bff608c
a59c766
4918ab1
4976551
d646401
5efbd55
e83a338
55c11c1
a94349b
6e63657
9a31bae
232da73
13b0050
db0d866
21de7db
dd051b2
dd73daf
c1c48f9
4e6acbe
d3c6f54
bd6d8cd
654073f
c9068a7
6626d62
fe0e4b9
89f3944
3b311f4
4b39808
3fff139
b9c2d46
0bb732e
7c91e96
da335c7
5abc397
fb3bacf
51f5e0a
41db219
3a95be0
571f297
ca2cb26
48a519c
1bd59d8
24f8bc8
ef2d5d6
23a24ee
dd5d7b7
41c58d4
9145181
f1195f8
4383b01
41482ff
f07d070
764151d
f1b41d3
382327e
75be00c
3ffa4cd
e352716
7faa2e2
42f6f0f
e7c502f
a71fa64
fd9d3d4
00e7b78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| from __future__ import annotations | ||
| from vllm_omni.metrics import OrchestratorAggregator | ||
| from vllm_omni.metrics.stats import RequestE2EStats | ||
|
|
||
|
|
||
| def _get_request_entry(table: list[dict], request_id: str) -> dict: | ||
| for entry in table: | ||
| if entry.get("request_id") == request_id: | ||
| return entry | ||
| raise AssertionError(f"request_id={request_id} not found") | ||
|
|
||
|
|
||
| def test_orchestrator_aggregator_builds_summary() -> None: | ||
| agg = OrchestratorAggregator(num_stages=2, enable_stats=False, wall_start_ts=0.0) | ||
| agg.stage_first_ts[0] = 0.0 | ||
| agg.stage_last_ts[0] = 0.03 | ||
| agg.stage_first_ts[1] = 0.05 | ||
| agg.stage_last_ts[1] = 0.07 | ||
|
|
||
| agg.on_forward(0, 1, "r1", size_bytes=1024, tx_ms=5.0, used_shm=False) | ||
| agg.on_stage_metrics( | ||
| 0, | ||
| "r1", | ||
| { | ||
| "num_tokens_in": 3, | ||
| "num_tokens_out": 3, | ||
| "stage_gen_time_ms": 30.0, | ||
| "batch_id": 1, | ||
| "batch_size": 1, | ||
| "rx_transfer_bytes": 0, | ||
| "rx_decode_time_ms": 0.0, | ||
| }, | ||
| ) | ||
| agg.on_stage_metrics( | ||
| 1, | ||
| "r1", | ||
| { | ||
| "num_tokens_out": 4, | ||
| "stage_gen_time_ms": 20.0, | ||
| "batch_id": 1, | ||
| "batch_size": 1, | ||
| "rx_transfer_bytes": 1024, | ||
| "rx_decode_time_ms": 5.0, | ||
| "rx_in_flight_time_ms": 2.0, | ||
| }, | ||
| ) | ||
| agg.on_finalize_request(1, "r1", req_start_ts=0.0) | ||
|
|
||
| summary = agg.build_and_log_summary(final_stage_id_to_prompt={"r1": 1}) | ||
| overall = summary["overall_summary"] | ||
| assert overall["e2e_requests"] == 1 | ||
|
|
||
| stage_entry = _get_request_entry(summary["stage_table"], "r1") | ||
| stage_ids = [row["stage_id"] for row in stage_entry["stages"]] | ||
| assert stage_ids == [0, 1] | ||
|
|
||
| transfer_entry = _get_request_entry(summary["trans_table"], "r1") | ||
| assert transfer_entry["transfers"][0]["edge"] == "0->1" | ||
| assert transfer_entry["transfers"][0]["size_kbytes"] == 1.0 | ||
|
|
||
| e2e_entry = _get_request_entry(summary["e2e_table"], "r1") | ||
| assert e2e_entry["e2e_total_tokens"] == 10 | ||
|
|
||
|
|
||
| def test_build_and_log_summary_e2e_only() -> None: | ||
| agg = OrchestratorAggregator(num_stages=1, enable_stats=False, wall_start_ts=0.0) | ||
| agg.e2e_events.append( | ||
| RequestE2EStats( | ||
| request_id="r", | ||
| e2e_total_ms=10.0, | ||
| e2e_total_tokens=5, | ||
| transfers_total_time_ms=0.0, | ||
| transfers_total_bytes=0, | ||
| ) | ||
| ) | ||
|
|
||
| summary = agg.build_and_log_summary(final_stage_id_to_prompt=0) | ||
| e2e_entry = _get_request_entry(summary["e2e_table"], "r") | ||
| assert e2e_entry["e2e_total_tokens"] == 5 | ||
| stage_entry = _get_request_entry(summary["stage_table"], "r") | ||
| assert stage_entry["stages"] == [] | ||
|
|
||
|
|
||
| def test_build_and_log_summary_multiple_requests() -> None: | ||
| agg = OrchestratorAggregator(num_stages=1, enable_stats=False, wall_start_ts=0.0) | ||
|
|
||
| agg.on_stage_metrics( | ||
| 0, | ||
| "r1", | ||
| { | ||
| "num_tokens_in": 2, | ||
| "num_tokens_out": 4, | ||
| "batch_id": 1, | ||
| "batch_size": 1, | ||
| "stage_gen_time_ms": 10.0, | ||
| "rx_transfer_bytes": 0, | ||
| "rx_decode_time_ms": 0.0, | ||
| "rx_in_flight_time_ms": 0.0, | ||
| }, | ||
| ) | ||
| agg.on_finalize_request(0, "r1", req_start_ts=0.0) | ||
|
|
||
| agg.on_stage_metrics( | ||
| 0, | ||
| "r2", | ||
| { | ||
| "num_tokens_in": 1, | ||
| "num_tokens_out": 2, | ||
| "batch_id": 2, | ||
| "batch_size": 1, | ||
| "stage_gen_time_ms": 12.0, | ||
| "rx_transfer_bytes": 0, | ||
| "rx_decode_time_ms": 0.0, | ||
| "rx_in_flight_time_ms": 0.0, | ||
| }, | ||
| ) | ||
| agg.on_finalize_request(0, "r2", req_start_ts=0.0) | ||
|
|
||
| summary = agg.build_and_log_summary(final_stage_id_to_prompt=0) | ||
| assert len(summary["stage_table"]) == 2 | ||
| assert {entry["request_id"] for entry in summary["e2e_table"]} == {"r1", "r2"} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| from vllm.sampling_params import SamplingParams | ||
| from vllm.tokenizers import TokenizerLike | ||
| from vllm.v1.engine.exceptions import EngineDeadError | ||
| import vllm.envs as envs | ||
|
|
||
| # Internal imports (our code) | ||
| from vllm_omni.config import OmniModelConfig | ||
|
|
@@ -24,16 +25,14 @@ | |
| from vllm_omni.distributed.ray_utils.utils import try_close_ray | ||
| from vllm_omni.engine.input_processor import OmniInputProcessor | ||
| from vllm_omni.entrypoints.client_request_state import ClientRequestState | ||
| from vllm_omni.entrypoints.log_utils import ( | ||
| OrchestratorMetrics, | ||
| ) | ||
| from vllm_omni.entrypoints.omni import OmniBase | ||
| from vllm_omni.entrypoints.omni_stage import OmniStage | ||
| from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK, OmniStageTaskType | ||
| from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load | ||
| from vllm_omni.entrypoints.utils import ( | ||
| get_final_stage_id_for_e2e, | ||
| ) | ||
| from vllm_omni.metrics import OrchestratorAggregator | ||
| from vllm_omni.outputs import OmniRequestOutput | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
@@ -57,7 +56,6 @@ def _weak_close_cleanup_async(stage_list, stage_in_queues, ray_pg, output_handle | |
| if output_handler is not None: | ||
| output_handler.cancel() | ||
|
|
||
|
|
||
| class AsyncOmni(OmniBase): | ||
| """Asynchronous unified entry point supporting multi-stage pipelines for LLM and Diffusion models. | ||
|
|
||
|
|
@@ -320,27 +318,27 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator | |
| ) | ||
|
|
||
| # Metrics/aggregation helper | ||
| metrics = OrchestratorMetrics( | ||
| num_stages, | ||
| self._enable_stats, | ||
| _wall_start_ts, | ||
| metrics = OrchestratorAggregator( | ||
| num_stages=num_stages, | ||
| enable_stats=self._enable_stats, | ||
| wall_start_ts=_wall_start_ts, # will be reset at generate() time, just a placeholder here | ||
| ) | ||
| # Seed stage-0 queue with all requests | ||
| logger.debug(f"[{self._name}] Seeding request into stage-0") | ||
| req_state = ClientRequestState(request_id) | ||
| req_state.metrics = metrics | ||
| self.request_states[request_id] = req_state | ||
|
|
||
| _req_start_ts[request_id] = time.time() | ||
| # Mark first input time for stage-0 | ||
| metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() | ||
|
|
||
| sp0: SamplingParams = sampling_params_list[0] # type: ignore[index] | ||
| task = { | ||
| "request_id": request_id, | ||
| "engine_inputs": prompt, | ||
| "sampling_params": sp0, | ||
| } | ||
| self.stage_list[0].submit(task) | ||
| _req_start_ts[request_id] = time.time() | ||
| logger.debug(f"[{self._name}] Enqueued request {request_id} to stage-0") | ||
|
|
||
| logger.debug(f"[{self._name}] Entering scheduling loop: stages={num_stages}") | ||
|
|
@@ -366,6 +364,8 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator | |
| metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) | ||
| try: | ||
| _m = asdict(result.get("metrics")) | ||
| # stage_gen_time_ms is the time of generating every chunk in this stage | ||
| metrics.accumulated_gen_time_ms[req_id] += _m.get("stage_gen_time_ms", 0.0) | ||
| if _m is not None and finished: | ||
| metrics.on_stage_metrics(stage_id, req_id, _m) | ||
| except Exception as e: | ||
|
|
@@ -423,7 +423,11 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator | |
| next_stage_id = stage_id + 1 | ||
| if next_stage_id <= final_stage_id_for_e2e and finished: | ||
| next_stage: OmniStage = self.stage_list[next_stage_id] | ||
| # Derive inputs for the next stage, record preprocess time | ||
| _prep_t0 = time.perf_counter() | ||
| next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt) | ||
| _prep_ms = (time.perf_counter() - _prep_t0) * 1000.0 | ||
| metrics.record_stage_preprocess_time(next_stage_id, req_id, _prep_ms) | ||
|
||
| sp_next: SamplingParams = sampling_params_list[next_stage_id] | ||
|
|
||
| # Check if we have a connector for this edge | ||
|
|
@@ -460,11 +464,9 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator | |
| logger.debug(f"[{self._name}] Request {req_id} fully completed") | ||
|
|
||
| logger.debug(f"[{self._name}] All requests completed") | ||
|
|
||
| # Summarize and print stats | ||
| try: | ||
| summary = metrics.build_and_log_summary(final_stage_id_for_e2e) | ||
| logger.info("[Summary] %s", pformat(summary, sort_dicts=False)) | ||
| metrics.build_and_log_summary(final_stage_id_for_e2e) | ||
LJH-LBJ marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| except Exception as e: | ||
| logger.exception(f"[{self._name}] Failed to build/log summary: {e}") | ||
| finally: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.