Skip to content

Commit 77811cb

Browse files
authored
Fix profiling summary pipeline on GPU traces (marin-community#3362)
## Summary Fixes marin-community#3345. The profiling summary pipeline produced empty `hot_ops`, `communication_ops`, `gap_before_ops`, and `step_time` on GPU traces because of three independent mismatches between the ingestion code (written for TPU traces) and the actual GPU/NCCL trace format. | Cause | Before | After | |-------|--------|-------| | Thread filter mismatch | 6 functions gate on `"XLA Ops"` — GPU uses `Stream #N(...)` | New `_is_device_op_event()` matches both | | Comm op naming | `_COMM_PATTERNS` misses `ncclDevKernel_AllGather_RING_LL` etc. | Added `nccl`, `allgather`, `allreduce`, `reducescatter` | | No step markers | TPU uses `"Steps"` thread with numeric names — GPU has neither | `StepTraceAnnotation` in trainer + host-side `step_num` fallback in ingest | ## Changes ### `ingest.py` — thread and op recognition <details><summary>Cause 1: Replace 6 hardcoded thread checks with <code>_is_device_op_event()</code></summary> The old code filtered device ops with: ```python if event.thread_name not in {"XLA Ops", "Async XLA Ops"}: continue ``` GPU traces use stream-based thread names like `Stream #0(compute)`, `Stream #1(nccl)`, so every device op was silently dropped. New predicate: ```python _DEVICE_OP_THREAD_NAMES = frozenset({"XLA Ops", "Async XLA Ops"}) def _is_device_op_thread(thread_name: str | None) -> bool: if thread_name is None: return False if thread_name in _DEVICE_OP_THREAD_NAMES: return True if thread_name.startswith("Stream #"): return True return False def _is_device_op_event(event: _CompleteTraceEvent) -> bool: return _is_device_event(event) and _is_device_op_thread(event.thread_name) ``` Updated call sites: `_summarize_hot_ops`, `_summarize_communication`, `_summarize_pre_op_gaps`, `_summarize_hierarchical_regions`, `_summarize_gap_region_contexts`, `_preferred_region_path_by_op`. </details> <details><summary>Cause 2: NCCL collective classification</summary> Added to `_COMM_PATTERNS`: `"nccl"`, `"allgather"`, `"allreduce"`, `"reducescatter"`. Updated `_collective_kind()` to normalize unseparated NCCL names (e.g. `ncclDevKernel_AllGather_RING_LL` → `"all-gather"`). Updated `semantics.py` `collective` regex to match the same patterns for family classification. </details> <details><summary>Cause 3: Host-side step markers</summary> **Capture side** (`trainer.py`): Wrapped the compiled step body (the `_maybe_save_jaxpr` call) in `jax.profiler.StepTraceAnnotation("train", step_num=int(state.step))`. The annotation is scoped to the compiled step only — hooks, logging, and tracker calls happen outside — so the measured interval matches TPU device-side `"Steps"` semantics. **Ingest side** (`ingest.py`): Added `step_num: int | None` to `_CompleteTraceEvent`. When the TPU-style `"Steps"` thread produces no results, falls back to host-side events filtered to `name == "train"` on `/host:*` processes: ```python if not per_step: for event in events: if event.step_num is None: continue if event.name != "train": continue if not event.process_name or not event.process_name.startswith("/host:"): continue per_step[event.step_num].append(event.dur) ``` The fallback only fires when no TPU-style step markers exist, so existing TPU behavior is unchanged. </details> ### `semantics.py` — collective family regex Extended the `collective` family pattern to match NCCL naming (`nccl`, `allgather`, `allreduce`, `reducescatter`). ### `trainer.py` — step annotation Wrapped the compiled step body in `jax.profiler.StepTraceAnnotation("train", step_num=...)` inside `train_step()`. Scoped narrowly to exclude hooks/logging so GPU step timing is comparable to TPU device-side timing. ## Test plan ### Unit tests All 12 tests pass (11 existing TPU + 1 new GPU): ``` tests/profiling/test_profile_summary.py 12 passed ``` **New: `test_gpu_stream_threads_and_nccl_ops`** — synthetic GPU trace with `Stream #N` threads, NCCL kernel names, and host-side `step_num` events. Asserts: - `step_time.all_steps.count == 3` (host-side fallback works) - `hot_ops` contains `fusion.1` (stream threads recognized) - `communication_ops` contains `all-gather` and `reduce-scatter` (NCCL classified) - `gap_before_ops` is non-empty (gap analysis on stream threads) ### Pre-merge canary runs (end-to-end) Both canary workflows triggered on this branch via `workflow_dispatch`. Both passed. | Canary | Workflow run | Result | |--------|-------------|--------| | **TPU** (v5p-8, Qwen3 30M, 1B tokens) | [#22788717179](https://github.com/marin-community/marin/actions/runs/22788717179) | **Passed** — no regression | | **GPU** (8xH100 CW, Llama 150M, 1B tokens) | [#22788717705](https://github.com/marin-community/marin/actions/runs/22788717705) | **Passed** — all fields now populated | W&B runs: [`canary-tpu-22788717179-1`](https://wandb.ai/marin-community/marin/runs/canary-tpu-22788717179-1), [`canary-gpu-22788717705-1`](https://wandb.ai/marin-community/marin/runs/canary-gpu-22788717705-1) <details><summary>GPU canary results — before vs after</summary> Profile summaries downloaded from W&B and re-summarized locally with the branch code. #### GPU: all three causes fixed | Metric | Before (marin-community#3345) | After (this PR) | |--------|---------------|-----------------| | `hot_ops` | **0** | **25** | | `communication_ops` | **0** | **4 collective types, 1,208 events** | | `gap_before_ops` | **0** | **238** | | `step_time.all_steps.count` | **0** | **6** (median 303,642 us) | | `time_breakdown.communication` | 0.04% (misclassified) | **1.08%** | | `time_breakdown.compute` | 22% (inflated by NCCL) | **16.0%** | Top 5 hot ops (GPU): | Op | Exclusive duration (us) | Count | |----|------------------------|-------| | `sm90_xmma_gemm_f32f32_tf32f32_f32_nt_n_...cublas` | 1,730,500 | 768 | | `input_scatter_fusion_1` | 1,576,241 | 48 | | `loop_multiply_fusion_6` | 1,095,015 | 768 | | `sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_...cublas` | 1,020,830 | 768 | | `nvjet_tss_192x192_64x3_1x2_h_bz_coopB_NNN` | 965,529 | 768 | Communication ops (GPU): | Collective | Count | Total duration (us) | |-----------|-------|-------------------| | `all-reduce` | 104 | 594,030 | | `reduce-scatter` | 336 | 115,259 | | `all-gather` | 672 | 63,331 | | `send-recv` | 96 | 22,926 | Top 3 pre-op gaps (GPU): | Op | Total gap (us) | Count | |----|---------------|-------| | `MemcpyH2D` | 15,567,349 | 720 | | `ncclDevKernel_ReduceScatter_Sum_bf16_RING_LL(...)` | 10,286,880 | 336 | | `MemcpyD2H` | 6,191,228 | 142 | Step timing (GPU): | Stat | Value (us) | |------|-----------| | count | 6 | | min | 284,241 | | median | 303,642 | | mean | 303,200 | | max | 324,602 | | p90 | 314,282 | </details> <details><summary>TPU canary results — no regression</summary> | Metric | After (this PR) | |--------|-----------------| | `hot_ops` | **25** (fusion.*, copy.*, reshape.*) | | `communication_ops` | **4 types** (all-reduce: 520, all-gather: 2,040, all-to-all: 80, async-collective: 6,240) | | `gap_before_ops` | **461** | | `step_time.all_steps.count` | **60** (via TPU-native `"Steps"` thread — host-side fallback did not fire) | Top 5 hot ops (TPU): | Op | Exclusive duration (us) | Count | |----|------------------------|-------| | `fusion.492` | 2,858,184 | 640 | | `copy.319` | 1,143,337 | 640 | | `fusion.483` | 1,016,882 | 640 | | `reshape.436` | 836,632 | 640 | | `fusion.481` | 828,741 | 640 | Time breakdown (TPU): | Category | Share | |----------|-------| | Compute | 44.6% | | Stall | 50.6% | | Host | 4.5% | | Communication | 0.3% | </details> <details><summary>How to reproduce / deeply inspect</summary> #### Re-summarize from W&B artifacts ```bash # GPU canary uv run python -m marin.profiling.cli summarize \ --run-target "canary-gpu-22788717705-1" \ --entity marin-community --project marin # TPU canary uv run python -m marin.profiling.cli summarize \ --run-target "canary-tpu-22788717179-1" \ --entity marin-community --project marin ``` #### Verify step annotation scope The `StepTraceAnnotation` is scoped to the compiled step body only. To confirm hooks aren't included in the measured interval, look at the raw trace in the W&B artifact: - Find `name="train"` events with `step_num` in args on the `/host:CPU` process - Their `dur` should be consistent across steps (no periodic spikes on eval/checkpoint hook steps) - The GPU canary shows a tight step time distribution (min 284k, max 325k us, ~14% spread) — consistent with no hook contamination #### Verify TPU fallback did not fire The TPU canary has `step_time.all_steps.count = 60` with a median of ~1,327 us — these are device-side step markers from the `"Steps"` thread (microsecond-scale durations). The GPU canary has `step_time.all_steps.count = 6` with a median of ~303,642 us — these are host-side `StepTraceAnnotation` events (millisecond-scale). The two paths produce different-scale measurements as expected, confirming the fallback only fires on GPU. </details> ## Not addressed - **Trace truncation** (marin-community#3345, secondary): Both canary traces hit exactly 1,000,000 complete events (`suspected_truncation: true`). This is orthogonal — it affects data volume but not the format mismatches fixed here. Worth a separate PR to filter host threads at capture time or increase the cap.
1 parent 3c9e438 commit 77811cb

File tree

4 files changed

+125
-29
lines changed

4 files changed

+125
-29
lines changed

lib/levanter/src/levanter/trainer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -482,13 +482,15 @@ def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]:
482482
hooks_this_time = any(state.step % h.every == 0 for h in self.hooks.jit_hooks)
483483

484484
with capture_time() as step_time:
485-
if hooks_this_time:
486-
result = self._maybe_save_jaxpr("train_step", self._jit_train_step_fn, state, batch, batch_kwargs)
487-
# force the loss so timing numbers are accurate. laziness isn't going to help here (i think?)
488-
else:
489-
result = self._maybe_save_jaxpr(
490-
"train_step_hooks", self._jit_train_step_fn_no_hook, state, batch, batch_kwargs
491-
)
485+
# Annotation scoped to the compiled step only (not hooks/logging below) so
486+
# that GPU host-side step_num timing matches TPU device-side "Steps" semantics.
487+
with jax.profiler.StepTraceAnnotation("train", step_num=int(state.step)):
488+
if hooks_this_time:
489+
result = self._maybe_save_jaxpr("train_step", self._jit_train_step_fn, state, batch, batch_kwargs)
490+
else:
491+
result = self._maybe_save_jaxpr(
492+
"train_step_hooks", self._jit_train_step_fn_no_hook, state, batch, batch_kwargs
493+
)
492494

493495
loss = result.loss.item()
494496

lib/marin/src/marin/profiling/ingest.py

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,15 @@
6363
"psum",
6464
"send",
6565
"recv",
66+
# GPU/NCCL-style (no separators)
67+
"nccl",
68+
"allgather",
69+
"allreduce",
70+
"reducescatter",
6671
)
6772

