Skip to content

Clear _tracer_cls after lowering to prevent save_compiled/load_compiled crash#286

Open
menglcai wants to merge 1 commit into
masterfrom
menglcai/symbolically_error
Open

Clear _tracer_cls after lowering to prevent save_compiled/load_compiled crash#286
menglcai wants to merge 1 commit into
masterfrom
menglcai/symbolically_error

Conversation

@menglcai
Copy link
Copy Markdown

@menglcai menglcai commented May 26, 2026

Descriptions

torch.compile with save_compiled and load_compiled crashes on load for pure Conv models (e.g. SqueezeNet) with:

  .....
  File "/opt/venv/lib/python3.12/site-packages/torch/fx/proxy.py", line 439, in to_bool
    raise TraceError(
torch._dynamo.exc.BackendCompilerFailed: backend='migraphx' raised:
TraceError: symbolically traced variables cannot be used as inputs to control flow

while models with Linear layers (e.g. AlexNet, ResNet) work fine.

Analysis

torch.save serializes _tracer_cls = PythonKeyTracer (torch.fx.experimental.proxy_tensor.PythonKeyTracer) alongside the compiled model. On torch.load, PyTorch uses this stored tracer to re-trace the graph. PythonKeyTracer.call_module unconditionally calls forward() on every submodule (see torch/fx/experimental/proxy_tensor.py#L1364-L1371) — including MGXModule. Since MGXModule.forward() branches on tensor shape, and tensors are symbolic during tracing, a TraceError is raised.

Why TinyLinear / AlexNet / ResNet pass?

const_fold calls split_const_subgraphs, which — when const-foldable ops exist (e.g. t(weight) from Linear layers) — creates a brand-new FoldedGraphModule backed by a brand-new graph (see torch/fx/experimental/const_fold.py#L388-L393 ). The new graph has _tracer_cls = None, so _deserialize_graph_module falls back to the safe base Tracer.

Why TinyConv/ squeezenet fail?

Pure-conv models have no const-foldable ops, so split_const_subgraphs hits the early return (see torch/fx/experimental/const_fold.py#L279-L280) and reuses the original mod_traced.graph — with _tracer_cls = PythonKeyTracer still attached — which is then written into the saved file.

Solution

This PR clear _tracer_cls on the lowered GraphModule before it is returned (and subsequently saved). This ensures _deserialize_graph_module falls back to the base Tracer (see torch/fx/graph_module.py#L183-L187 ), which respects is_leaf_module and treats MGXModule as a leaf.

# torch_migraphx/dynamo/lower_dynamo.py

    lowered_gm = post_lowering_pass(optim_gm)
+   lowered_gm._tracer_cls = None
    return lowered_gm

Steps to reproduce

Run the following test code to reproduce the issue. Models with Linear layers (e.g. TinyLinear, AlexNet, ResNet) are not affected, but pure Conv models (e.g. TinyConv, SqueezeNet) crash on load.

import torch
import torch_migraphx


class TinyConv(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 8, 3, padding=1)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.relu(self.conv(x))


class TinyLinear(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 8, 3, padding=1)
        self.relu = torch.nn.ReLU()
        self.fc   = torch.nn.Linear(8 * 32 * 32, 10)

    def forward(self, x):
        x = self.relu(self.conv(x))
        return self.fc(x.flatten(1))


@torch.no_grad
def run_test(model_cls, path, bypass_const_fold=False):
    x = torch.randn(1, 3, 32, 32, device="cuda")

    # save compiled
    model = torch.compile(
        model_cls().to("cuda").eval(),
        backend="migraphx",
        options={"deallocate": True, "save_compiled": path},
        dynamic=False,
    )
    _ = model(x)
    print(f"[{model_cls.__name__}] Save OK")

    # load compiled
    model = torch.compile(
        model_cls().to("cuda").eval(),
        backend="migraphx",
        options={"deallocate": True, "load_compiled": path},
        dynamic=False,
    )
    _ = model(x)
    print(f"[{model_cls.__name__}] Load OK")


if __name__ == "__main__":
    run_test(TinyLinear, "test_linear.mgx", bypass_const_fold=False)    # Pass
    run_test(TinyConv,   "test_conv.mgx",   bypass_const_fold=True)     # Fail

@menglcai menglcai marked this pull request as ready for review May 26, 2026 05:55
@menglcai menglcai requested a review from shivadbhavsar as a code owner May 26, 2026 05:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant