Skip to content

🐛 [Bug] torch_tensorrt.load a model saved with dynamic shape is throwing error #3174

Open
@lanluo-nvidia

Description

@lanluo-nvidia

Bug Description

model with dynamic shape saved successfully, when load it throw the error

To Reproduce

Steps to reproduce the behavior:

test code to reproduce:

import torch

from torch.export import Dim
import torch.nn as nn
import torch_tensorrt as torchtrt
import os
import tempfile

class bitwise_and(nn.Module):
    def forward(self, lhs_val, rhs_val):
        return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val)

dyn_dim = Dim("dyn_dim", min=3, max=6)
lhs = torch.randint(0, 2, (2, 4, 2), dtype=bool, device="cuda")
rhs = torch.randint(0, 2, (4, 2), dtype=bool, device="cuda")
inputs = (lhs, rhs)
torchtrt_inputs = [torchtrt.Input(shape=lhs.shape, dtype=torch.bool), 
                   torchtrt.Input(shape=rhs.shape, dtype=torch.bool)] 
mod = bitwise_and()


fx_mod=torch.export.export(mod, inputs, dynamic_shapes={"lhs_val": {1: dyn_dim}, "rhs_val": {0: dyn_dim}})
print(f"lan added fx_mod={fx_mod}")
trt_model = torchtrt.dynamo.compile(fx_mod, inputs=inputs, enable_precisions={torch.bool}, min_block_size=1)
trt_ep_path = os.path.join(tempfile.gettempdir(), "trt.ep")

lhs1 = torch.randint(0, 2, (2, 5, 2), dtype=bool, device="cuda")
rhs1 = torch.randint(0, 2, (5, 2), dtype=bool, device="cuda")
torchtrt.save(trt_model, trt_ep_path, inputs=[lhs1, rhs1])
print(f"lan added saved model to {trt_ep_path}")

loaded_trt_module = torch.export.load(trt_ep_path)
print(f"lan added load model from {trt_ep_path}") 

output = loaded_trt_module(lhs1, rhs1)
print(f"lan added got {output=}")

Expected behavior

torch.export.load should be able to load the model

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

the detailed error thrown are as below:

E0923 09:53:27.553340 1300597 site-packages/torch/fx/experimental/recording.py:298] failed while running evaluate_expr(*(s0 >= 0, True), **{'fx_node': False})
Traceback (most recent call last):
  File "/home/lanl/git/script/python/export_dynamic_shape_save_load_torchtrt_example.py", line 35, in <module>
    loaded_trt_module = torch.export.load(trt_ep_path)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/export/__init__.py", line 569, in load
    ep = deserialize(artifact, expected_opset_version)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 2436, in deserialize
    ExportedProgramDeserializer(expected_opset_version)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 2315, in deserialize
    GraphModuleDeserializer()
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 1906, in deserialize
    self.deserialize_graph(serialized_graph_module.graph)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 1612, in deserialize_graph
    meta_val = self.deserialize_tensor_meta(tensor_value)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 1579, in deserialize_tensor_meta
    torch.empty_strided(
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1238, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1339, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 2009, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_impls.py", line 176, in constructors
    r = func(*args, **new_kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_ops.py", line 716, in __call__
    return self._op(*args, **kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/sym_node.py", line 479, in expect_size
    r = b.expect_true(file, line)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/sym_node.py", line 465, in expect_true
    return self.guard_bool(file, line)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/sym_node.py", line 449, in guard_bool
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/recording.py", line 262, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5207, in evaluate_expr
    return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5283, in _evaluate_expr
    static_expr = self._maybe_evaluate_static(expr,
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 1604, in wrapper
    return fn_cache(self, *args, **kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4552, in _maybe_evaluate_static
    vr = var_ranges[k]
KeyError: s0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions