Skip to content

Extract executor code#1

Draft
eric-czech wants to merge 3 commits intomainfrom
executor-extraction
Draft

Extract executor code#1
eric-czech wants to merge 3 commits intomainfrom
executor-extraction

Conversation

@eric-czech
Copy link
Copy Markdown
Member

This contains deletions necessary to move Executor logic to https://github.com/Open-Athena/thalas (or wherever it ultimately ends up).

eric-czech pushed a commit that referenced this pull request Mar 30, 2026
## Summary

Fixes marin-community#3345. The profiling summary pipeline produced empty `hot_ops`,
`communication_ops`, `gap_before_ops`, and `step_time` on GPU traces
because of three independent mismatches between the ingestion code
(written for TPU traces) and the actual GPU/NCCL trace format.

| Cause | Before | After |
|-------|--------|-------|
| Thread filter mismatch | 6 functions gate on `"XLA Ops"` — GPU uses
`Stream #N(...)` | New `_is_device_op_event()` matches both |
| Comm op naming | `_COMM_PATTERNS` misses
`ncclDevKernel_AllGather_RING_LL` etc. | Added `nccl`, `allgather`,
`allreduce`, `reducescatter` |
| No step markers | TPU uses `"Steps"` thread with numeric names — GPU
has neither | `StepTraceAnnotation` in trainer + host-side `step_num`
fallback in ingest |

## Changes

### `ingest.py` — thread and op recognition

<details><summary>Cause 1: Replace 6 hardcoded thread checks with
<code>_is_device_op_event()</code></summary>

The old code filtered device ops with:
```python
if event.thread_name not in {"XLA Ops", "Async XLA Ops"}:
    continue
```

GPU traces use stream-based thread names like `Stream #0(compute)`,
`Stream #1(nccl)`, so every device op was silently dropped.

New predicate:
```python
_DEVICE_OP_THREAD_NAMES = frozenset({"XLA Ops", "Async XLA Ops"})

def _is_device_op_thread(thread_name: str | None) -> bool:
    if thread_name is None:
        return False
    if thread_name in _DEVICE_OP_THREAD_NAMES:
        return True
    if thread_name.startswith("Stream #"):
        return True
    return False

def _is_device_op_event(event: _CompleteTraceEvent) -> bool:
    return _is_device_event(event) and _is_device_op_thread(event.thread_name)
```

Updated call sites: `_summarize_hot_ops`, `_summarize_communication`,
`_summarize_pre_op_gaps`, `_summarize_hierarchical_regions`,
`_summarize_gap_region_contexts`, `_preferred_region_path_by_op`.

</details>

<details><summary>Cause 2: NCCL collective classification</summary>

Added to `_COMM_PATTERNS`: `"nccl"`, `"allgather"`, `"allreduce"`,
`"reducescatter"`.

Updated `_collective_kind()` to normalize unseparated NCCL names (e.g.
`ncclDevKernel_AllGather_RING_LL` → `"all-gather"`).

Updated `semantics.py` `collective` regex to match the same patterns for
family classification.

</details>

<details><summary>Cause 3: Host-side step markers</summary>

**Capture side** (`trainer.py`): Wrapped the compiled step body (the
`_maybe_save_jaxpr` call) in `jax.profiler.StepTraceAnnotation("train",
step_num=int(state.step))`. The annotation is scoped to the compiled
step only — hooks, logging, and tracker calls happen outside — so the
measured interval matches TPU device-side `"Steps"` semantics.

**Ingest side** (`ingest.py`): Added `step_num: int | None` to
`_CompleteTraceEvent`. When the TPU-style `"Steps"` thread produces no
results, falls back to host-side events filtered to `name == "train"` on
`/host:*` processes:

```python
if not per_step:
    for event in events:
        if event.step_num is None:
            continue
        if event.name != "train":
            continue
        if not event.process_name or not event.process_name.startswith("/host:"):
            continue
        per_step[event.step_num].append(event.dur)
```

The fallback only fires when no TPU-style step markers exist, so
existing TPU behavior is unchanged.

</details>

### `semantics.py` — collective family regex

Extended the `collective` family pattern to match NCCL naming (`nccl`,
`allgather`, `allreduce`, `reducescatter`).

### `trainer.py` — step annotation

Wrapped the compiled step body in
`jax.profiler.StepTraceAnnotation("train", step_num=...)` inside
`train_step()`. Scoped narrowly to exclude hooks/logging so GPU step
timing is comparable to TPU device-side timing.

## Test plan

### Unit tests

All 12 tests pass (11 existing TPU + 1 new GPU):

```
tests/profiling/test_profile_summary.py    12 passed
```

**New: `test_gpu_stream_threads_and_nccl_ops`** — synthetic GPU trace
with `Stream #N` threads, NCCL kernel names, and host-side `step_num`
events. Asserts:
- `step_time.all_steps.count == 3` (host-side fallback works)
- `hot_ops` contains `fusion.1` (stream threads recognized)
- `communication_ops` contains `all-gather` and `reduce-scatter` (NCCL
classified)
- `gap_before_ops` is non-empty (gap analysis on stream threads)

### Pre-merge canary runs (end-to-end)

Both canary workflows triggered on this branch via `workflow_dispatch`.
Both passed.