73+
_DEVICE_OP_THREAD_NAMES = frozenset({"XLA Ops", "Async XLA Ops"})
74+
6875
_STALL_PATTERN = re.compile(
6976
r"wait|barrier|dependency-wait|donation holds|semaphore|acquire|idle|blocked|sleep", re.IGNORECASE
7077
)
@@ -140,6 +147,7 @@ class _CompleteTraceEvent:
140147
run_id: str | None
141148
process_name: str | None
142149
thread_name: str | None
150+
step_num: int | None
143151

144152

145153
@dataclass
@@ -550,6 +558,7 @@ def _parse_complete_events(
550558
run_id=_string_like_arg(event.get("args"), "run_id"),
551559
process_name=process_names.get(pid),
552560
thread_name=thread_names.get((pid, tid)),
561+
step_num=_int_like_arg(event.get("args"), "step_num"),
553562
)
554563
)
555564

@@ -652,6 +661,8 @@ def _make_trace_provenance(events: list[_CompleteTraceEvent], *, trace_sha256: s
652661

653662
def _summarize_step_times(events: list[_CompleteTraceEvent], *, warmup_steps: int) -> StepTimeSummary:
654663
per_step: dict[int, list[float]] = defaultdict(list)
664+
665+
# TPU path: device "Steps" thread with numeric event names.
655666
for event in events:
656667
if not _is_device_event(event):
657668
continue
@@ -663,6 +674,19 @@ def _summarize_step_times(events: list[_CompleteTraceEvent], *, warmup_steps: in
663674
continue
664675
per_step[step].append(event.dur)
665676

677+
# GPU fallback: host-side StepTraceAnnotation events (step_num in args).
678+
# Filter to name="train" on /host:CPU to avoid averaging unrelated spans
679+
# (e.g. device-side events that also carry step_num).
680+
if not per_step:
681+
for event in events:
682+
if event.step_num is None:
683+
continue
684+
if event.name != "train":
685+
continue
686+
if not event.process_name or not event.process_name.startswith("/host:"):
687+
continue
688+
per_step[event.step_num].append(event.dur)
689+
666690
averaged_steps: list[tuple[int, float]] = []
667691
for step, durations in per_step.items():
668692
if not durations:
@@ -823,9 +847,7 @@ def _summarize_hot_ops(
823847
aggregate: dict[str, dict[str, float | int | str | Counter[str] | list[float]]] = {}
824848

825849
for event, exclusive_duration in zip(events, exclusive, strict=True):
826-
if not _is_device_event(event):
827-
continue
828-
if event.thread_name not in {"XLA Ops", "Async XLA Ops"}:
850+
if not _is_device_op_event(event):
829851
continue
830852

831853
bucket = aggregate.setdefault(
@@ -972,12 +994,10 @@ def _summarize_communication(events: list[_CompleteTraceEvent], exclusive: list[
972994
aggregate: dict[str, tuple[int, float]] = {}
973995

974996
for event, duration in zip(events, exclusive, strict=True):
975-
if not _is_device_event(event):
997+
if not _is_device_op_event(event):
976998
continue
977999
if not _is_communication_name(event.name):
9781000
continue
979-
if event.thread_name not in {"XLA Ops", "Async XLA Ops"}:
980-
continue
9811001

9821002
collective = _collective_kind(event.name)
9831003
count, total = aggregate.get(collective, (0, 0.0))
@@ -1000,9 +1020,7 @@ def _summarize_pre_op_gaps(events: list[_CompleteTraceEvent], *, limit: int) ->
10001020

10011021
by_track: dict[tuple[int, int], list[_CompleteTraceEvent]] = defaultdict(list)
10021022
for event in events:
1003-
if not _is_device_event(event):
1004-
continue
1005-
if event.thread_name not in {"XLA Ops", "Async XLA Ops"}:
1023+
if not _is_device_op_event(event):
10061024
continue
10071025
by_track[(event.pid, event.tid)].append(event)
10081026

@@ -1060,9 +1078,7 @@ def _summarize_hierarchical_regions(
10601078
aggregate: dict[str, dict[str, float | int]] = {}
10611079

10621080
for event, exclusive_duration in zip(events, exclusive, strict=True):
1063-
if not _is_device_event(event):
1064-
continue
1065-
if event.thread_name not in {"XLA Ops", "Async XLA Ops"}:
1081+
if not _is_device_op_event(event):
10661082
continue
10671083

10681084
path_parts = _hierarchical_parts(event)
@@ -1140,9 +1156,7 @@ def _summarize_gap_region_contexts(events: list[_CompleteTraceEvent], *, limit:
11401156

11411157
by_track: dict[tuple[int, int], list[_CompleteTraceEvent]] = defaultdict(list)
11421158
for event in events:
1143-
if not _is_device_event(event):
1144-
continue
1145-
if event.thread_name not in {"XLA Ops", "Async XLA Ops"}:
1159+
if not _is_device_op_event(event):
11461160
continue
11471161
by_track[(event.pid, event.tid)].append(event)
11481162

@@ -1562,13 +1576,27 @@ def _is_device_event(event: _CompleteTraceEvent) -> bool:
15621576
return bool(event.process_name and event.process_name.startswith("/device:"))
15631577

15641578

1579+
def _is_device_op_thread(thread_name: str | None) -> bool:
1580+
if thread_name is None:
1581+
return False
1582+
if thread_name in _DEVICE_OP_THREAD_NAMES:
1583+
return True
1584+
if thread_name.startswith("Stream #"):
1585+
return True
1586+
return False
1587+
1588+
1589+
def _is_device_op_event(event: _CompleteTraceEvent) -> bool:
1590+
return _is_device_event(event) and _is_device_op_thread(event.thread_name)
1591+
1592+
15651593
def _collective_kind(name: str) -> str:
15661594
lowered = name.lower()
1567-
if "all-reduce" in lowered or "psum" in lowered:
1595+
if "all-reduce" in lowered or "allreduce" in lowered or "psum" in lowered:
15681596
return "all-reduce"
1569-
if "all-gather" in lowered or "all_gather" in lowered:
1597+
if "all-gather" in lowered or "all_gather" in lowered or "allgather" in lowered:
15701598
return "all-gather"
1571-
if "reduce-scatter" in lowered:
1599+
if "reduce-scatter" in lowered or "reducescatter" in lowered:
15721600
return "reduce-scatter"
15731601
if "all-to-all" in lowered or "alltoall" in lowered:
15741602
return "all-to-all"
@@ -1681,9 +1709,7 @@ def _preferred_region_path_by_op(events: list[_CompleteTraceEvent], *, max_depth
16811709
counters: dict[str, dict[str, int]] = defaultdict(dict)
16821710

16831711
for event in events:
1684-
if not _is_device_event(event):
1685-
continue
1686-
if event.thread_name not in {"XLA Ops", "Async XLA Ops"}:
1712+
if not _is_device_op_event(event):
16871713
continue
16881714
if not event.tf_op:
16891715
continue
@@ -1814,3 +1840,17 @@ def _string_like_arg(args_value: Any, key: str) -> str | None:
18141840
if isinstance(value, (int, float)):
18151841
return str(value)
18161842
return None
1843+
1844+
1845+
def _int_like_arg(args_value: Any, key: str) -> int | None:
1846+
if not isinstance(args_value, dict):
1847+
return None
1848+
value = args_value.get(key)
1849+
if isinstance(value, int):
1850+
return value
1851+
if isinstance(value, str):
1852+
try:
1853+
return int(value)
1854+
except ValueError:
1855+
return None
1856+
return None

lib/marin/src/marin/profiling/semantics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
(
2020
"collective",
2121
re.compile(
22-
r"all-reduce|all_gather|all-gather|reduce-scatter|all-to-all|alltoall|collective",
22+
r"all-reduce|all_gather|all-gather|reduce-scatter|all-to-all|alltoall|collective"
23+
r"|nccl|allgather|allreduce|reducescatter",
2324
re.IGNORECASE,
2425
),
2526
),

tests/profiling/test_profile_summary.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,59 @@ def test_gap_marker_payload_resolution_does_not_cross_second_idle_gap(tmp_path:
338338
assert top_gap.marker_op == "iota.296"
339339

340340

341+
def test_gpu_stream_threads_and_nccl_ops(tmp_path: Path) -> None:
342+
"""GPU traces use 'Stream #N' threads for ops and NCCL naming for collectives.
343+
344+
Step markers come from host-side StepTraceAnnotation events (step_num in args)
345+
rather than the TPU-style 'Steps' thread with numeric event names.
346+
"""
347+
trace_path = tmp_path / "gpu_trace.json.gz"
348+
payload = {
349+
"displayTimeUnit": "ns",
350+
"traceEvents": [
351+
# GPU device process with stream-based threads (no "XLA Ops" thread).
352+
{"ph": "M", "pid": 1, "name": "process_name", "args": {"name": "/device:GPU:0"}},
353+
{"ph": "M", "pid": 1, "tid": 10, "name": "thread_name", "args": {"name": "Stream #0(compute)"}},
354+
{"ph": "M", "pid": 1, "tid": 11, "name": "thread_name", "args": {"name": "Stream #1(nccl)"}},
355+
# Host process with step annotations.
356+
{"ph": "M", "pid": 2, "name": "process_name", "args": {"name": "/host:CPU"}},
357+
{"ph": "M", "pid": 2, "tid": 1, "name": "thread_name", "args": {"name": "python3"}},
358+
# Step annotations on host (as produced by jax.profiler.StepTraceAnnotation).
359+
{"ph": "X", "pid": 2, "tid": 1, "name": "train", "ts": 0, "dur": 500, "args": {"step_num": "0"}},
360+
{"ph": "X", "pid": 2, "tid": 1, "name": "train", "ts": 500, "dur": 400, "args": {"step_num": "1"}},
361+
{"ph": "X", "pid": 2, "tid": 1, "name": "train", "ts": 900, "dur": 350, "args": {"step_num": "2"}},
362+
# Compute ops on Stream #0.
363+
{"ph": "X", "pid": 1, "tid": 10, "name": "fusion.1", "ts": 10, "dur": 100},
364+
{"ph": "X", "pid": 1, "tid": 10, "name": "custom-call.2", "ts": 120, "dur": 80},
365+
# NCCL collective on Stream #1.
366+
{"ph": "X", "pid": 1, "tid": 11, "name": "ncclDevKernel_AllGather_RING_LL", "ts": 200, "dur": 50},
367+
{"ph": "X", "pid": 1, "tid": 11, "name": "ncclDevKernel_ReduceScatter_RING_LL", "ts": 260, "dur": 40},
368+
],
369+
}
370+
with gzip.open(trace_path, "wt", encoding="utf-8") as handle:
371+
json.dump(payload, handle)
372+
373+
summary = summarize_trace(trace_path, warmup_steps=1, hot_op_limit=10)
374+
375+
# Step markers detected via host-side step_num fallback.
376+
assert summary.step_time.all_steps.count == 3
377+
assert summary.step_time.steady_state_steps.count == 2
378+
379+
# Ops from Stream threads are recognized (not empty like the old code would produce).
380+
assert len(summary.hot_ops) > 0
381+
op_names = {op.name for op in summary.hot_ops}
382+
assert "fusion.1" in op_names
383+
384+
# NCCL collectives are classified.
385+
assert len(summary.communication_ops) > 0
386+
collective_kinds = {op.collective for op in summary.communication_ops}
387+
assert "all-gather" in collective_kinds
388+
assert "reduce-scatter" in collective_kinds
389+
390+
# Gap analysis works on stream threads.
391+
assert len(summary.gap_before_ops) > 0
392+
393+
341394
def _write_trace(path: Path, *, step_durations: list[float], softmax_duration: float) -> None:
342395
path.parent.mkdir(parents=True, exist_ok=True)
343396
payload = {

0 commit comments

Comments
 (0)