|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +import contextlib |
14 | 15 | import logging |
15 | 16 | import time |
16 | 17 | from typing import TYPE_CHECKING, NamedTuple, Union, cast |
17 | 18 |
|
18 | 19 | import numpy as np |
| 20 | +import rebel |
19 | 21 | import torch |
20 | 22 | import torch.distributed |
21 | 23 | import torch.nn as nn |
@@ -240,8 +242,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): |
240 | 242 | ) |
241 | 243 |
|
242 | 244 | if envs.VLLM_RBLN_METRICS: |
243 | | - self.performance_tracker = PerformanceTracker() |
244 | | - self.performance_tracker.register_cleanup() |
| 245 | + self.model_performance_tracker = PerformanceTracker("MODEL") |
| 246 | + self.sampler_performance_tracker = PerformanceTracker("SAMPLER") |
245 | 247 |
|
246 | 248 | # Ephemeral state transferred |
247 | 249 | # between execute_model() and sample_tokens(). |
@@ -309,22 +311,26 @@ def execute_model( |
309 | 311 | ) |
310 | 312 |
|
311 | 313 | with record_function_or_nullcontext("rbln_model_runner: forward"): |
312 | | - start_time = time.perf_counter() |
313 | | - # FIXME model_input must be modified to be padded |
314 | | - hidden_states = self.model(model_input) |
| 314 | + if hasattr(rebel, "capture_reports"): |
| 315 | + capture_ctx = rebel.capture_reports() |
| 316 | + else: |
| 317 | + # use a dummy context manager that does nothing |
| 318 | + capture_ctx = contextlib.nullcontext() |
| 319 | + model_start_time = time.perf_counter() |
| 320 | + with capture_ctx as model_reports: |
| 321 | + # FIXME model_input must be modified to be padded |
| 322 | + hidden_states = self.model(model_input) |
| 323 | + if envs.VLLM_RBLN_METRICS and self.model_performance_tracker is not None: |
| 324 | + self.collect_metrics( |
| 325 | + self.model_performance_tracker, |
| 326 | + model_input.is_prompt, |
| 327 | + start_time=model_start_time, |
| 328 | + end_time=time.perf_counter(), |
| 329 | + reports=model_reports, |
| 330 | + token_count=0, |
| 331 | + # the performance of sampler doesn't depend on token count |
| 332 | + ) |
315 | 333 | sample_hidden_states = hidden_states.clone() |
316 | | - end_time = time.perf_counter() |
317 | | - if envs.VLLM_RBLN_METRICS: |
318 | | - # Record performance metrics |
319 | | - execution_time = end_time - start_time |
320 | | - if model_input.is_prompt: |
321 | | - self.performance_tracker.record_prefill( |
322 | | - execution_time, num_scheduled_tokens |
323 | | - ) |
324 | | - else: |
325 | | - self.performance_tracker.record_decode( |
326 | | - execution_time, num_scheduled_tokens |
327 | | - ) |
328 | 334 |
|
329 | 335 | with record_function_or_nullcontext("rbln_model_runner: postprocess"): |
330 | 336 | if self.is_pooling_model: |
@@ -450,7 +456,7 @@ def _prepare_inputs( |
450 | 456 | finished_requests_ids=list(finished_requests_ids), |
451 | 457 | cached_block_tables=cached_block_tables, |
452 | 458 | cached_lengths=cached_lengths, |
453 | | - is_prompt=is_prefill, |
| 459 | + is_prompt=is_prefill, # FIXME unify the variable name is_prefill and is_prompt |
454 | 460 | dummy_block=scheduler_output.dummy_block, |
455 | 461 | ) |
456 | 462 | return model_input, num_scheduled_tokens |
@@ -1308,7 +1314,24 @@ def sample_tokens( |
1308 | 1314 | padded_logits = logits.reshape(1, -1) |
1309 | 1315 | else: |
1310 | 1316 | padded_logits = logits |
1311 | | - sampler_output = self._sample(padded_logits, spec_decode_metadata=None) |
| 1317 | + sampler_start_time = time.perf_counter() |
| 1318 | + if hasattr(rebel, "capture_reports"): |
| 1319 | + capture_ctx = rebel.capture_reports() |
| 1320 | + else: |
| 1321 | + # use a dummy context manager that does nothing |
| 1322 | + capture_ctx = contextlib.nullcontext() |
| 1323 | + with capture_ctx as sampler_reports: |
| 1324 | + sampler_output = self._sample(padded_logits, spec_decode_metadata=None) |
| 1325 | + if envs.VLLM_RBLN_METRICS and self.sampler_performance_tracker is not None: |
| 1326 | + self.collect_metrics( |
| 1327 | + self.sampler_performance_tracker, |
| 1328 | + is_prompt, |
| 1329 | + start_time=sampler_start_time, |
| 1330 | + end_time=time.perf_counter(), |
| 1331 | + reports=sampler_reports, |
| 1332 | + token_count=0, |
| 1333 | + # the performance of sampler doesn't depend on token count |
| 1334 | + ) |
1312 | 1335 | self.input_batch.prev_sampled_token_ids = None |
1313 | 1336 |
|
1314 | 1337 | with record_function_or_nullcontext("rbln_model_runner: bookkeep"): |
@@ -1481,3 +1504,37 @@ def postprocess_sampler_output( |
1481 | 1504 | logprobs_tensors = LogprobsTensors(**dict) |
1482 | 1505 |
|
1483 | 1506 | return num_sampled_tokens, sampled_token_ids, logprobs_tensors |
| 1507 | + |
| 1508 | + def collect_metrics( |
| 1509 | + self, |
| 1510 | + performance_tracker: PerformanceTracker, |
| 1511 | + is_prefill: bool, |
| 1512 | + start_time: float, |
| 1513 | + end_time: float, |
| 1514 | + reports: list[dict], |
| 1515 | + token_count: int, |
| 1516 | + ) -> None: |
| 1517 | + execution_time = end_time - start_time |
| 1518 | + host_time = None |
| 1519 | + device_time = None |
| 1520 | + ccl_time = None |
| 1521 | + if reports is not None and len(reports) > 0: |
| 1522 | + host_time = reports[0].get("total_host", None) |
| 1523 | + device_time = reports[0].get("total_device", None) |
| 1524 | + ccl_time = reports[0].get("total_ccl", None) |
| 1525 | + if is_prefill: |
| 1526 | + performance_tracker.record_prefill( |
| 1527 | + execution_time, |
| 1528 | + token_count, |
| 1529 | + host_time=host_time, |
| 1530 | + device_time=device_time, |
| 1531 | + ccl_time=ccl_time, |
| 1532 | + ) |
| 1533 | + else: |
| 1534 | + performance_tracker.record_decode( |
| 1535 | + execution_time, |
| 1536 | + token_count, |
| 1537 | + host_time=host_time, |
| 1538 | + device_time=device_time, |
| 1539 | + ccl_time=ccl_time, |
| 1540 | + ) |
0 commit comments