Skip to content

Commit 6b55b0c

Browse files
Add tracing support in deepseek for vLLM flow (#36605)
### Ticket #36604 ### What's changed Deepseek supports tracing only in decode path for now. This patch implements tracing in vLLM workflow for deepseek. ### Checklist - [ ] [![All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml/badge.svg?branch=pprajapati/vllm_tracing)](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml?query=branch:pprajapati/vllm_tracing) - [ ] [![Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml/badge.svg?branch=pprajapati/vllm_tracing)](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml?query=branch:pprajapati/vllm_tracing) - [ ] [![cpp-unit-tests](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml/badge.svg?branch=pprajapati/vllm_tracing)](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml?query=branch:pprajapati/vllm_tracing) - [ ] New/Existing tests provide coverage for changes #### Model tests If your changes cover model-related code, you should run tests corresponding to affected models and platforms (Single card, T3K, Galaxy). "Choose your pipeline" workflows facilitate running multiple kinds of tests in a single run. Each offers `models-mandatory` and `models-extended` presets. The former includes a minimal set of tests, to be run always. The latter extends that with additional ones - use your best judgement in deciding which is the most appropriate for your PR. - [ ] [![(Single) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml/badge.svg?branch=pprajapati/vllm_tracing)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml?query=branch:pprajapati/vllm_tracing) - [ ] `models-mandatory` preset (runs: [Device perf regressions](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-device-models.yaml) and [Frequent model and ttnn tests](https://github.com/tenstorrent/tt-metal/actions/workflows/fast-dispatch-full-regressions-and-models.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/single-card-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-models.yaml) tests) - [ ] other selection - specify runs - [ ] [![(T3K) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-t3k.yaml/badge.svg?branch=pprajapati/vllm_tracing)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-t3k.yaml?query=branch:pprajapati/vllm_tracing) - [ ] `models-mandatory` preset (runs: [Unit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-unit-tests.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-model-perf-tests.yaml) tests) - [ ] other selection - specify runs - [ ] [![(Galaxy) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml/badge.svg?branch=pprajapati/vllm_tracing)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml?query=branch:pprajapati/vllm_tracing) - [ ] `models-mandatory` preset (runs: [Quick tests](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-quick.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-model-perf-tests.yaml) tests) - [ ] other selection - specify runs --------- Signed-off-by: Pratikkumar Prajapati <pprajapati@tenstorrent.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 7ca802b commit 6b55b0c

File tree

3 files changed

+115
-50
lines changed

3 files changed

+115
-50
lines changed

models/demos/deepseek_v3/demo/demo.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,18 @@ def run_demo(
277277
logger.info(f"Opening mesh device with shape {mesh_shape}")
278278
if enable_trace:
279279
logger.info("Enabling trace for decode forward pass")
280-
trace_region_size = 4880384 + int(0.20 * 4880384) # 20% additional
280+
# NOTE:
281+
# The base trace region size below (~36.3 MiB) was empirically determined from
282+
# vLLM decode workloads to be sufficient to keep the trace buffer from
283+
# overflowing under typical DeepSeek-V3 demo settings (batch size, sequence
284+
# length, and mesh configuration). We add 20% headroom as a conservative
285+
# safety margin to accommodate variability across models / prompts without
286+
# repeatedly re-tuning this value.
287+
#
288+
# If you are optimizing memory usage, this can be reduced after verifying
289+
# that tracing completes without buffer exhaustion for your target workload.
290+
BASE_TRACE_REGION_BYTES = 38_070_272
291+
trace_region_size = BASE_TRACE_REGION_BYTES + int(0.20 * BASE_TRACE_REGION_BYTES)
281292
logger.info(f"Trace region size set to {trace_region_size}")
282293
mesh_device = ttnn.open_mesh_device(mesh_shape=mesh_shape, trace_region_size=trace_region_size)
283294
else:

models/demos/deepseek_v3/tt/generator.py

Lines changed: 94 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,21 @@ def __init__(
132132
self.random_weights = random_weights
133133
self.single_layer = single_layer
134134

135+
# Model runtime state
136+
self.model_state = None
137+
self.model_shared_state = None
138+
self.model_prefill_cfg = None
139+
self.model_decode_cfg = None
140+
self.model_weight_config = None
141+
self.page_tables_tt = None
142+
135143
# Trace state (decode)
136144
self._trace_id: int | None = None
137145
self._trace_tokens: ttnn.Tensor | None = None
138146
self._trace_positions: ttnn.Tensor | None = None
139147
self._trace_rot_idxs: ttnn.Tensor | None = None
140148
self._trace_output: ttnn.Tensor | None = None
149+
self._trace_page_tables_to_use: tuple[ttnn.Tensor, ...] | None = None
141150
self.enable_trace = enable_trace
142151
self.signpost = signpost
143152
self.prefill_max_tokens = prefill_max_tokens
@@ -279,20 +288,47 @@ def cleanup_all(self) -> None:
279288

280289
# Clean up model states
281290
try:
282-
if hasattr(self, "model_state") and self.model_state is not None:
291+
if self.model_state is not None:
283292
del self.model_state
284293
except Exception as e:
285294
logger.warning(f"Failed to cleanup model state: {e}")
286295

287296
try:
288-
if hasattr(self, "model_shared_state") and self.model_shared_state is not None:
297+
if self.model_shared_state is not None:
289298
del self.model_shared_state
290299
except Exception as e:
291300
logger.warning(f"Failed to cleanup model shared state: {e}")
292301

302+
# Clean up trace state
303+
try:
304+
if self._trace_id is not None:
305+
ttnn.release_trace(self.mesh_device, self._trace_id)
306+
del self._trace_id
307+
if self._trace_tokens is not None:
308+
ttnn.deallocate(self._trace_tokens)
309+
del self._trace_tokens
310+
if self._trace_positions is not None:
311+
ttnn.deallocate(self._trace_positions)
312+
del self._trace_positions
313+
if self._trace_rot_idxs is not None:
314+
ttnn.deallocate(self._trace_rot_idxs)
315+
del self._trace_rot_idxs
316+
if self._trace_output is not None:
317+
ttnn.deallocate(self._trace_output)
318+
del self._trace_output
319+
if self._trace_page_tables_to_use is not None and self._trace_page_tables_to_use is not self.page_tables_tt:
320+
for i, page_table in enumerate(self._trace_page_tables_to_use):
321+
try:
322+
ttnn.deallocate(page_table)
323+
except Exception as e:
324+
logger.warning(f"Failed to deallocate trace page table {i}: {e}")
325+
del self._trace_page_tables_to_use
326+
except Exception as e:
327+
logger.warning(f"Failed to cleanup trace state: {e}")
328+
293329
# Clean up page tables (TTNN tensors)
294330
try:
295-
if hasattr(self, "page_tables_tt") and self.page_tables_tt is not None:
331+
if self.page_tables_tt is not None:
296332
for i, page_table in enumerate(self.page_tables_tt):
297333
try:
298334
ttnn.deallocate(page_table)
@@ -304,45 +340,37 @@ def cleanup_all(self) -> None:
304340

305341
# Clean up RoPE setup
306342
try:
307-
if hasattr(self, "rope_setup") and self.rope_setup is not None:
343+
if self.rope_setup is not None:
308344
del self.rope_setup
309345
except Exception as e:
310346
logger.warning(f"Failed to cleanup RoPE setup: {e}")
311347

312348
# Clean up CCL
313349
try:
314-
if hasattr(self, "ccl") and self.ccl is not None:
350+
if self.ccl is not None:
315351
del self.ccl
316352
except Exception as e:
317353
logger.warning(f"Failed to cleanup CCL: {e}")
318354

319355
# Clean up configs
320356
try:
321-
if hasattr(self, "model_prefill_cfg") and self.model_prefill_cfg is not None:
357+
if self.model_prefill_cfg is not None:
322358
del self.model_prefill_cfg
323-
if hasattr(self, "model_decode_cfg") and self.model_decode_cfg is not None:
359+
if self.model_decode_cfg is not None:
324360
del self.model_decode_cfg
325-
if hasattr(self, "model_weight_config") and self.model_weight_config is not None:
361+
if self.model_weight_config is not None:
326362
del self.model_weight_config
327363

328364
except Exception as e:
329365
logger.warning(f"Failed to cleanup model configs: {e}")
330366

331367
# Clean up paged config
332368
try:
333-
if hasattr(self, "paged_config") and self.paged_config is not None:
369+
if self.paged_config is not None:
334370
del self.paged_config
335371
except Exception as e:
336372
logger.warning(f"Failed to cleanup paged config: {e}")
337373

338-
# Clean up trace state
339-
if self.enable_trace:
340-
try:
341-
if hasattr(self, "_trace_id") and self._trace_id is not None:
342-
ttnn.release_trace(self.mesh_device, self._trace_id)
343-
except Exception as e:
344-
logger.warning(f"Failed to release trace: {e}")
345-
346374
def __enter__(self):
347375
"""Context manager entry."""
348376
return self
@@ -414,7 +442,7 @@ def _decode_step(
414442
tokens_step: torch.Tensor,
415443
positions: torch.Tensor,
416444
batch_size_per_row: int,
417-
page_table: torch.Tensor | None = None,
445+
page_tables: torch.Tensor | None = None,
418446
return_rot_idxs: bool = False,
419447
) -> torch.Tensor | Tuple[torch.Tensor, ttnn.Tensor]:
420448
"""Run a single decode step and return logits on host as torch tensor [1, 1, B, V].
@@ -444,8 +472,8 @@ def _decode_step(
444472
dtype=ttnn.int32,
445473
)
446474

447-
if page_table is not None:
448-
page_tables_to_use = self._convert_vllm_page_table_for_batch(page_table)
475+
if page_tables is not None:
476+
page_tables_to_use = self._convert_vllm_page_table_for_batch(page_tables, device=self.mesh_device)
449477
else:
450478
page_tables_to_use = self._get_page_tables()
451479
# RowBatchedModel forward
@@ -637,11 +665,11 @@ def generate(
637665
logger.info(f"Decoding step {gen_idx} for {num_of_prompts} user(s)...")
638666
profiler.start(f"decode_time_{gen_idx}")
639667
logits = self.decode_forward(
640-
next_tokens,
641-
positions,
642-
self.batch_size_per_row,
643-
profiler,
644-
gen_idx,
668+
tokens=next_tokens,
669+
positions=positions,
670+
batch_size_per_row=self.batch_size_per_row,
671+
profiler=profiler,
672+
gen_idx=gen_idx,
645673
enable_trace=self.enable_trace,
646674
)
647675
profiler.end(f"decode_time_{gen_idx}")
@@ -818,14 +846,18 @@ def _prefill(
818846
return logits # [1, 1, seq_len, V]
819847

820848
def _capture_decode_trace(
821-
self, init_tokens: torch.Tensor, positions: torch.Tensor, batch_size_per_row: int
849+
self,
850+
init_tokens: torch.Tensor,
851+
positions: torch.Tensor,
852+
batch_size_per_row: int,
853+
page_tables: torch.Tensor | None = None,
822854
) -> None:
823855
"""Allocate persistent inputs, capture trace for one decode iteration, and store trace state."""
824856
assert self._trace_id is None, "Trace already captured"
825857

826858
# 1) Warm-up compile run (no trace) to keep compilation out of capture
827859
logger.info("Running warm-up decode step (no trace)...")
828-
_ = self._decode_step(init_tokens, positions, batch_size_per_row=batch_size_per_row)
860+
_ = self._decode_step(init_tokens, positions, batch_size_per_row=batch_size_per_row, page_tables=page_tables)
829861
ttnn.synchronize_device(self.mesh_device)
830862

831863
# 2) Allocate persistent device inputs
@@ -838,6 +870,13 @@ def _capture_decode_trace(
838870
)
839871

840872
self._trace_rot_idxs = self.rope_setup.get_rot_idxs(positions)
873+
874+
if page_tables is not None:
875+
self._trace_page_tables_to_use = self._convert_vllm_page_table_for_batch(
876+
page_tables, device=self.mesh_device
877+
)
878+
else:
879+
self._trace_page_tables_to_use = self._get_page_tables()
841880
ttnn.synchronize_device(self.mesh_device)
842881

843882
# 3) Capture decode graph
@@ -847,15 +886,12 @@ def _capture_decode_trace(
847886

848887
# Only capture the rot_mats generation from rot_idxs (all ttnn ops, no from_torch)
849888
rope_tensors = self.rope_setup.get_rot_mats_from_rot_idxs(self._trace_rot_idxs)
850-
logger.info(f"Rope tensors done")
851-
852-
# TODO: Fix this for vLLM
853889
self._trace_output = RowBatchedModel.forward_decode(
854890
x=self._trace_tokens,
855891
position_idxs=self._trace_positions,
856892
cfg=self.model_run_config_decode,
857893
rope_tensors=rope_tensors,
858-
page_tables=self.page_tables_tt,
894+
page_tables=self._trace_page_tables_to_use,
859895
)
860896
ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0)
861897
logger.info("Decode trace capture complete.")
@@ -866,16 +902,20 @@ def decode_forward(
866902
tokens: torch.Tensor,
867903
positions: torch.Tensor,
868904
batch_size_per_row: int,
869-
profiler: BenchmarkProfiler,
870-
gen_idx: int,
905+
gen_idx: int = 0,
906+
profiler: BenchmarkProfiler | None = None,
871907
enable_trace: bool = False,
908+
page_tables: torch.Tensor | None = None,
872909
) -> torch.Tensor:
910+
# vLLM does not pass enable_trace param while initializing the model.
911+
# vLLM sets it in decode/prefill calls only, so we need to set it here too.
912+
self.enable_trace = enable_trace
873913
if not enable_trace:
874-
return self._decode_step(tokens, positions, batch_size_per_row).squeeze(0).squeeze(0)
914+
return self._decode_step(tokens, positions, batch_size_per_row, page_tables).squeeze(0).squeeze(0)
875915
else:
876916
# Capture trace and return trace output
877917
if self._trace_id is None:
878-
self._capture_decode_trace(tokens, positions, batch_size_per_row)
918+
self._capture_decode_trace(tokens, positions, batch_size_per_row, page_tables)
879919
# First call: return the captured run's output
880920
assert self._trace_output is not None
881921
logits = ttnn.to_torch(
@@ -892,6 +932,7 @@ def decode_forward(
892932
and self._trace_positions is not None
893933
and self._trace_rot_idxs is not None
894934
and self._trace_id is not None
935+
and self._trace_page_tables_to_use is not None
895936
)
896937
torch_input = tokens.view(1, 1, -1).to(torch.int32)
897938

@@ -921,13 +962,20 @@ def decode_forward(
921962
host_rot_idxs = self.rope_setup.get_rot_idxs(positions, on_host=True)
922963
ttnn.copy_host_to_device_tensor(host_rot_idxs, self._trace_rot_idxs)
923964

965+
if page_tables is not None:
966+
page_tables_to_use = self._convert_vllm_page_table_for_batch(page_tables, device=None)
967+
for i, page_table in enumerate(page_tables_to_use):
968+
ttnn.copy_host_to_device_tensor(page_table, self._trace_page_tables_to_use[i])
969+
924970
self.ccl.reset_sem_counters()
925-
profiler.start(f"trace_execution_{gen_idx}")
971+
if profiler is not None:
972+
profiler.start(f"trace_execution_{gen_idx}")
926973
ttnn.execute_trace(self.mesh_device, self._trace_id, cq_id=0, blocking=True)
927-
profiler.end(f"trace_execution_{gen_idx}")
928-
logger.info(
929-
f"Trace execution t/s/user @ {gen_idx}th token: {1/profiler.get_duration(f'trace_execution_{gen_idx}')}"
930-
)
974+
if profiler is not None:
975+
profiler.end(f"trace_execution_{gen_idx}")
976+
logger.info(
977+
f"Trace execution t/s/user @ {gen_idx}th token: {1/profiler.get_duration(f'trace_execution_{gen_idx}')}"
978+
)
931979
assert self._trace_output is not None
932980
logits = ttnn.to_torch(
933981
self._trace_output,
@@ -1034,13 +1082,17 @@ def _convert_vllm_page_table_for_user(
10341082
num_layers = self.hf_config.num_hidden_layers
10351083
return tuple(ttnn.clone(page_table_tt) for _ in range(num_layers))
10361084

1037-
def _convert_vllm_page_table_for_batch(self, page_table: torch.Tensor) -> tuple[ttnn.Tensor, ...]:
1085+
def _convert_vllm_page_table_for_batch(
1086+
self, page_table: torch.Tensor, device: ttnn.Device | ttnn.MeshDevice | None
1087+
) -> tuple[ttnn.Tensor, ...]:
10381088
"""
10391089
Convert vLLM's block_tables (page_table) to TTNN tensor format for the entire batch.
10401090
Creates one page table per layer as expected by the model.
10411091
10421092
Args:
10431093
page_table: torch.Tensor of shape [batch_size, max_num_blocks_per_req] from vLLM
1094+
device: ttnn.Device, ttnn.MeshDevice, or None. If provided, creates device tensors on the specified device.
1095+
If None, creates host tensors instead of device tensors.
10441096
10451097
Returns:
10461098
Tuple of TTNN tensors, one per layer
@@ -1051,7 +1103,7 @@ def _convert_vllm_page_table_for_batch(self, page_table: torch.Tensor) -> tuple[
10511103

10521104
page_table_tt = ttnn.from_torch(
10531105
page_table,
1054-
device=self.mesh_device,
1106+
device=device,
10551107
dtype=ttnn.int32,
10561108
layout=ttnn.ROW_MAJOR_LAYOUT,
10571109
mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=0),

models/demos/deepseek_v3/tt/generator_vllm.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,24 +132,26 @@ def prefill_forward(self, *args, **kwargs):
132132
def decode_forward(self, *args, **kwargs):
133133
assert self.model_run_config_decode is not None, "Model run config decode is not initialized"
134134

135-
page_table = kwargs.get("page_table", None)
135+
page_tables = kwargs.get("page_table", None)
136136
kv_cache = kwargs.get("kv_cache", None)
137+
enable_trace = kwargs.get("enable_trace", False)
137138
# Set kv_cache if provided and all entries are valid
138139
if kv_cache is not None and not any(entry is None for entry in kv_cache):
139140
self.set_kv_cache(kv_cache)
140141

141142
tokens_step = kwargs["tokens"].squeeze(1)
143+
142144
return_value = (
143-
self._decode_step(
144-
tokens_step=tokens_step,
145+
super()
146+
.decode_forward(
147+
tokens=tokens_step,
145148
positions=kwargs["start_pos"],
146149
batch_size_per_row=USERS_PER_ROW,
147-
page_table=page_table,
150+
enable_trace=enable_trace,
151+
page_tables=page_tables,
148152
)
149-
.squeeze(0)
150-
.squeeze(0)
151153
.unsqueeze(1)
152-
) # [1,1,B,V] -> [B, 1, V]
154+
) # [B, V] -> [B, 1, V]
153155
return return_value
154156

155157
def allocate_kv_cache(self, kv_cache_shape, dtype, num_layers):

0 commit comments

Comments
 (0)