Skip to content

Commit c199df1

Browse files
committed
[Relax][Torch] Fix from_exported_program crash with FakeTensor/lifted tensors (#18407)
1 parent c75b5ac commit c199df1

File tree

2 files changed

+122
-25
lines changed

2 files changed

+122
-25
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
3636

3737
@staticmethod
3838
def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Tensor:
39-
"""Convert a PyTorch tensor to TVM tensor, handling sparse tensors, FakeTensors, and lifted tensors.
39+
"""Convert a PyTorch tensor to TVM tensor, handling sparse tensors.
4040
4141
Parameters
4242
----------
@@ -47,19 +47,12 @@ def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Te
4747
-------
4848
tvm.runtime.Tensor
4949
The converted TVM tensor.
50+
51+
Raises
52+
------
53+
RuntimeError
54+
If the tensor is a FakeTensor or other tensor subclass that cannot be converted.
5055
"""
51-
# Fix for Issue #18407: Handle FakeTensor and lifted tensors (from torch.export)
52-
# Check if this is a FakeTensor or tensor subclass that doesn't support .numpy()
53-
try:
54-
# Check if it's a FakeTensor
55-
if hasattr(torch, '_subclasses') and hasattr(torch._subclasses, 'fake_tensor'):
56-
if isinstance(tensor_value, torch._subclasses.fake_tensor.FakeTensor):
57-
# Create a real tensor with the same shape and dtype
58-
real_tensor = torch.zeros(tensor_value.shape, dtype=tensor_value.dtype)
59-
return tvm.runtime.tensor(real_tensor.numpy())
60-
except (AttributeError, ImportError):
61-
pass
62-
6356
# PyTorch sparse tensors (layout != torch.strided) must be converted to dense.
6457
if tensor_value.layout != torch.strided:
6558
tensor_to_convert = tensor_value.to_dense()
@@ -73,17 +66,8 @@ def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Te
7366
except (RuntimeError, BufferError):
7467
# Fallback: convert to numpy and then to TVM tensor
7568
# This handles cases where DLPack conversion fails
76-
try:
77-
tensor_cpu = tensor_detached.cpu().contiguous()
78-
return tvm.runtime.tensor(tensor_cpu.numpy())
79-
except RuntimeError as e:
80-
# Fix for Issue #18407: Handle tensor subclasses that don't support .numpy()
81-
# This can happen with lifted tensors from torch.export
82-
if "tensor subclasses" in str(e) or "FakeTensor" in str(e):
83-
# Create a dummy tensor with the same shape and dtype
84-
dummy_tensor = torch.zeros(tensor_value.shape, dtype=tensor_value.dtype)
85-
return tvm.runtime.tensor(dummy_tensor.numpy())
86-
raise
69+
tensor_cpu = tensor_detached.cpu().contiguous()
70+
return tvm.runtime.tensor(tensor_cpu.numpy())
8771

8872
########## Unary Ops ##########
8973

@@ -1709,11 +1693,27 @@ def from_exported_program(
17091693
binding = {}
17101694
for tensor_name, tensor_value in to_bind_parameters.items():
17111695
# find relax var name from graph signature
1696+
bind_name = None
17121697
for spec in exported_program.graph_signature.input_specs:
17131698
if tensor_name == spec.target:
17141699
bind_name = spec.arg.name
17151700
break
1716-
binding[bind_name] = self._convert_pytorch_tensor_to_tvm(tensor_value)
1701+
if bind_name is None:
1702+
# Skip tensors that don't have corresponding input specs
1703+
# (e.g., lifted_tensor from torch.export)
1704+
continue
1705+
try:
1706+
binding[bind_name] = self._convert_pytorch_tensor_to_tvm(tensor_value)
1707+
except RuntimeError as e:
1708+
# Skip FakeTensor/lifted tensors that cannot be converted
1709+
# These are typically intermediate tensors that torch.export couldn't properly lift
1710+
import warnings
1711+
1712+
warnings.warn(
1713+
f"Skipping parameter '{tensor_name}' (bind_name: '{bind_name}'): "
1714+
f"Cannot convert tensor to TVM format: {e}"
1715+
)
1716+
continue
17171717

17181718
mod = self.block_builder.get()
17191719
mod = relax.transform.BindParams("main", binding)(mod)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Test handling of FakeTensor and lifted tensors in from_exported_program"""
18+
import pytest
19+
20+
torch = pytest.importorskip("torch", "2.1")
21+
22+
import math
23+
import torch.nn as nn
24+
from torch.export import export as torch_export
25+
26+
import tvm
27+
from tvm.relax.frontend.torch import from_exported_program
28+
29+
30+
def test_lifted_tensor_with_masked_fill():
31+
"""Test Issue #18407: FakeTensor/lifted tensors from eq+expand+masked_fill_"""
32+
33+
def get_attn_pad_mask(seq_q, seq_k):
34+
B, Lq = seq_q.size()
35+
B2, Lk = seq_k.size()
36+
assert B == B2
37+
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # (B,1,Lk)
38+
return pad_attn_mask.expand(B, Lq, Lk) # (B,Lq,Lk)
39+
40+
class TinyMHA(nn.Module):
41+
def __init__(self, d_model=64, d_k=16, n_heads=4, dropout=0.1):
42+
super().__init__()
43+
self.h, self.dk = n_heads, d_k
44+
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
45+
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
46+
self.W_V = nn.Linear(d_model, d_k * n_heads, bias=False)
47+
self.proj = nn.Linear(d_k * n_heads, d_model, bias=False)
48+
self.ln = nn.LayerNorm(d_model)
49+
self.drop = nn.Dropout(dropout)
50+
51+
def forward(self, x, attn_mask):
52+
B, L, _ = x.shape
53+
q = self.W_Q(x).view(B, L, self.h, self.dk).transpose(1, 2)
54+
k = self.W_K(x).view(B, L, self.h, self.dk).transpose(1, 2)
55+
v = self.W_V(x).view(B, L, self.h, self.dk).transpose(1, 2)
56+
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.dk)
57+
# This masked_fill_ with eq+expand mask triggers lifted_tensor
58+
scores.masked_fill_(attn_mask.unsqueeze(1), -1e9)
59+
attn = torch.softmax(scores, dim=-1)
60+
ctx = torch.matmul(attn, v).transpose(1, 2).reshape(B, L, self.h * self.dk)
61+
out = self.drop(self.proj(ctx))
62+
return self.ln(out + x)
63+
64+
class MiniModel(nn.Module):
65+
def __init__(self, vocab=1000, d_model=64):
66+
super().__init__()
67+
self.emb = nn.Embedding(vocab, d_model)
68+
self.mha = TinyMHA(d_model=d_model, d_k=16, n_heads=4, dropout=0.1)
69+
self.proj = nn.Linear(d_model, vocab, bias=False)
70+
71+
def forward(self, enc_inputs):
72+
x = self.emb(enc_inputs)
73+
mask = get_attn_pad_mask(enc_inputs, enc_inputs)
74+
y = self.mha(x, mask)
75+
logits = self.proj(y)
76+
return logits.reshape(-1, logits.size(-1))
77+
78+
torch.manual_seed(42)
79+
model = MiniModel().eval()
80+
enc = torch.randint(0, 1000, (2, 5))
81+
enc[0, 0] = 0 # Ensure eq(0) path is taken
82+
83+
# Export with torch.export (may emit warnings about lifted_tensor)
84+
ep = torch_export(model, (enc,))
85+
86+
# This should not crash (Issue #18407)
87+
mod = from_exported_program(ep)
88+
89+
# Verify the module was created successfully
90+
assert isinstance(mod, tvm.IRModule)
91+
# The module should have a main function
92+
assert len(mod.functions) > 0
93+
94+
95+
if __name__ == "__main__":
96+
test_lifted_tensor_with_masked_fill()
97+
print("Test passed!")

0 commit comments

Comments
 (0)