Skip to content

Commit 8774dc9

Browse files
authored
fix: cherry pick PR of 3445 (#3457)
1 parent a674f31 commit 8774dc9

File tree

12 files changed

+246
-132
lines changed

12 files changed

+246
-132
lines changed

.github/workflows/build-test-linux.yml

+8-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,13 @@ jobs:
173173
cd tests/py
174174
python -m pip install -r requirements.txt
175175
cd dynamo
176-
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/
176+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_models.xml --ir dynamo models/test_models.py
177+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_models_dynamic.xml --ir dynamo models/test_dyn_models.py
178+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/engine_cache.xml --ir dynamo models/test_engine_cache.py
179+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dtype_support.xml --ir dynamo models/test_dtype_support.py
180+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/model_refit.xml --ir dynamo models/test_model_refit.py
181+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/modelopt_models.xml --ir dynamo models/test_modelopt_models.py
182+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/weight_stripped_engine.xml --ir dynamo models/test_weight_stripped_engine.py
177183
popd
178184
179185
tests-py-dynamo-serde:
@@ -206,6 +212,7 @@ jobs:
206212
cd dynamo
207213
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
208214
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py
215+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_kwargs_serde_test_results.xml --ir dynamo models/test_export_kwargs_serde.py
209216
popd
210217
211218
tests-py-torch-compile-be:

py/torch_tensorrt/dynamo/_refit.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
logger = logging.getLogger(__name__)
4949

5050

51-
@needs_refit
51+
@needs_refit # type: ignore
5252
def construct_refit_mapping(
5353
module: torch.fx.GraphModule,
5454
inputs: Sequence[Input],
@@ -81,7 +81,7 @@ def construct_refit_mapping(
8181
return interpreter.ctx.mapping
8282

8383

84-
@needs_refit
84+
@needs_refit # type: ignore
8585
def construct_refit_mapping_from_weight_name_map(
8686
weight_name_map: dict[Any, Any],
8787
state_dict: dict[Any, Any],
@@ -111,7 +111,7 @@ def construct_refit_mapping_from_weight_name_map(
111111
return engine_weight_map
112112

113113

114-
@needs_refit
114+
@needs_refit # type: ignore
115115
def _refit_single_trt_engine_with_gm(
116116
new_gm: torch.fx.GraphModule,
117117
old_engine: trt.ICudaEngine,
@@ -192,7 +192,7 @@ def _refit_single_trt_engine_with_gm(
192192
raise AssertionError("Refitting failed.")
193193

194194

195-
@needs_refit
195+
@needs_refit # type: ignore
196196
def refit_module_weights(
197197
compiled_module: torch.fx.GraphModule | ExportedProgram,
198198
new_weight_module: ExportedProgram,

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
893893
else:
894894
constant_tensor = frozen_attr
895895

896-
return to_torch(constant_tensor)
896+
return to_torch(constant_tensor)
897897

898898
def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
899899
assert isinstance(target, str)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+96-97
Original file line numberDiff line numberDiff line change
@@ -358,31 +358,27 @@ def create_constant(
358358
shape = trt.Dims()
359359
else:
360360
shape = list(torch_value.shape)
361-
if torch_value is not None:
362-
if torch_value.dtype == torch.bfloat16:
363-
torch_value_fp32 = torch_value.to(torch.float32)
364-
numpy_value = torch_value_fp32.numpy()
365-
else:
366-
numpy_value = torch_value.numpy()
367361

368-
ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1)
369-
constant = ctx.net.add_constant(
370-
shape,
371-
numpy_value,
372-
)
373-
constant.name = name
374-
if torch_value.dtype == torch.bfloat16:
375-
return cast_trt_tensor(
376-
ctx,
377-
constant.get_output(0),
378-
trt.DataType.BF16,
379-
name + "_bf16_cast",
380-
)
381-
return constant.get_output(0)
362+
if torch_value.dtype == torch.bfloat16:
363+
torch_value_fp32 = torch_value.to(torch.float32)
364+
numpy_value = torch_value_fp32.numpy()
382365
else:
383-
raise ValueError(
384-
f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None."
366+
numpy_value = torch_value.numpy()
367+
368+
ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1)
369+
constant = ctx.net.add_constant(
370+
shape,
371+
numpy_value,
372+
)
373+
constant.name = name
374+
if torch_value.dtype == torch.bfloat16:
375+
return cast_trt_tensor(
376+
ctx,
377+
constant.get_output(0),
378+
trt.DataType.BF16,
379+
name + "_bf16_cast",
385380
)
381+
return constant.get_output(0)
386382

387383

388384
def get_trt_tensor(
@@ -423,53 +419,6 @@ def get_trt_tensor(
423419
raise AssertionError(f"Cannot convert {input_val} to TRT constant")
424420

425421

426-
def to_torch(
427-
value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]],
428-
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None,
429-
) -> Optional[torch.Tensor]:
430-
"""
431-
Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU
432-
Args:
433-
value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]):
434-
A PyTorch tensor, Numpy array, int, float, or bool
435-
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
436-
If a dtype is given, we will convert the type of the given `value` to this dtype.
437-
Returns:
438-
A PyTorch tensor or None, if the input was None.
439-
"""
440-
441-
cpu_device = torch.device("cpu")
442-
torch_dtype = (
443-
_enums.dtype._from(dtype).to(torch.dtype, use_default=True) if dtype else None
444-
)
445-
446-
with unset_fake_temporarily():
447-
if value is None:
448-
return None
449-
450-
elif isinstance(value, torch.Tensor):
451-
output = value.to(cpu_device).contiguous()
452-
453-
elif isinstance(value, np.ndarray):
454-
output = torch.from_numpy(value).to(cpu_device).contiguous()
455-
456-
elif isinstance(value, int):
457-
output = torch.tensor([value], device=cpu_device, dtype=torch.int32)
458-
459-
elif isinstance(value, float):
460-
output = torch.tensor([value], device=cpu_device, dtype=torch.float32)
461-
462-
elif isinstance(value, bool):
463-
output = torch.tensor([value], device=cpu_device, dtype=torch.bool)
464-
465-
else:
466-
raise AssertionError(
467-
f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}"
468-
)
469-
470-
return output.to(torch_dtype) if torch_dtype else output
471-
472-
473422
@overload
474423
def get_positive_dim(dim: int, dim_size: int) -> int: ...
475424

@@ -633,42 +582,92 @@ def to_numpy(
633582
Returns:
634583
A Numpy array or None, if the input was None.
635584
"""
636-
output = None
585+
with unset_fake_temporarily():
586+
output = None
637587

638-
if value is None or isinstance(value, np.ndarray):
639-
output = value
588+
if value is None or isinstance(value, np.ndarray):
589+
output = value
640590

641-
elif isinstance(value, torch.Tensor):
642-
if value.is_quantized:
643-
value = value.dequantize()
644-
elif value.dtype == torch.bfloat16:
645-
# TODO: Remove when numpy has a BF16 type
646-
_LOGGER.warning(
647-
"Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation",
591+
elif isinstance(value, torch.Tensor):
592+
if value.is_quantized:
593+
value = value.dequantize()
594+
elif value.dtype == torch.bfloat16:
595+
# TODO: Remove when numpy has a BF16 type
596+
_LOGGER.warning(
597+
"Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation",
598+
)
599+
value = value.to(torch.float)
600+
601+
output = value.cpu().detach().contiguous().numpy()
602+
603+
elif isinstance(value, int):
604+
output = np.array([value], dtype=np.int32)
605+
606+
elif isinstance(value, float):
607+
output = np.array([value], dtype=np.float32)
608+
609+
elif isinstance(value, bool):
610+
output = np.array([value], dtype=np.bool_)
611+
612+
if isinstance(output, np.ndarray) or output is None:
613+
return (
614+
output
615+
if (dtype is None or output is None)
616+
else output.astype(
617+
_enums.dtype._from(dtype).to(np.dtype, use_default=True)
618+
)
619+
)
620+
else:
621+
raise AssertionError(
622+
f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}"
648623
)
649-
value = value.to(torch.float)
650624

651-
output = value.cpu().detach().contiguous().numpy()
652625

653-
elif isinstance(value, int):
654-
output = np.array([value], dtype=np.int32)
626+
def to_torch(
627+
value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]],
628+
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None,
629+
) -> Optional[torch.Tensor]:
630+
"""
631+
Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU
632+
Args:
633+
value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]):
634+
A PyTorch tensor, Numpy array, int, float, or bool
635+
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
636+
If a dtype is given, we will convert the type of the given `value` to this dtype.
637+
Returns:
638+
A PyTorch tensor or None, if the input was None.
639+
"""
655640

656-
elif isinstance(value, float):
657-
output = np.array([value], dtype=np.float32)
641+
cpu_device = torch.device("cpu")
642+
torch_dtype = (
643+
_enums.dtype._from(dtype).to(torch.dtype, use_default=True) if dtype else None
644+
)
658645

659-
elif isinstance(value, bool):
660-
output = np.array([value], dtype=np.bool_)
646+
with unset_fake_temporarily():
647+
if value is None:
648+
return None
661649

662-
if isinstance(output, np.ndarray) or output is None:
663-
return (
664-
output
665-
if (dtype is None or output is None)
666-
else output.astype(_enums.dtype._from(dtype).to(np.dtype, use_default=True))
667-
)
668-
else:
669-
raise AssertionError(
670-
f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}"
671-
)
650+
elif isinstance(value, torch.Tensor):
651+
output = value.to(cpu_device).contiguous()
652+
653+
elif isinstance(value, np.ndarray):
654+
output = torch.from_numpy(value).to(cpu_device).contiguous()
655+
656+
elif isinstance(value, int):
657+
output = torch.tensor([value], device=cpu_device, dtype=torch.int32)
658+
659+
elif isinstance(value, float):
660+
output = torch.tensor([value], device=cpu_device, dtype=torch.float32)
661+
662+
elif isinstance(value, bool):
663+
output = torch.tensor([value], device=cpu_device, dtype=torch.bool)
664+
665+
else:
666+
raise AssertionError(
667+
f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}"
668+
)
669+
670+
return output.to(torch_dtype) if torch_dtype else output
672671

673672

674673
def flatten_dims(

tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import pytest
2-
3-
flashinfer = pytest.importorskip("flashinfer")
41
import unittest
52

3+
import pytest
64
import torch
75
import torch.nn as nn
86
import torch_tensorrt
@@ -12,25 +10,29 @@
1210

1311
from ..conversion.harness import DispatchTestCase
1412

13+
# Toggle this flag to enable/disable flashinfer-based overrides
14+
enable_flashinfer: bool = False
15+
if enable_flashinfer:
16+
import flashinfer
1517

16-
@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
17-
def flashinfer_rmsnorm(
18-
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
19-
) -> torch.Tensor:
20-
return flashinfer.norm.rmsnorm(input, weight)
18+
@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
19+
def flashinfer_rmsnorm(
20+
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
21+
) -> torch.Tensor:
22+
return flashinfer.norm.rmsnorm(input, weight)
2123

24+
@torch.library.register_fake("flashinfer::rmsnorm")
25+
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
26+
return input
2227

23-
@torch.library.register_fake("flashinfer::rmsnorm")
24-
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
25-
return input
28+
torch_tensorrt.dynamo.conversion.plugins.custom_op(
29+
"flashinfer::rmsnorm", supports_dynamic_shapes=True
30+
)
2631

2732

28-
torch_tensorrt.dynamo.conversion.plugins.custom_op(
29-
"flashinfer::rmsnorm", supports_dynamic_shapes=True
33+
@unittest.skip(
34+
"Flashinfer RMSNorm test is disabled due to error: SM75 support not available"
3035
)
31-
32-
33-
@unittest.skip("Not Available")
3436
class TestAutomaticPlugin(DispatchTestCase):
3537
@parameterized.expand(
3638
[

tests/py/dynamo/backend/test_backend_compiler.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
from copy import deepcopy
33

44
import torch
5+
import torch_tensorrt
56
from torch.testing._internal.common_utils import TestCase, run_tests
67
from torch_tensorrt.dynamo.partitioning import fast_partition
78

8-
import torch_tensorrt
9-
109
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
1110

1211

@@ -51,7 +50,6 @@ def forward(self, x, y):
5150
pass_through_build_failures=True,
5251
torch_executed_ops={"torch.ops.aten.add.Tensor"},
5352
use_python_runtime=False,
54-
debug=True,
5553
)
5654
optimized_model_results = optimized_model(*inputs).detach().cpu()
5755
torch_model_results = fx_graph(*inputs).detach().cpu()
@@ -132,7 +130,6 @@ def forward(self, x, y):
132130
pass_through_build_failures=True,
133131
torch_executed_ops={"torch.ops.aten.add.Tensor"},
134132
use_python_runtime=False,
135-
debug=True,
136133
)
137134
optimized_model_results = optimized_model(*inputs).detach().cpu()
138135
torch_model_results = model(*inputs).detach().cpu()
@@ -177,7 +174,6 @@ def forward(self, x, y):
177174
optimization_level=4,
178175
version_compatible=True,
179176
max_aux_streams=5,
180-
debug=True,
181177
)
182178
optimized_model_results = optimized_model(*inputs).detach().cpu()
183179
torch_model_results = fx_graph(*inputs).detach().cpu()
@@ -225,7 +221,6 @@ def forward(self, x, y):
225221
min_block_size=1,
226222
pass_through_build_failures=True,
227223
truncate_double=True,
228-
debug=True,
229224
)
230225
optimized_model_results = optimized_model(*inputs).detach().cpu()
231226
torch_model_results = fx_graph(*inputs).detach().cpu()
@@ -298,7 +293,6 @@ def forward(self, x, y):
298293
min_block_size=1,
299294
pass_through_build_failures=True,
300295
truncate_double=False,
301-
debug=True,
302296
torch_executed_ops={"torch.ops.aten.add.Tensor"},
303297
)
304298
optimized_model_results = optimized_model(*inputs).detach().cpu()

0 commit comments

Comments
 (0)