Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Te
-------
tvm.runtime.Tensor
The converted TVM tensor.

Raises
------
RuntimeError
If the tensor is a FakeTensor or other tensor subclass that cannot be converted.
"""
# PyTorch sparse tensors (layout != torch.strided) must be converted to dense.
if tensor_value.layout != torch.strided:
Expand Down Expand Up @@ -1688,11 +1693,27 @@ def from_exported_program(
binding = {}
for tensor_name, tensor_value in to_bind_parameters.items():
# find relax var name from graph signature
bind_name = None
for spec in exported_program.graph_signature.input_specs:
if tensor_name == spec.target:
bind_name = spec.arg.name
break
binding[bind_name] = self._convert_pytorch_tensor_to_tvm(tensor_value)
if bind_name is None:
# Skip tensors that don't have corresponding input specs
# (e.g., lifted_tensor from torch.export)
continue
try:
binding[bind_name] = self._convert_pytorch_tensor_to_tvm(tensor_value)
except RuntimeError as e:
# Skip FakeTensor/lifted tensors that cannot be converted
# These are typically intermediate tensors that torch.export couldn't properly lift
import warnings

warnings.warn(
f"Skipping parameter '{tensor_name}' (bind_name: '{bind_name}'): "
f"Cannot convert tensor to TVM format: {e}"
)
continue

mod = self.block_builder.get()
mod = relax.transform.BindParams("main", binding)(mod)
Expand Down
97 changes: 97 additions & 0 deletions tests/python/relax/test_frontend_torch_export_faketensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test handling of FakeTensor and lifted tensors in from_exported_program"""
import pytest

torch = pytest.importorskip("torch", "2.1")

import math
import torch.nn as nn
from torch.export import export as torch_export

import tvm
from tvm.relax.frontend.torch import from_exported_program


def test_lifted_tensor_with_masked_fill():
"""Test Issue #18407: FakeTensor/lifted tensors from eq+expand+masked_fill_"""

def get_attn_pad_mask(seq_q, seq_k):
B, Lq = seq_q.size()
B2, Lk = seq_k.size()
assert B == B2
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # (B,1,Lk)
return pad_attn_mask.expand(B, Lq, Lk) # (B,Lq,Lk)

class TinyMHA(nn.Module):
def __init__(self, d_model=64, d_k=16, n_heads=4, dropout=0.1):
super().__init__()
self.h, self.dk = n_heads, d_k
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_V = nn.Linear(d_model, d_k * n_heads, bias=False)
self.proj = nn.Linear(d_k * n_heads, d_model, bias=False)
self.ln = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)

def forward(self, x, attn_mask):
B, L, _ = x.shape
q = self.W_Q(x).view(B, L, self.h, self.dk).transpose(1, 2)
k = self.W_K(x).view(B, L, self.h, self.dk).transpose(1, 2)
v = self.W_V(x).view(B, L, self.h, self.dk).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.dk)
# This masked_fill_ with eq+expand mask triggers lifted_tensor
scores.masked_fill_(attn_mask.unsqueeze(1), -1e9)
attn = torch.softmax(scores, dim=-1)
ctx = torch.matmul(attn, v).transpose(1, 2).reshape(B, L, self.h * self.dk)
out = self.drop(self.proj(ctx))
return self.ln(out + x)

class MiniModel(nn.Module):
def __init__(self, vocab=1000, d_model=64):
super().__init__()
self.emb = nn.Embedding(vocab, d_model)
self.mha = TinyMHA(d_model=d_model, d_k=16, n_heads=4, dropout=0.1)
self.proj = nn.Linear(d_model, vocab, bias=False)

def forward(self, enc_inputs):
x = self.emb(enc_inputs)
mask = get_attn_pad_mask(enc_inputs, enc_inputs)
y = self.mha(x, mask)
logits = self.proj(y)
return logits.reshape(-1, logits.size(-1))

torch.manual_seed(42)
model = MiniModel().eval()
enc = torch.randint(0, 1000, (2, 5))
enc[0, 0] = 0 # Ensure eq(0) path is taken

# Export with torch.export (may emit warnings about lifted_tensor)
ep = torch_export(model, (enc,))

# This should not crash (Issue #18407)
mod = from_exported_program(ep)

# Verify the module was created successfully
assert isinstance(mod, tvm.IRModule)
# The module should have a main function
assert len(mod.functions) > 0


if __name__ == "__main__":
test_lifted_tensor_with_masked_fill()
print("Test passed!")