|
| 1 | +"""Reachability repros for PR #706 (PyTorch capture/import path). |
| 2 | +
|
| 3 | +Run from the repo root with the branch checked out: |
| 4 | +
|
| 5 | + python python/tests/transpile/tools/repro_706/repro_pytorch.py |
| 6 | +
|
| 7 | +Each section reaches a code path the PR touches (or shows it is unreachable), |
| 8 | +on the *current* checkout. Pair with origin_vs_pr.sh to see the origin behaviour. |
| 9 | +""" |
| 10 | +from __future__ import annotations |
| 11 | +import traceback |
| 12 | +from collections import Counter |
| 13 | +import torch, torch.nn as nn, torch.nn.functional as F |
| 14 | + |
| 15 | +from cactus.transpile.capture_pytorch import capture_model |
| 16 | +from cactus.transpile.optimize_graph import fuse_rms_norm |
| 17 | +from cactus.transpile.normalize import normalize_target |
| 18 | +from cactus.transpile.import_semantics import apply_import_semantics |
| 19 | + |
| 20 | + |
| 21 | +def opc(g): |
| 22 | + return Counter(g.nodes[n].op for n in g.order) |
| 23 | + |
| 24 | + |
| 25 | +def line(t): |
| 26 | + print(f"\n========== {t} ==========") |
| 27 | + |
| 28 | + |
| 29 | +# ---- 1. aten_ops longest-prefix routing (KEEP: reached by natural ops) ---- |
| 30 | +line("1. aten_ops routing (this branch)") |
| 31 | +for op in ["addmm", "minimum", "maximum", "slice_scatter", "select_scatter"]: |
| 32 | + t = getattr(torch.ops.aten, op).default |
| 33 | + print(f" normalize_target(aten.{op}) -> {normalize_target(t)}") |
| 34 | + |
| 35 | +class MinMax(nn.Module): |
| 36 | + def forward(self, a, b): |
| 37 | + return torch.minimum(a, b) + torch.maximum(a, b) |
| 38 | +c = capture_model(MinMax(), (torch.randn(3), torch.randn(3))) |
| 39 | +print(" MinMax model IR ops:", dict(opc(c.ir_graph))) |
| 40 | + |
| 41 | + |
| 42 | +# ---- 2. BatchNorm: which aten op does the IR import path actually see? ---- |
| 43 | +# REMOVED from PR: the no_training (7-arg) branch is unreachable here because |
| 44 | +# the capture path never runs run_decompositions (that only happens on the |
| 45 | +# NPU->CoreML path). torch.export emits the 8-arg aten.batch_norm.default. |
| 46 | +line("2. eval-mode BatchNorm export form") |
| 47 | +class ConvBN(nn.Module): |
| 48 | + def __init__(self): |
| 49 | + super().__init__() |
| 50 | + self.conv = nn.Conv2d(3, 4, 3, padding=1) |
| 51 | + self.bn = nn.BatchNorm2d(4) |
| 52 | + def forward(self, x): |
| 53 | + return self.bn(self.conv(x)) |
| 54 | +ep = torch.export.export(ConvBN().eval(), (torch.randn(1, 3, 8, 8),)) |
| 55 | +print(" exported BN target(s):", [str(n.target) for n in ep.graph.nodes if "batch_norm" in str(n.target)]) |
| 56 | +print(" (run_decompositions would instead emit _native_batch_norm_legit_no_training -> only feeds CoreML)") |
| 57 | + |
| 58 | + |
| 59 | +# ---- 3. scaled addmm fail-closed (KEEP: reached) ---- |
| 60 | +line("3. scaled addmm (beta=2.0)") |
| 61 | +class ScaledAddmm(nn.Module): |
| 62 | + def forward(self, bias, a, b): |
| 63 | + return torch.addmm(bias, a, b, beta=2.0) |
| 64 | +try: |
| 65 | + capture_model(ScaledAddmm(), (torch.randn(4), torch.randn(4, 5), torch.randn(5, 4))) |
| 66 | + print(" imported WITHOUT failing (origin behaviour: silently drops beta)") |
| 67 | +except Exception as e: |
| 68 | + print(" fail-closed OK:", type(e).__name__, "->", str(e).split(":")[-1].strip()[:80]) |
| 69 | + |
| 70 | + |
| 71 | +# ---- 4. SDPA literal mask -> attrs['mask']? (REMOVED: never produced) ---- |
| 72 | +line("4. SDPA mask -> attrs['mask']? (3 natural forms)") |
| 73 | +def check(name, mod, args): |
| 74 | + g = capture_model(mod, args).ir_graph |
| 75 | + apply_import_semantics(g) |
| 76 | + for n in g.order: |
| 77 | + nd = g.nodes[n] |
| 78 | + if nd.op == "attention": |
| 79 | + print(f" [{name}] attrs={sorted(nd.attrs)} mask_in_attrs={'mask' in nd.attrs} n_inputs={len(nd.inputs)}") |
| 80 | + return |
| 81 | + print(f" [{name}] no attention node") |
| 82 | + |
| 83 | +class MaskTensorArg(nn.Module): |
| 84 | + def forward(self, q, k, v, m): |
| 85 | + return F.scaled_dot_product_attention(q, k, v, attn_mask=m) |
| 86 | +check("tensor-mask-as-input", MaskTensorArg(), (torch.randn(1, 1, 4, 8),) * 3 + (torch.zeros(1, 1, 4, 4),)) |
| 87 | + |
| 88 | +class MaskInlineConst(nn.Module): |
| 89 | + def forward(self, q, k, v): |
| 90 | + return F.scaled_dot_product_attention(q, k, v, attn_mask=torch.zeros(4, 4)) |
| 91 | +check("inline-const-mask", MaskInlineConst(), (torch.randn(1, 1, 4, 8),) * 3) |
| 92 | + |
| 93 | +class CausalFlag(nn.Module): |
| 94 | + def forward(self, q, k, v): |
| 95 | + return F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| 96 | +check("is_causal-flag", CausalFlag(), (torch.randn(1, 1, 4, 8),) * 3) |
| 97 | + |
| 98 | + |
| 99 | +# ---- 6. rms_norm guard (KEEP: non-last-dim mis-fuses without it) ---- |
| 100 | +# NOTE: fuse_rms_norm skips FP32 inputs (kept unfused on purpose), so the inputs |
| 101 | +# must be FP16 to reach the fusion -- as they are in real transpilation. |
| 102 | +def rms(x, dim, w, eps=1e-6): |
| 103 | + return x * torch.rsqrt(x.pow(2).mean(dim, keepdim=True) + eps) * w |
| 104 | + |
| 105 | +line("6a. LAST-dim RMSNorm, FP16") |
| 106 | +class RMSLast(nn.Module): |
| 107 | + def __init__(self): |
| 108 | + super().__init__() |
| 109 | + self.w = nn.Parameter(torch.ones(8, dtype=torch.float16)) |
| 110 | + def forward(self, x): |
| 111 | + return rms(x, -1, self.w) |
| 112 | +g = capture_model(RMSLast().half().eval(), (torch.randn(2, 3, 8, dtype=torch.float16),)).ir_graph |
| 113 | +changed = fuse_rms_norm(g) |
| 114 | +print(" fuse_rms_norm changed?", changed, "| rms_norm nodes:", opc(g).get("rms_norm", 0)) |
| 115 | + |
| 116 | +line("6b. NON-last-dim RMSNorm (channel dim=1), FP16") |
| 117 | +class RMSChan(nn.Module): |
| 118 | + def __init__(self): |
| 119 | + super().__init__() |
| 120 | + self.w = nn.Parameter(torch.ones(4, 1, 1, dtype=torch.float16)) |
| 121 | + def forward(self, x): |
| 122 | + return rms(x, 1, self.w) |
| 123 | +g = capture_model(RMSChan().half().eval(), (torch.randn(2, 4, 5, 5, dtype=torch.float16),)).ir_graph |
| 124 | +print(" mean axes:", [(n, g.nodes[n].attrs.get("axis")) for n in g.order if g.nodes[n].op == "mean"]) |
| 125 | +changed = fuse_rms_norm(g) |
| 126 | +for n in g.order: |
| 127 | + if g.nodes[n].op == "rms_norm": |
| 128 | + wv = g.values.get(g.nodes[n].inputs[1]) |
| 129 | + print(" fused rms_norm weight shape:", getattr(wv, "shape", None), "(channel=4, last dim=5)") |
| 130 | +print(" fuse_rms_norm changed?", changed, |
| 131 | + "| this branch ships the guard so changed=False (correct);", |
| 132 | + "checkout origin/main rms_norm.py to see changed=True with weight (5,)") |
| 133 | + |
| 134 | +print("\n[done]") |
0 commit comments