-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Summary
Compiling a Relax IRModule converted from a PyTorch torch.export program crashes with a segmentation fault inside TVM’s TIR pass pipeline, specifically in:
tvm::tir::transform::InjectPTXLDG32(bool)tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)tvm::tir::BufferStore::BufferStore(...)
This occurs while invoking tvm.compile(...) with:
target = tvm.target.Target("llvm")(CPU-only)tir_pipeline = tir.get_default_tir_pipeline(target)relax_pipeline = "default"PassContext.configincludes"tir.ptx_ldg32": 1plus several other flags
Even though the target is LLVM CPU, the stack trace indicates a PTX-specific pass / rewriter is running and then segfaulting.
This is not a Python exception; it is a hard crash (Segmentation fault (core dumped)), so it likely indicates a bug in pass gating / pipeline selection / or an unsafe assumption in the InjectPTXLDG32 pass when used under this pipeline.
Environment
From the repro output:
- TVM version:
0.22.0 - TVM git commit:
9dbf3f22ff6f44962472f9af310fda368ca85ef2 - LLVM:
17.0.6 - PyTorch:
2.9.0+cu128 - Python:
3.10.16(inferred from stack paths) - NumPy:
2.2.6(printed as “Python version” in script; see note below) - OS: Linux x86_64
Minimal Repro Script
import random
import numpy as np
import torch
import torch.nn as nn
import tvm
from tvm import tir
def print_env_info():
print("==== Environment Info ====")
print("TVM version:", getattr(tvm, "__version__", "unknown"))
print("TVM git commit:", tvm.support.libinfo()["GIT_COMMIT_HASH"])
print("TVM LLVM version:", tvm.support.libinfo().get("LLVM_VERSION", "unknown"))
print("NumPy version:", np.__version__)
print("PyTorch version:", torch.__version__)
print("==========================\n")
def set_seed(seed: int = 0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(100, 64, 4, 1, 0, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
nn.Tanh(),
nn.Conv2d(3, 8, 3, 1, 1, bias=False),
nn.BatchNorm2d(8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(8, 1, 1, 1, 0, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
y = self.net(x)
return y.reshape(-1)
def export_to_relax(mod: nn.Module, x: torch.Tensor) -> tvm.IRModule:
mod = mod.to("cpu").eval()
x = x.to("cpu")
ep = torch.export.export(mod, (x,))
from tvm.relax.frontend.torch import from_exported_program
return from_exported_program(ep)
def main():
print_env_info()
set_seed(0)
target = tvm.target.Target("llvm")
tir_pipeline = tir.get_default_tir_pipeline(target)
relax_pipeline = "default"
B = 64
x = torch.rand(B, 100, 1, 1, dtype=torch.float32)
model = Model()
print("[repro] exporting torch -> relax ...")
ir_mod = export_to_relax(model, x)
pass_config = {
"relax.FuseOps.max_depth": 4,
"relax.backend.use_cuda_graph": 1,
"tir.disable_storage_rewrite": 1,
"tir.disable_vectorize": 1,
"tir.enable_debug": 1,
"tir.enable_equiv_terms_in_cse_tir": 1,
"tir.ptx_ldg32": 1,
"tir.use_async_copy": 1,
}
pc_kwargs = {
"opt_level": 1,
"disabled_pass": [
"CanonicalizeBindings",
"Simplify",
"VectorizeLoop",
"RemoveNoOp",
],
"config": pass_config,
}
print("[repro] target:", target)
print("[repro] tir_pipeline: explicit_default")
print("[repro] PassContext.config keys:", sorted(pass_config.keys()))
print("[repro] compiling with tvm.compile ...")
with tvm.transform.PassContext(**pc_kwargs):
_ = tvm.compile(
ir_mod,
target=target,
relax_pipeline=relax_pipeline,
tir_pipeline=tir_pipeline,
)
print("[repro] compile finished (no crash).")
if __name__ == "__main__":
main()Actual Behavior
tvm.compile(...) crashes with a segfault. The stack trace consistently includes:
tvm::tir::BufferStore::BufferStore(...)tvm::tir::PTXRewriter::VisitStmt_(BufferStoreNode const*)tvm::tir::transform::InjectPTXLDG32(bool)
Excerpt:
!!!!!!! Segfault encountered !!!!!!!
...
tvm::tir::BufferStore::BufferStore(...)
tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)
...
tvm::tir::transform::PrimFuncPassNode::operator()(...)
...
tvm::tir::transform::InjectPTXLDG32(bool)
Segmentation fault (core dumped)
Triage
- needs-triage
- bug
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug