|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +"""Test handling of FakeTensor and lifted tensors in from_exported_program""" |
| 18 | +import pytest |
| 19 | + |
| 20 | +torch = pytest.importorskip("torch", "2.1") |
| 21 | + |
| 22 | +import math |
| 23 | +import torch.nn as nn |
| 24 | +from torch.export import export as torch_export |
| 25 | + |
| 26 | +import tvm |
| 27 | +from tvm.relax.frontend.torch import from_exported_program |
| 28 | + |
| 29 | + |
| 30 | +def test_lifted_tensor_with_masked_fill(): |
| 31 | + """Test Issue #18407: FakeTensor/lifted tensors from eq+expand+masked_fill_""" |
| 32 | + |
| 33 | + def get_attn_pad_mask(seq_q, seq_k): |
| 34 | + B, Lq = seq_q.size() |
| 35 | + B2, Lk = seq_k.size() |
| 36 | + assert B == B2 |
| 37 | + pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # (B,1,Lk) |
| 38 | + return pad_attn_mask.expand(B, Lq, Lk) # (B,Lq,Lk) |
| 39 | + |
| 40 | + class TinyMHA(nn.Module): |
| 41 | + def __init__(self, d_model=64, d_k=16, n_heads=4, dropout=0.1): |
| 42 | + super().__init__() |
| 43 | + self.h, self.dk = n_heads, d_k |
| 44 | + self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False) |
| 45 | + self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False) |
| 46 | + self.W_V = nn.Linear(d_model, d_k * n_heads, bias=False) |
| 47 | + self.proj = nn.Linear(d_k * n_heads, d_model, bias=False) |
| 48 | + self.ln = nn.LayerNorm(d_model) |
| 49 | + self.drop = nn.Dropout(dropout) |
| 50 | + |
| 51 | + def forward(self, x, attn_mask): |
| 52 | + B, L, _ = x.shape |
| 53 | + q = self.W_Q(x).view(B, L, self.h, self.dk).transpose(1, 2) |
| 54 | + k = self.W_K(x).view(B, L, self.h, self.dk).transpose(1, 2) |
| 55 | + v = self.W_V(x).view(B, L, self.h, self.dk).transpose(1, 2) |
| 56 | + scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.dk) |
| 57 | + # This masked_fill_ with eq+expand mask triggers lifted_tensor |
| 58 | + scores.masked_fill_(attn_mask.unsqueeze(1), -1e9) |
| 59 | + attn = torch.softmax(scores, dim=-1) |
| 60 | + ctx = torch.matmul(attn, v).transpose(1, 2).reshape(B, L, self.h * self.dk) |
| 61 | + out = self.drop(self.proj(ctx)) |
| 62 | + return self.ln(out + x) |
| 63 | + |
| 64 | + class MiniModel(nn.Module): |
| 65 | + def __init__(self, vocab=1000, d_model=64): |
| 66 | + super().__init__() |
| 67 | + self.emb = nn.Embedding(vocab, d_model) |
| 68 | + self.mha = TinyMHA(d_model=d_model, d_k=16, n_heads=4, dropout=0.1) |
| 69 | + self.proj = nn.Linear(d_model, vocab, bias=False) |
| 70 | + |
| 71 | + def forward(self, enc_inputs): |
| 72 | + x = self.emb(enc_inputs) |
| 73 | + mask = get_attn_pad_mask(enc_inputs, enc_inputs) |
| 74 | + y = self.mha(x, mask) |
| 75 | + logits = self.proj(y) |
| 76 | + return logits.reshape(-1, logits.size(-1)) |
| 77 | + |
| 78 | + torch.manual_seed(42) |
| 79 | + model = MiniModel().eval() |
| 80 | + enc = torch.randint(0, 1000, (2, 5)) |
| 81 | + enc[0, 0] = 0 # Ensure eq(0) path is taken |
| 82 | + |
| 83 | + # Export with torch.export (may emit warnings about lifted_tensor) |
| 84 | + ep = torch_export(model, (enc,)) |
| 85 | + |
| 86 | + # This should not crash (Issue #18407) |
| 87 | + mod = from_exported_program(ep) |
| 88 | + |
| 89 | + # Verify the module was created successfully |
| 90 | + assert isinstance(mod, tvm.IRModule) |
| 91 | + # The module should have a main function |
| 92 | + assert len(mod.functions) > 0 |
| 93 | + |
| 94 | + |
| 95 | +if __name__ == "__main__": |
| 96 | + test_lifted_tensor_with_masked_fill() |
| 97 | + print("Test passed!") |
0 commit comments