| Canary | Workflow run | Result |
|--------|-------------|--------|
| **TPU** (v5p-8, Qwen3 30M, 1B tokens) |
[#22788717179](https://github.com/marin-community/marin/actions/runs/22788717179)
| **Passed** — no regression |
| **GPU** (8xH100 CW, Llama 150M, 1B tokens) |
[#22788717705](https://github.com/marin-community/marin/actions/runs/22788717705)
| **Passed** — all fields now populated |

W&B runs:
[`canary-tpu-22788717179-1`](https://wandb.ai/marin-community/marin/runs/canary-tpu-22788717179-1),
[`canary-gpu-22788717705-1`](https://wandb.ai/marin-community/marin/runs/canary-gpu-22788717705-1)

<details><summary>GPU canary results — before vs after</summary>

Profile summaries downloaded from W&B and re-summarized locally with the
branch code.

#### GPU: all three causes fixed

| Metric | Before (marin-community#3345) | After (this PR) |
|--------|---------------|-----------------|
| `hot_ops` | **0** | **25** |
| `communication_ops` | **0** | **4 collective types, 1,208 events** |
| `gap_before_ops` | **0** | **238** |
| `step_time.all_steps.count` | **0** | **6** (median 303,642 us) |
| `time_breakdown.communication` | 0.04% (misclassified) | **1.08%** |
| `time_breakdown.compute` | 22% (inflated by NCCL) | **16.0%** |

Top 5 hot ops (GPU):
| Op | Exclusive duration (us) | Count |
|----|------------------------|-------|
| `sm90_xmma_gemm_f32f32_tf32f32_f32_nt_n_...cublas` | 1,730,500 | 768 |
| `input_scatter_fusion_1` | 1,576,241 | 48 |
| `loop_multiply_fusion_6` | 1,095,015 | 768 |
| `sm90_xmma_gemm_f32f32_tf32f32_f32_tn_n_...cublas` | 1,020,830 | 768 |
| `nvjet_tss_192x192_64x3_1x2_h_bz_coopB_NNN` | 965,529 | 768 |

Communication ops (GPU):
| Collective | Count | Total duration (us) |
|-----------|-------|-------------------|
| `all-reduce` | 104 | 594,030 |
| `reduce-scatter` | 336 | 115,259 |
| `all-gather` | 672 | 63,331 |
| `send-recv` | 96 | 22,926 |

Top 3 pre-op gaps (GPU):
| Op | Total gap (us) | Count |
|----|---------------|-------|
| `MemcpyH2D` | 15,567,349 | 720 |
| `ncclDevKernel_ReduceScatter_Sum_bf16_RING_LL(...)` | 10,286,880 | 336
|
| `MemcpyD2H` | 6,191,228 | 142 |

Step timing (GPU):
| Stat | Value (us) |
|------|-----------|
| count | 6 |
| min | 284,241 |
| median | 303,642 |
| mean | 303,200 |
| max | 324,602 |
| p90 | 314,282 |

</details>

<details><summary>TPU canary results — no regression</summary>

| Metric | After (this PR) |
|--------|-----------------|
| `hot_ops` | **25** (fusion.*, copy.*, reshape.*) |
| `communication_ops` | **4 types** (all-reduce: 520, all-gather: 2,040,
all-to-all: 80, async-collective: 6,240) |
| `gap_before_ops` | **461** |
| `step_time.all_steps.count` | **60** (via TPU-native `"Steps"` thread
— host-side fallback did not fire) |

Top 5 hot ops (TPU):
| Op | Exclusive duration (us) | Count |
|----|------------------------|-------|
| `fusion.492` | 2,858,184 | 640 |
| `copy.319` | 1,143,337 | 640 |
| `fusion.483` | 1,016,882 | 640 |
| `reshape.436` | 836,632 | 640 |
| `fusion.481` | 828,741 | 640 |

Time breakdown (TPU):
| Category | Share |
|----------|-------|
| Compute | 44.6% |
| Stall | 50.6% |
| Host | 4.5% |
| Communication | 0.3% |

</details>

<details><summary>How to reproduce / deeply inspect</summary>

#### Re-summarize from W&B artifacts

```bash
# GPU canary
uv run python -m marin.profiling.cli summarize \
  --run-target "canary-gpu-22788717705-1" \
  --entity marin-community --project marin

# TPU canary
uv run python -m marin.profiling.cli summarize \
  --run-target "canary-tpu-22788717179-1" \
  --entity marin-community --project marin
```

#### Verify step annotation scope

The `StepTraceAnnotation` is scoped to the compiled step body only. To
confirm hooks aren't included in the measured interval, look at the raw
trace in the W&B artifact:
- Find `name="train"` events with `step_num` in args on the `/host:CPU`
process
- Their `dur` should be consistent across steps (no periodic spikes on
eval/checkpoint hook steps)
- The GPU canary shows a tight step time distribution (min 284k, max
325k us, ~14% spread) — consistent with no hook contamination

#### Verify TPU fallback did not fire

The TPU canary has `step_time.all_steps.count = 60` with a median of
~1,327 us — these are device-side step markers from the `"Steps"` thread
(microsecond-scale durations). The GPU canary has
`step_time.all_steps.count = 6` with a median of ~303,642 us — these are
host-side `StepTraceAnnotation` events (millisecond-scale). The two
paths produce different-scale measurements as expected, confirming the
fallback only fires on GPU.

</details>

## Not addressed

- **Trace truncation** (marin-community#3345, secondary): Both canary traces hit
exactly 1,000,000 complete events (`suspected_truncation: true`). This
is orthogonal — it affects data volume but not the format mismatches
fixed here. Worth a separate PR to filter host threads at capture time
or increase the cap.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants