Skip to content

Commit 6dff6d3

Browse files
authored
fix: performance tracker (#503)
1 parent 6d58d6a commit 6dff6d3

4 files changed

Lines changed: 82 additions & 34 deletions

File tree

vllm_rbln/v1/worker/metrics.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import atexit
1615
from collections import defaultdict
1716
from dataclasses import dataclass, field
1817

@@ -193,13 +192,6 @@ def __init__(self, name: str | None = None):
193192
self.decode_metrics = StepMetrics()
194193
self.prefill_metrics_by_request_id = PrefillMetricsByRequestID()
195194
self.padded_decode_metrics = StepMetrics()
196-
self._registered_cleanup = False
197-
198-
def register_cleanup(self):
199-
"""Register cleanup function to print stats on exit."""
200-
if not self._registered_cleanup:
201-
atexit.register(self.print_final_stats)
202-
self._registered_cleanup = True
203195

204196
def check_dummy_request(self, request_ids: list[str] | None) -> bool:
205197
if request_ids:
@@ -227,7 +219,7 @@ def record_prefill(
227219
f"got {len(request_ids)}: {request_ids}"
228220
)
229221
request_id = request_ids[0]
230-
self.prefill_metrics.add_measurement(latency, token_count)
222+
self.prefill_metrics.add_measurement(latency, token_count, host_time, device_time, ccl_time)
231223
if request_id:
232224
self.prefill_metrics_by_request_id.add_measurement(
233225
request_id, latency, token_count, host_time, device_time, ccl_time

vllm_rbln/v1/worker/optimum_model_runner.py

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import contextlib
1415
import logging
1516
import time
1617
from typing import TYPE_CHECKING, NamedTuple, Union, cast
1718

1819
import numpy as np
20+
import rebel
1921
import torch
2022
import torch.distributed
2123
import torch.nn as nn
@@ -240,8 +242,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
240242
)
241243

242244
if envs.VLLM_RBLN_METRICS:
243-
self.performance_tracker = PerformanceTracker()
244-
self.performance_tracker.register_cleanup()
245+
self.model_performance_tracker = PerformanceTracker("MODEL")
246+
self.sampler_performance_tracker = PerformanceTracker("SAMPLER")
245247

246248
# Ephemeral state transferred
247249
# between execute_model() and sample_tokens().
@@ -309,22 +311,26 @@ def execute_model(
309311
)
310312

311313
with record_function_or_nullcontext("rbln_model_runner: forward"):
312-
start_time = time.perf_counter()
313-
# FIXME model_input must be modified to be padded
314-
hidden_states = self.model(model_input)
314+
if hasattr(rebel, "capture_reports"):
315+
capture_ctx = rebel.capture_reports()
316+
else:
317+
# use a dummy context manager that does nothing
318+
capture_ctx = contextlib.nullcontext()
319+
model_start_time = time.perf_counter()
320+
with capture_ctx as model_reports:
321+
# FIXME model_input must be modified to be padded
322+
hidden_states = self.model(model_input)
323+
if envs.VLLM_RBLN_METRICS and self.model_performance_tracker is not None:
324+
self.collect_metrics(
325+
self.model_performance_tracker,
326+
model_input.is_prompt,
327+
start_time=model_start_time,
328+
end_time=time.perf_counter(),
329+
reports=model_reports,
330+
token_count=0,
331+
# the performance of sampler doesn't depend on token count
332+
)
315333
sample_hidden_states = hidden_states.clone()
316-
end_time = time.perf_counter()
317-
if envs.VLLM_RBLN_METRICS:
318-
# Record performance metrics
319-
execution_time = end_time - start_time
320-
if model_input.is_prompt:
321-
self.performance_tracker.record_prefill(
322-
execution_time, num_scheduled_tokens
323-
)
324-
else:
325-
self.performance_tracker.record_decode(
326-
execution_time, num_scheduled_tokens
327-
)
328334

329335
with record_function_or_nullcontext("rbln_model_runner: postprocess"):
330336
if self.is_pooling_model:
@@ -450,7 +456,7 @@ def _prepare_inputs(
450456
finished_requests_ids=list(finished_requests_ids),
451457
cached_block_tables=cached_block_tables,
452458
cached_lengths=cached_lengths,
453-
is_prompt=is_prefill,
459+
is_prompt=is_prefill, # FIXME unify the variable name is_prefill and is_prompt
454460
dummy_block=scheduler_output.dummy_block,
455461
)
456462
return model_input, num_scheduled_tokens
@@ -1308,7 +1314,24 @@ def sample_tokens(
13081314
padded_logits = logits.reshape(1, -1)
13091315
else:
13101316
padded_logits = logits
1311-
sampler_output = self._sample(padded_logits, spec_decode_metadata=None)
1317+
sampler_start_time = time.perf_counter()
1318+
if hasattr(rebel, "capture_reports"):
1319+
capture_ctx = rebel.capture_reports()
1320+
else:
1321+
# use a dummy context manager that does nothing
1322+
capture_ctx = contextlib.nullcontext()
1323+
with capture_ctx as sampler_reports:
1324+
sampler_output = self._sample(padded_logits, spec_decode_metadata=None)
1325+
if envs.VLLM_RBLN_METRICS and self.sampler_performance_tracker is not None:
1326+
self.collect_metrics(
1327+
self.sampler_performance_tracker,
1328+
is_prompt,
1329+
start_time=sampler_start_time,
1330+
end_time=time.perf_counter(),
1331+
reports=sampler_reports,
1332+
token_count=0,
1333+
# the performance of sampler doesn't depend on token count
1334+
)
13121335
self.input_batch.prev_sampled_token_ids = None
13131336

13141337
with record_function_or_nullcontext("rbln_model_runner: bookkeep"):
@@ -1481,3 +1504,37 @@ def postprocess_sampler_output(
14811504
logprobs_tensors = LogprobsTensors(**dict)
14821505

14831506
return num_sampled_tokens, sampled_token_ids, logprobs_tensors
1507+
1508+
def collect_metrics(
1509+
self,
1510+
performance_tracker: PerformanceTracker,
1511+
is_prefill: bool,
1512+
start_time: float,
1513+
end_time: float,
1514+
reports: list[dict],
1515+
token_count: int,
1516+
) -> None:
1517+
execution_time = end_time - start_time
1518+
host_time = None
1519+
device_time = None
1520+
ccl_time = None
1521+
if reports is not None and len(reports) > 0:
1522+
host_time = reports[0].get("total_host", None)
1523+
device_time = reports[0].get("total_device", None)
1524+
ccl_time = reports[0].get("total_ccl", None)
1525+
if is_prefill:
1526+
performance_tracker.record_prefill(
1527+
execution_time,
1528+
token_count,
1529+
host_time=host_time,
1530+
device_time=device_time,
1531+
ccl_time=ccl_time,
1532+
)
1533+
else:
1534+
performance_tracker.record_decode(
1535+
execution_time,
1536+
token_count,
1537+
host_time=host_time,
1538+
device_time=device_time,
1539+
ccl_time=ccl_time,
1540+
)

vllm_rbln/v1/worker/optimum_worker.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,11 @@ def pin_lora(self, lora_id: int) -> bool:
243243

244244
def shutdown(self) -> None:
245245
logger.info("v1 optimum_worker shutdown called")
246-
if envs.VLLM_RBLN_METRICS and self.model_runner.performance_tracker:
247-
# FIXME - performance tracker atexit is not called
248-
self.model_runner.performance_tracker.print_final_stats()
246+
if envs.VLLM_RBLN_METRICS:
247+
if self.model_runner.model_performance_tracker:
248+
self.model_runner.model_performance_tracker.print_final_stats()
249+
if self.model_runner.sampler_performance_tracker:
250+
self.model_runner.sampler_performance_tracker.print_final_stats()
249251

250252

251253
def init_worker_distributed_environment(

vllm_rbln/v1/worker/rbln_model_runner.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -544,11 +544,8 @@ def __init__(
544544
def _enable_performance_tracker(self):
545545
if envs.VLLM_RBLN_METRICS:
546546
self.performance_tracker = PerformanceTracker("MODEL")
547-
self.performance_tracker.register_cleanup()
548547
self.sampler_performance_tracker = PerformanceTracker("SAMPLER")
549-
self.sampler_performance_tracker.register_cleanup()
550548
self.e2e_performance_tracker = PerformanceTracker("E2E")
551-
self.e2e_performance_tracker.register_cleanup()
552549

553550
def _get_positions(self, num_tokens: Any):
554551
if isinstance(num_tokens, int):

0 commit comments

Comments
 (0)