Open
Description
Bug Description
model with dynamic shape saved successfully, when load it throw the error
To Reproduce
Steps to reproduce the behavior:
test code to reproduce:
import torch
from torch.export import Dim
import torch.nn as nn
import torch_tensorrt as torchtrt
import os
import tempfile
class bitwise_and(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val)
dyn_dim = Dim("dyn_dim", min=3, max=6)
lhs = torch.randint(0, 2, (2, 4, 2), dtype=bool, device="cuda")
rhs = torch.randint(0, 2, (4, 2), dtype=bool, device="cuda")
inputs = (lhs, rhs)
torchtrt_inputs = [torchtrt.Input(shape=lhs.shape, dtype=torch.bool),
torchtrt.Input(shape=rhs.shape, dtype=torch.bool)]
mod = bitwise_and()
fx_mod=torch.export.export(mod, inputs, dynamic_shapes={"lhs_val": {1: dyn_dim}, "rhs_val": {0: dyn_dim}})
print(f"lan added fx_mod={fx_mod}")
trt_model = torchtrt.dynamo.compile(fx_mod, inputs=inputs, enable_precisions={torch.bool}, min_block_size=1)
trt_ep_path = os.path.join(tempfile.gettempdir(), "trt.ep")
lhs1 = torch.randint(0, 2, (2, 5, 2), dtype=bool, device="cuda")
rhs1 = torch.randint(0, 2, (5, 2), dtype=bool, device="cuda")
torchtrt.save(trt_model, trt_ep_path, inputs=[lhs1, rhs1])
print(f"lan added saved model to {trt_ep_path}")
loaded_trt_module = torch.export.load(trt_ep_path)
print(f"lan added load model from {trt_ep_path}")
output = loaded_trt_module(lhs1, rhs1)
print(f"lan added got {output=}")
Expected behavior
torch.export.load should be able to load the model
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0):
- PyTorch Version (e.g. 1.0):
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Additional context
the detailed error thrown are as below:
E0923 09:53:27.553340 1300597 site-packages/torch/fx/experimental/recording.py:298] failed while running evaluate_expr(*(s0 >= 0, True), **{'fx_node': False})
Traceback (most recent call last):
File "/home/lanl/git/script/python/export_dynamic_shape_save_load_torchtrt_example.py", line 35, in <module>
loaded_trt_module = torch.export.load(trt_ep_path)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/export/__init__.py", line 569, in load
ep = deserialize(artifact, expected_opset_version)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 2436, in deserialize
ExportedProgramDeserializer(expected_opset_version)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 2315, in deserialize
GraphModuleDeserializer()
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 1906, in deserialize
self.deserialize_graph(serialized_graph_module.graph)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 1612, in deserialize_graph
meta_val = self.deserialize_tensor_meta(tensor_value)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 1579, in deserialize_tensor_meta
torch.empty_strided(
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/utils/_stats.py", line 21, in wrapper
return fn(*args, **kwargs)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1238, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1339, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 2009, in _dispatch_impl
op_impl_out = op_impl(self, func, *args, **kwargs)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_impls.py", line 176, in constructors
r = func(*args, **new_kwargs)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_ops.py", line 716, in __call__
return self._op(*args, **kwargs)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/sym_node.py", line 479, in expect_size
r = b.expect_true(file, line)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/sym_node.py", line 465, in expect_true
return self.guard_bool(file, line)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/sym_node.py", line 449, in guard_bool
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/recording.py", line 262, in wrapper
return retlog(fn(*args, **kwargs))
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5207, in evaluate_expr
return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5283, in _evaluate_expr
static_expr = self._maybe_evaluate_static(expr,
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 1604, in wrapper
return fn_cache(self, *args, **kwargs)
File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4552, in _maybe_evaluate_static
vr = var_ranges[k]
KeyError: s0