Skip to content

Commit 94fc381

Browse files
committed
test: #706 reachability repros (reverted in next commit; not for merge)
Runnable repros that reach each path #706 touches and show the dropped ones are unreachable: aten_ops mis-routing, scaled-addmm fail-closed, rms_norm non-last-dim mis-fusion, capture_jax strided-slice aliasing, plus the unreachable SDPA mask/dropout and no-training batchnorm cases. This commit is reverted immediately so it stays referenceable by SHA without entering the squash/merge. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
1 parent e22e099 commit 94fc381

4 files changed

Lines changed: 248 additions & 0 deletions

File tree

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# PR #706 reachability repros
2+
3+
Scripts that reach (or show as unreachable) each code path touched by #706, on the
4+
real capture/import pipeline. This commit is intentionally reverted in the next
5+
commit so the repros are referenceable by SHA but are **not** part of the squash/merge.
6+
7+
## Run
8+
9+
```bash
10+
# PyTorch paths (routing, batchnorm export form, scaled-addmm, SDPA mask, rms_norm)
11+
python python/tests/transpile/tools/repro_706/repro_pytorch.py
12+
13+
# JAX path (strided-slice aliasing)
14+
pip install "jax[cpu]"
15+
python python/tests/transpile/tools/repro_706/repro_jax.py
16+
17+
# origin/main vs this branch, one file swapped at a time
18+
bash python/tests/transpile/tools/repro_706/origin_vs_pr.sh
19+
```
20+
21+
## What each shows
22+
23+
| Path | Kept? | Repro result |
24+
|---|---|---|
25+
| `aten_ops` longest-prefix | keep | origin mis-routes `addmm→add`, `minimum→min`, `maximum→max`, `slice_scatter→slice`, `select_scatter→index`; a `torch.minimum/maximum` model imports to `{minimum,maximum,add}` here |
26+
| `importers` scaled-addmm | keep | `torch.addmm(b,a,c, beta=2.0)` fails closed |
27+
| `fusion/rms_norm` guard | keep | FP16 channel-wise (non-last) RMS fuses into a last-dim kernel with weight `(5,)` for a channel dim of 4 without the guard; guard skips it, last-dim still fuses |
28+
| `capture_jax` strided slice | keep | `x[::2]` → origin aliases full length-6 input (`ops=[]`), branch gives length-3 `slice` |
29+
| `lower.py` | drop | only the non-FP16 branch changes; FP16 engine never executes it |
30+
| `import_semantics` dropout | drop | `dropout_p` always 0 after `model.eval()` |
31+
| `import_semantics` mask | drop | `attrs["mask"]` only set for a non-tensor literal; SDPA `attn_mask` is always Tensor/None (3 forms tested, none set it) |
32+
| `importers` batchnorm no-training | drop | `_native_batch_norm_legit_no_training` only appears after `run_decompositions`, which only feeds CoreML; the IR importer always sees 8-arg `aten.batch_norm.default` |
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/bin/bash
2+
# Origin-vs-PR comparison for PR #706. Run from the repo root with this branch
3+
# checked out. Swaps a single file to origin/main, runs the probe, then restores.
4+
# The cactus package is editable-installed against this tree, so a file swap +
5+
# fresh interpreter is the way to compare; worktrees would all import this tree.
6+
set -euo pipefail
7+
cd "$(git rev-parse --show-toplevel)"
8+
BR=audit/transpiler-lowering
9+
F=python/cactus/transpile
10+
11+
restore() { git checkout -q "$BR" -- "$1"; }
12+
13+
echo "########## aten_ops: ORIGIN routing (expect add/min/max/slice/index) ##########"
14+
git checkout -q origin/main -- "$F/aten_ops.py"
15+
python -c "
16+
from cactus.transpile.normalize import normalize_target
17+
import torch
18+
for op in ['addmm','minimum','maximum','slice_scatter','select_scatter']:
19+
print(' ORIGIN', op, '->', normalize_target(getattr(torch.ops.aten, op).default))"
20+
restore "$F/aten_ops.py"
21+
22+
echo "########## rms_norm: WITH guard, non-last must NOT fuse / last must fuse ##########"
23+
# this branch ships the guard; show it explicitly
24+
python -c "
25+
import torch, torch.nn as nn
26+
from cactus.transpile.capture_pytorch import capture_model
27+
from cactus.transpile.optimize_graph import fuse_rms_norm
28+
def rms(x,d,w,e=1e-6): return x*torch.rsqrt(x.pow(2).mean(d,keepdim=True)+e)*w
29+
class Chan(nn.Module):
30+
def __init__(s): super().__init__(); s.w=nn.Parameter(torch.ones(4,1,1,dtype=torch.float16))
31+
def forward(s,x): return rms(x,1,s.w)
32+
class Last(nn.Module):
33+
def __init__(s): super().__init__(); s.w=nn.Parameter(torch.ones(8,dtype=torch.float16))
34+
def forward(s,x): return rms(x,-1,s.w)
35+
for name,m,ex in [('non-last',Chan(),torch.randn(2,4,5,5,dtype=torch.float16)),('last',Last(),torch.randn(2,3,8,dtype=torch.float16))]:
36+
g=capture_model(m.half().eval(),(ex,)).ir_graph
37+
ch=fuse_rms_norm(g); rn=sum(1 for n in g.order if g.nodes[n].op=='rms_norm')
38+
print(' WITH-GUARD',name,'fused=',ch,'rms_norm=',rn)"
39+
40+
echo "########## capture_jax: ORIGIN aliases strided slice (expect ops=[] shape=(6,)) ##########"
41+
git checkout -q origin/main -- "$F/capture_jax.py"
42+
python python/tests/transpile/tools/repro_706/repro_jax.py 2>/dev/null || echo " (needs jax[cpu])"
43+
restore "$F/capture_jax.py"
44+
45+
echo "########## batch_norm: run_decompositions turns it into the no_training op ##########"
46+
python -c "
47+
import torch, torch.nn as nn
48+
m=nn.BatchNorm2d(4).eval()
49+
ep=torch.export.export(m,(torch.randn(1,4,8,8),))
50+
print(' default export :', [str(n.target) for n in ep.graph.nodes if 'batch_norm' in str(n.target)])
51+
ep2=ep.run_decompositions()
52+
print(' after decompose :', [str(n.target) for n in ep2.graph.nodes if 'batch_norm' in str(n.target)])
53+
print(' (decompose only runs on the NPU->CoreML path, never into import_captured_to_ir)')"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Reachability repro for PR #706 (JAX capture path).
2+
3+
pip install "jax[cpu]"
4+
python python/tests/transpile/tools/repro_706/repro_jax.py
5+
6+
Shows the stride-aware slice-aliasing fix in capture_jax.py: a strided slice
7+
(x[::2]) must not be treated as a no-op alias.
8+
"""
9+
from __future__ import annotations
10+
import numpy as np
11+
import jax, jax.numpy as jnp
12+
13+
from cactus.transpile.capture_jax import capture_jax_function
14+
15+
16+
def describe(tag, fn, args):
17+
ir = capture_jax_function(fn, args)
18+
ops = [ir.nodes[n].op for n in ir.order]
19+
out_id, in_id = ir.outputs[0], ir.inputs[0]
20+
out_shape = getattr(ir.values.get(out_id), "shape", None)
21+
print(f"[{tag}] ops={ops} out_shape={out_shape} output_is_input_alias={out_id == in_id}")
22+
23+
24+
# Strided slice over a length-6 vector -> expected length 3.
25+
# origin: stride ignored in changed_axes -> start=0,limit=full -> aliased no-op (WRONG).
26+
# this branch: stride!=1 -> a real slice node, length 3.
27+
x = jnp.arange(6.0)
28+
print("eager jax x[::2] =", np.asarray(jax.jit(lambda v: v[::2])(x)))
29+
describe("strided-slice x[::2]", lambda v: v[::2], (x,))
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""Reachability repros for PR #706 (PyTorch capture/import path).
2+
3+
Run from the repo root with the branch checked out:
4+
5+
python python/tests/transpile/tools/repro_706/repro_pytorch.py
6+
7+
Each section reaches a code path the PR touches (or shows it is unreachable),
8+
on the *current* checkout. Pair with origin_vs_pr.sh to see the origin behaviour.
9+
"""
10+
from __future__ import annotations
11+
import traceback
12+
from collections import Counter
13+
import torch, torch.nn as nn, torch.nn.functional as F
14+
15+
from cactus.transpile.capture_pytorch import capture_model
16+
from cactus.transpile.optimize_graph import fuse_rms_norm
17+
from cactus.transpile.normalize import normalize_target
18+
from cactus.transpile.import_semantics import apply_import_semantics
19+
20+
21+
def opc(g):
22+
return Counter(g.nodes[n].op for n in g.order)
23+
24+
25+
def line(t):
26+
print(f"\n========== {t} ==========")
27+
28+
29+
# ---- 1. aten_ops longest-prefix routing (KEEP: reached by natural ops) ----
30+
line("1. aten_ops routing (this branch)")
31+
for op in ["addmm", "minimum", "maximum", "slice_scatter", "select_scatter"]:
32+
t = getattr(torch.ops.aten, op).default
33+
print(f" normalize_target(aten.{op}) -> {normalize_target(t)}")
34+
35+
class MinMax(nn.Module):
36+
def forward(self, a, b):
37+
return torch.minimum(a, b) + torch.maximum(a, b)
38+
c = capture_model(MinMax(), (torch.randn(3), torch.randn(3)))
39+
print(" MinMax model IR ops:", dict(opc(c.ir_graph)))
40+
41+
42+
# ---- 2. BatchNorm: which aten op does the IR import path actually see? ----
43+
# REMOVED from PR: the no_training (7-arg) branch is unreachable here because
44+
# the capture path never runs run_decompositions (that only happens on the
45+
# NPU->CoreML path). torch.export emits the 8-arg aten.batch_norm.default.
46+
line("2. eval-mode BatchNorm export form")
47+
class ConvBN(nn.Module):
48+
def __init__(self):
49+
super().__init__()
50+
self.conv = nn.Conv2d(3, 4, 3, padding=1)
51+
self.bn = nn.BatchNorm2d(4)
52+
def forward(self, x):
53+
return self.bn(self.conv(x))
54+
ep = torch.export.export(ConvBN().eval(), (torch.randn(1, 3, 8, 8),))
55+
print(" exported BN target(s):", [str(n.target) for n in ep.graph.nodes if "batch_norm" in str(n.target)])
56+
print(" (run_decompositions would instead emit _native_batch_norm_legit_no_training -> only feeds CoreML)")
57+
58+
59+
# ---- 3. scaled addmm fail-closed (KEEP: reached) ----
60+
line("3. scaled addmm (beta=2.0)")
61+
class ScaledAddmm(nn.Module):
62+
def forward(self, bias, a, b):
63+
return torch.addmm(bias, a, b, beta=2.0)
64+
try:
65+
capture_model(ScaledAddmm(), (torch.randn(4), torch.randn(4, 5), torch.randn(5, 4)))
66+
print(" imported WITHOUT failing (origin behaviour: silently drops beta)")
67+
except Exception as e:
68+
print(" fail-closed OK:", type(e).__name__, "->", str(e).split(":")[-1].strip()[:80])
69+
70+
71+
# ---- 4. SDPA literal mask -> attrs['mask']? (REMOVED: never produced) ----
72+
line("4. SDPA mask -> attrs['mask']? (3 natural forms)")
73+
def check(name, mod, args):
74+
g = capture_model(mod, args).ir_graph
75+
apply_import_semantics(g)
76+
for n in g.order:
77+
nd = g.nodes[n]
78+
if nd.op == "attention":
79+
print(f" [{name}] attrs={sorted(nd.attrs)} mask_in_attrs={'mask' in nd.attrs} n_inputs={len(nd.inputs)}")
80+
return
81+
print(f" [{name}] no attention node")
82+
83+
class MaskTensorArg(nn.Module):
84+
def forward(self, q, k, v, m):
85+
return F.scaled_dot_product_attention(q, k, v, attn_mask=m)
86+
check("tensor-mask-as-input", MaskTensorArg(), (torch.randn(1, 1, 4, 8),) * 3 + (torch.zeros(1, 1, 4, 4),))
87+
88+
class MaskInlineConst(nn.Module):
89+
def forward(self, q, k, v):
90+
return F.scaled_dot_product_attention(q, k, v, attn_mask=torch.zeros(4, 4))
91+
check("inline-const-mask", MaskInlineConst(), (torch.randn(1, 1, 4, 8),) * 3)
92+
93+
class CausalFlag(nn.Module):
94+
def forward(self, q, k, v):
95+
return F.scaled_dot_product_attention(q, k, v, is_causal=True)
96+
check("is_causal-flag", CausalFlag(), (torch.randn(1, 1, 4, 8),) * 3)
97+
98+
99+
# ---- 6. rms_norm guard (KEEP: non-last-dim mis-fuses without it) ----
100+
# NOTE: fuse_rms_norm skips FP32 inputs (kept unfused on purpose), so the inputs
101+
# must be FP16 to reach the fusion -- as they are in real transpilation.
102+
def rms(x, dim, w, eps=1e-6):
103+
return x * torch.rsqrt(x.pow(2).mean(dim, keepdim=True) + eps) * w
104+
105+
line("6a. LAST-dim RMSNorm, FP16")
106+
class RMSLast(nn.Module):
107+
def __init__(self):
108+
super().__init__()
109+
self.w = nn.Parameter(torch.ones(8, dtype=torch.float16))
110+
def forward(self, x):
111+
return rms(x, -1, self.w)
112+
g = capture_model(RMSLast().half().eval(), (torch.randn(2, 3, 8, dtype=torch.float16),)).ir_graph
113+
changed = fuse_rms_norm(g)
114+
print(" fuse_rms_norm changed?", changed, "| rms_norm nodes:", opc(g).get("rms_norm", 0))
115+
116+
line("6b. NON-last-dim RMSNorm (channel dim=1), FP16")
117+
class RMSChan(nn.Module):
118+
def __init__(self):
119+
super().__init__()
120+
self.w = nn.Parameter(torch.ones(4, 1, 1, dtype=torch.float16))
121+
def forward(self, x):
122+
return rms(x, 1, self.w)
123+
g = capture_model(RMSChan().half().eval(), (torch.randn(2, 4, 5, 5, dtype=torch.float16),)).ir_graph
124+
print(" mean axes:", [(n, g.nodes[n].attrs.get("axis")) for n in g.order if g.nodes[n].op == "mean"])
125+
changed = fuse_rms_norm(g)
126+
for n in g.order:
127+
if g.nodes[n].op == "rms_norm":
128+
wv = g.values.get(g.nodes[n].inputs[1])
129+
print(" fused rms_norm weight shape:", getattr(wv, "shape", None), "(channel=4, last dim=5)")
130+
print(" fuse_rms_norm changed?", changed,
131+
"| this branch ships the guard so changed=False (correct);",
132+
"checkout origin/main rms_norm.py to see changed=True with weight (5,)")
133+
134+
print("\n[done]")

0 commit comments

Comments
 (0)