Skip to content

[Bug] Segfault in tvm.compile (Relax, target=llvm) during TIR pass InjectPTXLDG32 / PTXRewriter::VisitStmt_(BufferStore) even though target is CPU-only #18599

@tinywisdom

Description

@tinywisdom

Summary

Compiling a Relax IRModule converted from a PyTorch torch.export program crashes with a segmentation fault inside TVM’s TIR pass pipeline, specifically in:

  • tvm::tir::transform::InjectPTXLDG32(bool)
  • tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)
  • tvm::tir::BufferStore::BufferStore(...)

This occurs while invoking tvm.compile(...) with:

  • target = tvm.target.Target("llvm") (CPU-only)
  • tir_pipeline = tir.get_default_tir_pipeline(target)
  • relax_pipeline = "default"
  • PassContext.config includes "tir.ptx_ldg32": 1 plus several other flags

Even though the target is LLVM CPU, the stack trace indicates a PTX-specific pass / rewriter is running and then segfaulting.

This is not a Python exception; it is a hard crash (Segmentation fault (core dumped)), so it likely indicates a bug in pass gating / pipeline selection / or an unsafe assumption in the InjectPTXLDG32 pass when used under this pipeline.

Environment

From the repro output:

  • TVM version: 0.22.0
  • TVM git commit: 9dbf3f22ff6f44962472f9af310fda368ca85ef2
  • LLVM: 17.0.6
  • PyTorch: 2.9.0+cu128
  • Python: 3.10.16 (inferred from stack paths)
  • NumPy: 2.2.6 (printed as “Python version” in script; see note below)
  • OS: Linux x86_64

Minimal Repro Script

import random
import numpy as np
import torch
import torch.nn as nn
import tvm
from tvm import tir


def print_env_info():
    print("==== Environment Info ====")
    print("TVM version:", getattr(tvm, "__version__", "unknown"))
    print("TVM git commit:", tvm.support.libinfo()["GIT_COMMIT_HASH"])
    print("TVM LLVM version:", tvm.support.libinfo().get("LLVM_VERSION", "unknown"))

    print("NumPy version:", np.__version__)
    print("PyTorch version:", torch.__version__)
    print("==========================\n")


def set_seed(seed: int = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(100, 64, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh(),
            nn.Conv2d(3, 8, 3, 1, 1, bias=False),
            nn.BatchNorm2d(8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(8, 1, 1, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        y = self.net(x)
        return y.reshape(-1)


def export_to_relax(mod: nn.Module, x: torch.Tensor) -> tvm.IRModule:
    mod = mod.to("cpu").eval()
    x = x.to("cpu")
    ep = torch.export.export(mod, (x,))
    from tvm.relax.frontend.torch import from_exported_program
    return from_exported_program(ep)


def main():
    print_env_info()
    set_seed(0)

    target = tvm.target.Target("llvm")
    tir_pipeline = tir.get_default_tir_pipeline(target)
    relax_pipeline = "default"

    B = 64
    x = torch.rand(B, 100, 1, 1, dtype=torch.float32)
    model = Model()

    print("[repro] exporting torch -> relax ...")
    ir_mod = export_to_relax(model, x)

    pass_config = {
        "relax.FuseOps.max_depth": 4,
        "relax.backend.use_cuda_graph": 1,
        "tir.disable_storage_rewrite": 1,
        "tir.disable_vectorize": 1,
        "tir.enable_debug": 1,
        "tir.enable_equiv_terms_in_cse_tir": 1,
        "tir.ptx_ldg32": 1,
        "tir.use_async_copy": 1,
    }

    pc_kwargs = {
        "opt_level": 1,
        "disabled_pass": [
            "CanonicalizeBindings",
            "Simplify",
            "VectorizeLoop",
            "RemoveNoOp",
        ],
        "config": pass_config,
    }

    print("[repro] target:", target)
    print("[repro] tir_pipeline: explicit_default")
    print("[repro] PassContext.config keys:", sorted(pass_config.keys()))
    print("[repro] compiling with tvm.compile ...")

    with tvm.transform.PassContext(**pc_kwargs):
        _ = tvm.compile(
            ir_mod,
            target=target,
            relax_pipeline=relax_pipeline,
            tir_pipeline=tir_pipeline,
        )

    print("[repro] compile finished (no crash).")


if __name__ == "__main__":
    main()

Actual Behavior

tvm.compile(...) crashes with a segfault. The stack trace consistently includes:

  • tvm::tir::BufferStore::BufferStore(...)
  • tvm::tir::PTXRewriter::VisitStmt_(BufferStoreNode const*)
  • tvm::tir::transform::InjectPTXLDG32(bool)

Excerpt:

!!!!!!! Segfault encountered !!!!!!!
...
tvm::tir::BufferStore::BufferStore(...)
tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)
...
tvm::tir::transform::PrimFuncPassNode::operator()(...)
...
tvm::tir::transform::InjectPTXLDG32(bool)
Segmentation fault (core dumped)

Triage

  • needs-triage
  • bug

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions