-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Description
Hi, thanks for creating this package, it helps us to run whisper with tensorRT.
however, we found that is package didn't include a dependency map (usually is done by requirements.txt)
so we run whisper on the latest release version (20240930). seems it didn;t work with sdpa package introduced in the commit openai/whisper@27f9713.
environment:
torch==2.5.0
torch2trt @ git+https://github.com/NVIDIA-AI-IOT/torch2trt@4e820ae31b4e35d59685935223b05b2e11d47b03
whisper-trt @ git+https://github.com/NVIDIA-AI-IOT/whisper_trt@3cea986723033151c845c87c83caa48e3ceff7ab
tensorrt==10.5.0
tensorrt-cu12==10.5.0
tensorrt-cu12-bindings==10.5.0
tensorrt-cu12-libs==10.5.0
onnx==1.17.0
onnx-graphsurgeon==0.5.2
onnxruntime==1.20.0
openai-whisper==20240930
here is my stacktrace, and hope it helpful.
File "/usr/local/lib/python3.11/dist-packages/whisper_trt/model.py", line 312, in build
"text_decoder_engine": cls.build_text_decoder_engine().state_dict(),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/whisper_trt/model.py", line 206, in build_text_decoder_engine
engine = torch2trt.torch2trt(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch2trt/torch2trt.py", line 608, in torch2trt
torch.onnx.export(
File "/usr/local/lib/python3.11/dist-packages/torch/onnx/__init__.py", line 375, in export
export(
File "/usr/local/lib/python3.11/dist-packages/torch/onnx/utils.py", line 502, in export
_export(
File "/usr/local/lib/python3.11/dist-packages/torch/onnx/utils.py", line 1564, in _export
graph, params_dict, torch_out = _model_to_graph(
^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/onnx/utils.py", line 997, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/onnx/utils.py", line 904, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/jit/_trace.py", line 1500, in _get_trace_graph
outs = ONNXTracedModule(
^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/jit/_trace.py", line 139, in forward
graph, out = torch._C._create_graph_by_tracing(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/jit/_trace.py", line 130, in wrapper
outs.append(self.inner(*trace_inputs))
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
result = self.forward(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch2trt/flatten_module.py", line 34, in forward
output = self.module(*args)
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
result = self.forward(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/whisper_trt/model.py", line 94, in forward
x = block(x, xa, mask)
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
result = self.forward(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/whisper/model.py", line 167, in forward
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
result = self.forward(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/whisper/model.py", line 111, in forward
wv, qk = self.qkv_attention(q, k, v, mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/whisper/model.py", line 124, in qkv_attention
a = scaled_dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor
lea-xtend-ai and mtinnes
Metadata
Metadata
Assignees
Labels
No labels