Skip to content

Commit 88ee123

Browse files
authored
fix: Repair integer inputs in dynamic shape cases (#2891)
1 parent 483e646 commit 88ee123

File tree

5 files changed

+159
-20
lines changed

5 files changed

+159
-20
lines changed

.github/workflows/build-test-linux.yml

+1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ jobs:
145145
cd tests/py/dynamo
146146
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
147147
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
148+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_generation_compile.xml models/test_hf_generate_dynamic.py
148149
popd
149150
150151
tests-py-dynamo-serde:

.github/workflows/build-test-windows.yml

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ jobs:
9999
cd tests/py/dynamo
100100
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
101101
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
102+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_generation_compile.xml models/test_hf_generate_dynamic.py
102103
popd
103104
104105
tests-py-torch-compile-be:

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
128128
self.context = self.engine.create_execution_context()
129129

130130
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
131+
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
132+
contiguous_inputs: List[torch.Tensor] = [
133+
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
134+
for i in inputs
135+
]
131136
with (
132137
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
133138
if self.profiling_enabled
@@ -174,7 +179,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
174179
self.input_names
175180
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
176181

177-
contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
178182
for i, input_name in enumerate(self.input_names):
179183
if not contiguous_inputs[i].is_cuda:
180184
logger.warning(
@@ -193,12 +197,17 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
193197
contiguous_inputs[i].dtype == self.input_dtypes[i]
194198
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
195199

200+
# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
201+
# as per TensorRT requirements
196202
if self.engine.is_shape_inference_io(input_name):
197-
# Shape tensor inputs are casted to int32 explicitly.
198-
# Refer to https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
199-
inputs_cpu = contiguous_inputs[i].cpu().to(torch.int32)
203+
# Shape tensor inputs are casted to int64 explicitly
204+
# Currently Torch CPU pointers are not working; numpy pointers are used instead
205+
# to refer to underlying memory
206+
inputs_cpu = (
207+
contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
208+
)
200209
self.context.set_tensor_address(
201-
input_name, inputs_cpu.data_ptr()
210+
input_name, inputs_cpu.ctypes.data
202211
)
203212
else:
204213
self.context.set_input_shape(

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
146146
"""Implementation of the forward pass for a TensorRT engine
147147
148148
Args:
149-
*inputs (torch.Tensor): Inputs to the forward function, must all be ``torch.Tensor``
149+
*inputs (Union[torch.Tensor, int]): Inputs to the forward function
150150
151151
Returns:
152152
torch.Tensor or Tuple(torch.Tensor): Result of the engine computation
@@ -158,22 +158,18 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
158158
self.input_binding_names
159159
), f"Wrong number of inputs, expected {len(self.input_binding_names)} got {len(inputs)}."
160160

161-
types: List[bool] = [issubclass(type(i), torch.Tensor) for i in inputs]
162-
163-
try:
164-
assert all(types)
165-
except AssertionError:
166-
167-
def is_non_tensor(i: Tuple[Any, bool]) -> bool:
168-
return not i[1]
169-
170-
non_tensors = [i[0] for i in filter(is_non_tensor, zip(inputs, types))]
171-
raise RuntimeError(
172-
f"TorchTensorRTModule expects a flattened list of tensors as input, found non tensors: {non_tensors}"
173-
)
161+
# If the inputs are not Torch Tensors, which can occur in scenarios such as shape tensors
162+
# which are outputs of a preceding Torch subgraph (where the Dynamic input may be an integer)
163+
# directly cast the input to a Torch Tensor.
164+
#
165+
# This also avoids the need for type-checking inputs, since they are now explicitly casted to Torch tensors
166+
input_tensors: List[torch.Tensor] = [
167+
(i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
168+
for i in inputs
169+
]
174170

175171
outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine(
176-
list(inputs), self.engine
172+
list(input_tensors), self.engine
177173
)
178174

179175
if len(outputs) == 1:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import pytest
2+
import torch
3+
import torch_tensorrt
4+
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
5+
from transformers.generation.stopping_criteria import (
6+
EosTokenCriteria,
7+
MaxLengthCriteria,
8+
)
9+
10+
11+
@pytest.mark.unit
12+
def test_dynamic_generation_python_rt():
13+
"""
14+
Tests HuggingFace Generate Code with dynamic shapes
15+
Code Credit: @peri044
16+
"""
17+
# Define tokenizer and model
18+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
19+
model = (
20+
AutoModelForCausalLM.from_pretrained(
21+
"gpt2", pad_token_id=tokenizer.eos_token_id, use_cache=False
22+
)
23+
.eval()
24+
.to("cuda")
25+
)
26+
27+
# Input prompt
28+
model_inputs = tokenizer(("Repeat " * 128)[:-1], return_tensors="pt").to("cuda")
29+
input_ids = model_inputs["input_ids"]
30+
max_tokens = 40
31+
32+
# Pyt model outputs
33+
greedy_output = model.generate(**model_inputs, max_new_tokens=max_tokens)
34+
print(
35+
"Pytorch model generated text: ",
36+
tokenizer.decode(greedy_output[0], skip_special_tokens=True),
37+
)
38+
39+
# Compile Torch-TRT model
40+
torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023)
41+
model.forward = torch.compile(
42+
model.forward,
43+
backend="tensorrt",
44+
dynamic=None,
45+
options={
46+
"enabled_precisions": {torch.float},
47+
"torch_executed_ops": {"torch.ops.aten.slice.Tensor"},
48+
"use_python_runtime": True,
49+
"optimization_level": 0,
50+
"min_block_size": 29,
51+
},
52+
)
53+
54+
# Auto-regressive generation loop for greedy search
55+
stopping_criteria = StoppingCriteriaList(
56+
[
57+
MaxLengthCriteria(max_length=max_tokens),
58+
EosTokenCriteria(eos_token_id=tokenizer.eos_token_id),
59+
]
60+
)
61+
while True:
62+
trt_outputs = model(input_ids)
63+
logits = trt_outputs.logits
64+
next_token_logits = logits[:, -1, :]
65+
next_tokens = torch.argmax(next_token_logits, dim=-1)
66+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
67+
if stopping_criteria(input_ids, logits).item():
68+
break
69+
70+
# TODO: Add test for correctness
71+
72+
73+
@pytest.mark.unit
74+
def test_dynamic_generation_cpp_rt():
75+
"""
76+
Tests HuggingFace Generate Code with dynamic shapes
77+
Code Credit: @peri044
78+
"""
79+
# Define tokenizer and model
80+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
81+
model = (
82+
AutoModelForCausalLM.from_pretrained(
83+
"gpt2", pad_token_id=tokenizer.eos_token_id, use_cache=False
84+
)
85+
.eval()
86+
.to("cuda")
87+
)
88+
89+
# Input prompt
90+
model_inputs = tokenizer(("Repeat " * 128)[:-1], return_tensors="pt").to("cuda")
91+
input_ids = model_inputs["input_ids"]
92+
max_tokens = 40
93+
94+
# Pyt model outputs
95+
greedy_output = model.generate(**model_inputs, max_new_tokens=max_tokens)
96+
print(
97+
"Pytorch model generated text: ",
98+
tokenizer.decode(greedy_output[0], skip_special_tokens=True),
99+
)
100+
101+
# Compile Torch-TRT model
102+
torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023)
103+
model.forward = torch.compile(
104+
model.forward,
105+
backend="tensorrt",
106+
dynamic=None,
107+
options={
108+
"enabled_precisions": {torch.float},
109+
"torch_executed_ops": {"torch.ops.aten.slice.Tensor"},
110+
"use_python_runtime": False,
111+
"optimization_level": 0,
112+
"min_block_size": 29,
113+
},
114+
)
115+
116+
# Auto-regressive generation loop for greedy search
117+
stopping_criteria = StoppingCriteriaList(
118+
[
119+
MaxLengthCriteria(max_length=max_tokens),
120+
EosTokenCriteria(eos_token_id=tokenizer.eos_token_id),
121+
]
122+
)
123+
while True:
124+
trt_outputs = model(input_ids)
125+
logits = trt_outputs.logits
126+
next_token_logits = logits[:, -1, :]
127+
next_tokens = torch.argmax(next_token_logits, dim=-1)
128+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
129+
if stopping_criteria(input_ids, logits).item():
130+
break
131+
132+
# TODO: Add test for correctness

0 commit comments

Comments
 (0)