Commit 77811cb
authored
Fix profiling summary pipeline on GPU traces (marin-community#3362)
## Summary
Fixes marin-community#3345. The profiling summary pipeline produced empty `hot_ops`,
`communication_ops`, `gap_before_ops`, and `step_time` on GPU traces
because of three independent mismatches between the ingestion code
(written for TPU traces) and the actual GPU/NCCL trace format.
| Cause | Before | After |
|-------|--------|-------|
| Thread filter mismatch | 6 functions gate on `"XLA Ops"` — GPU uses
`Stream #N(...)` | New `_is_device_op_event()` matches both |
| Comm op naming | `_COMM_PATTERNS` misses
`ncclDevKernel_AllGather_RING_LL` etc. | Added `nccl`, `allgather`,
`allreduce`, `reducescatter` |
| No step markers | TPU uses `"Steps"` thread with numeric names — GPU
has neither | `StepTraceAnnotation` in trainer + host-side `step_num`
fallback in ingest |
## Changes
### `ingest.py` — thread and op recognition
<details><summary>Cause 1: Replace 6 hardcoded thread checks with
<code>_is_device_op_event()</code></summary>
The old code filtered device ops with:
```python
if event.thread_name not in {"XLA Ops", "Async XLA Ops"}:
continue
```
GPU traces use stream-based thread names like `Stream #0(compute)`,
`Stream #1(nccl)`, so every device op was silently dropped.
New predicate:
```python
_DEVICE_OP_THREAD_NAMES = frozenset({"XLA Ops", "Async XLA Ops"})
def _is_device_op_thread(thread_name: str | None) -> bool:
if thread_name is None:
return False
if thread_name in _DEVICE_OP_THREAD_NAMES:
return True
if thread_name.startswith("Stream #"):
return True
return False
def _is_device_op_event(event: _CompleteTraceEvent) -> bool:
return _is_device_event(event) and _is_device_op_thread(event.thread_name)
```
Updated call sites: `_summarize_hot_ops`, `_summarize_communication`,
`_summarize_pre_op_gaps`, `_summarize_hierarchical_regions`,
`_summarize_gap_region_contexts`, `_preferred_region_path_by_op`.
</details>
<details><summary>Cause 2: NCCL collective classification</summary>
Added to `_COMM_PATTERNS`: `"nccl"`, `"allgather"`, `"allreduce"`,
`"reducescatter"`.
Updated `_collective_kind()` to normalize unseparated NCCL names (e.g.
`ncclDevKernel_AllGather_RING_LL` → `"all-gather"`).
Updated `semantics.py` `collective` regex to match the same patterns for
family classification.
</details>
<details><summary>Cause 3: Host-side step markers</summary>
**Capture side** (`trainer.py`): Wrapped the compiled step body (the
`_maybe_save_jaxpr` call) in `jax.profiler.StepTraceAnnotation("train",
step_num=int(state.step))`. The annotation is scoped to the compiled
step only — hooks, logging, and tracker calls happen outside — so the
measured interval matches TPU device-side `"Steps"` semantics.
**Ingest side** (`ingest.py`): Added `step_num: int | None` to
`_CompleteTraceEvent`. When the TPU-style `"Steps"` thread produces no
results, falls back to host-side events filtered to `name == "train"` on
`/host:*` processes:
```python
if not per_step:
for event in events:
if event.step_num is None:
continue
if event.name != "train":
continue
if not event.process_name or not event.process_name.startswith("/host:"):
continue
per_step[event.step_num].append(event.dur)
```
The fallback only fires when no TPU-style step markers exist, so
existing TPU behavior is unchanged.
</details>
### `semantics.py` — collective family regex
Extended the `collective` family pattern to match NCCL naming (`nccl`,
`allgather`, `allreduce`, `reducescatter`).
### `trainer.py` — step annotation
Wrapped the compiled step body in
`jax.profiler.StepTraceAnnotation("train", step_num=...)` inside
`train_step()`. Scoped narrowly to exclude hooks/logging so GPU step
timing is comparable to TPU device-side timing.
## Test plan
### Unit tests
All 12 tests pass (11 existing TPU + 1 new GPU):
```
tests/profiling/test_profile_summary.py 12 passed
```
**New: `test_gpu_stream_threads_and_nccl_ops`** — synthetic GPU trace
with `Stream #N` threads, NCCL kernel names, and host-side `step_num`
events. Asserts:
- `step_time.all_steps.count == 3` (host-side fallback works)
- `hot_ops` contains `fusion.1` (stream threads recognized)
- `communication_ops` contains `all-gather` and `reduce-scatter` (NCCL
classified)
- `gap_before_ops` is non-empty (gap analysis on stream threads)
### Pre-merge canary runs (end-to-end)
Both canary workflows triggered on this branch via `workflow_dispatch`.
Both passed.
| Canary | Workflow run | Result |
|--------|-------------|--------|
| **TPU** (v5p-8, Qwen3 30M, 1B tokens) |
[#22788717179](https://github.com/marin-community/marin/actions/runs/22788717179)
| **Passed** — no regression |
| **GPU** (8xH100 CW, Llama 150M, 1B tokens) |
[#22788717705](https://github.com/marin-community/marin/actions/runs/22788717705)
| **Passed** — all fields now populated |
W&B runs:
[`canary-tpu-22788717179-1`](https://wandb.ai/marin-community/marin/runs/canary-tpu-22788717179-1),
[`canary-gpu-22788717705-1`](https://wandb.ai/marin-community/marin/runs/canary-gpu-22788717705-1)
<details><summary>GPU canary results — before vs after</summary>
Profile summaries downloaded from W&B and re-summarized locally with the
branch code.
#### GPU: all three causes fixed
| Metric | Before (marin-community#3345) | After (this PR) |
|--------|---------------|-----------------|
| `hot_ops` | **0** | **25** |
| `communication_ops` | **0** | **4 collective types, 1,208 events** |
| `gap_before_ops` | **0** | **238** |
| `step_time.all_steps.count` | **0** | **6** (median 303,642 us) |
| `time_breakdown.communication` | 0.04% (misclassified) | **1.08%** |
| `time_breakdown.compute` | 22% (inflated by NCCL) | **16.0%** |
Top 5 hot ops (GPU):
| Op | Exclusive duration (us) | Count |
|----|------------------------|-------|
| `sm90_xmma_gemm_f32f32_tf32f32_f32_nt_n_...cublas` | 1,730,500 | 768 |
| `input_scatter_fusion_1` | 1,576,241 | 48 |
| `loop_multiply_fusion_6` | 1,095,015 | 768 |
| `sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_...cublas` | 1,020,830 | 768 |
| `nvjet_tss_192x192_64x3_1x2_h_bz_coopB_NNN` | 965,529 | 768 |
Communication ops (GPU):
| Collective | Count | Total duration (us) |
|-----------|-------|-------------------|
| `all-reduce` | 104 | 594,030 |
| `reduce-scatter` | 336 | 115,259 |
| `all-gather` | 672 | 63,331 |
| `send-recv` | 96 | 22,926 |
Top 3 pre-op gaps (GPU):
| Op | Total gap (us) | Count |
|----|---------------|-------|
| `MemcpyH2D` | 15,567,349 | 720 |
| `ncclDevKernel_ReduceScatter_Sum_bf16_RING_LL(...)` | 10,286,880 | 336
|
| `MemcpyD2H` | 6,191,228 | 142 |
Step timing (GPU):
| Stat | Value (us) |
|------|-----------|
| count | 6 |
| min | 284,241 |
| median | 303,642 |
| mean | 303,200 |
| max | 324,602 |
| p90 | 314,282 |
</details>
<details><summary>TPU canary results — no regression</summary>
| Metric | After (this PR) |
|--------|-----------------|
| `hot_ops` | **25** (fusion.*, copy.*, reshape.*) |
| `communication_ops` | **4 types** (all-reduce: 520, all-gather: 2,040,
all-to-all: 80, async-collective: 6,240) |
| `gap_before_ops` | **461** |
| `step_time.all_steps.count` | **60** (via TPU-native `"Steps"` thread
— host-side fallback did not fire) |
Top 5 hot ops (TPU):
| Op | Exclusive duration (us) | Count |
|----|------------------------|-------|
| `fusion.492` | 2,858,184 | 640 |
| `copy.319` | 1,143,337 | 640 |
| `fusion.483` | 1,016,882 | 640 |
| `reshape.436` | 836,632 | 640 |
| `fusion.481` | 828,741 | 640 |
Time breakdown (TPU):
| Category | Share |
|----------|-------|
| Compute | 44.6% |
| Stall | 50.6% |
| Host | 4.5% |
| Communication | 0.3% |
</details>
<details><summary>How to reproduce / deeply inspect</summary>
#### Re-summarize from W&B artifacts
```bash
# GPU canary
uv run python -m marin.profiling.cli summarize \
--run-target "canary-gpu-22788717705-1" \
--entity marin-community --project marin
# TPU canary
uv run python -m marin.profiling.cli summarize \
--run-target "canary-tpu-22788717179-1" \
--entity marin-community --project marin
```
#### Verify step annotation scope
The `StepTraceAnnotation` is scoped to the compiled step body only. To
confirm hooks aren't included in the measured interval, look at the raw
trace in the W&B artifact:
- Find `name="train"` events with `step_num` in args on the `/host:CPU`
process
- Their `dur` should be consistent across steps (no periodic spikes on
eval/checkpoint hook steps)
- The GPU canary shows a tight step time distribution (min 284k, max
325k us, ~14% spread) — consistent with no hook contamination
#### Verify TPU fallback did not fire
The TPU canary has `step_time.all_steps.count = 60` with a median of
~1,327 us — these are device-side step markers from the `"Steps"` thread
(microsecond-scale durations). The GPU canary has
`step_time.all_steps.count = 6` with a median of ~303,642 us — these are
host-side `StepTraceAnnotation` events (millisecond-scale). The two
paths produce different-scale measurements as expected, confirming the
fallback only fires on GPU.
</details>
## Not addressed
- **Trace truncation** (marin-community#3345, secondary): Both canary traces hit
exactly 1,000,000 complete events (`suspected_truncation: true`). This
is orthogonal — it affects data volume but not the format mismatches
fixed here. Worth a separate PR to filter host threads at capture time
or increase the cap.1 parent 3c9e438 commit 77811cb
File tree
4 files changed
+125
-29
lines changed- lib
- levanter/src/levanter
- marin/src/marin/profiling
- tests/profiling
4 files changed
+125
-29
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
482 | 482 | | |
483 | 483 | | |
484 | 484 | | |
485 | | - | |
486 | | - | |
487 | | - | |
488 | | - | |
489 | | - | |
490 | | - | |
491 | | - | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
492 | 494 | | |
493 | 495 | | |
494 | 496 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
66 | 71 | | |
67 | 72 | | |
| 73 | + | |
| 74 | + | |
68 | 75 | | |
69 | 76 | | |
70 | 77 | | |
| |||
140 | 147 | | |
141 | 148 | | |
142 | 149 | | |
| 150 | + | |
143 | 151 | | |
144 | 152 | | |
145 | 153 | | |
| |||
550 | 558 | | |
551 | 559 | | |
552 | 560 | | |
| 561 | + | |
553 | 562 | | |
554 | 563 | | |
555 | 564 | | |
| |||
652 | 661 | | |
653 | 662 | | |
654 | 663 | | |
| 664 | + | |
| 665 | + | |
655 | 666 | | |
656 | 667 | | |
657 | 668 | | |
| |||
663 | 674 | | |
664 | 675 | | |
665 | 676 | | |
| 677 | + | |
| 678 | + | |
| 679 | + | |
| 680 | + | |
| 681 | + | |
| 682 | + | |
| 683 | + | |
| 684 | + | |
| 685 | + | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
| 689 | + | |
666 | 690 | | |
667 | 691 | | |
668 | 692 | | |
| |||
823 | 847 | | |
824 | 848 | | |
825 | 849 | | |
826 | | - | |
827 | | - | |
828 | | - | |
| 850 | + | |
829 | 851 | | |
830 | 852 | | |
831 | 853 | | |
| |||
972 | 994 | | |
973 | 995 | | |
974 | 996 | | |
975 | | - | |
| 997 | + | |
976 | 998 | | |
977 | 999 | | |
978 | 1000 | | |
979 | | - | |
980 | | - | |
981 | 1001 | | |
982 | 1002 | | |
983 | 1003 | | |
| |||
1000 | 1020 | | |
1001 | 1021 | | |
1002 | 1022 | | |
1003 | | - | |
1004 | | - | |
1005 | | - | |
| 1023 | + | |
1006 | 1024 | | |
1007 | 1025 | | |
1008 | 1026 | | |
| |||
1060 | 1078 | | |
1061 | 1079 | | |
1062 | 1080 | | |
1063 | | - | |
1064 | | - | |
1065 | | - | |
| 1081 | + | |
1066 | 1082 | | |
1067 | 1083 | | |
1068 | 1084 | | |
| |||
1140 | 1156 | | |
1141 | 1157 | | |
1142 | 1158 | | |
1143 | | - | |
1144 | | - | |
1145 | | - | |
| 1159 | + | |
1146 | 1160 | | |
1147 | 1161 | | |
1148 | 1162 | | |
| |||
1562 | 1576 | | |
1563 | 1577 | | |
1564 | 1578 | | |
| 1579 | + | |
| 1580 | + | |
| 1581 | + | |
| 1582 | + | |
| 1583 | + | |
| 1584 | + | |
| 1585 | + | |
| 1586 | + | |
| 1587 | + | |
| 1588 | + | |
| 1589 | + | |
| 1590 | + | |
| 1591 | + | |
| 1592 | + | |
1565 | 1593 | | |
1566 | 1594 | | |
1567 | | - | |
| 1595 | + | |
1568 | 1596 | | |
1569 | | - | |
| 1597 | + | |
1570 | 1598 | | |
1571 | | - | |
| 1599 | + | |
1572 | 1600 | | |
1573 | 1601 | | |
1574 | 1602 | | |
| |||
1681 | 1709 | | |
1682 | 1710 | | |
1683 | 1711 | | |
1684 | | - | |
1685 | | - | |
1686 | | - | |
| 1712 | + | |
1687 | 1713 | | |
1688 | 1714 | | |
1689 | 1715 | | |
| |||
1814 | 1840 | | |
1815 | 1841 | | |
1816 | 1842 | | |
| 1843 | + | |
| 1844 | + | |
| 1845 | + | |
| 1846 | + | |
| 1847 | + | |
| 1848 | + | |
| 1849 | + | |
| 1850 | + | |
| 1851 | + | |
| 1852 | + | |
| 1853 | + | |
| 1854 | + | |
| 1855 | + | |
| 1856 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
22 | | - | |
| 22 | + | |
| 23 | + | |
23 | 24 | | |
24 | 25 | | |
25 | 26 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
338 | 338 | | |
339 | 339 | | |
340 | 340 | | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
341 | 394 | | |
342 | 395 | | |
343 | 396 | | |
| |||
0 commit comments