Skip to content

Commit 8cf5d71

Browse files
committed
TensorRT-LLM import fix and aot_joint_export specify as explicit setting in dynamo.compile
TRT-LLM installation utilities and adding test cases adding the option in _compiler.py changes in the TRT-LLM loading tool- removing install_wget, install_unzip, install_mpi Further changes in error logging of the TRT-LLM installation tool moving the load_tensorrt_llm to dynamo/utils.py correcting misprint for TRT LLM load Using python lib for download to make it platform agnostic dll file path update for windows correcting the non critical lint error Including version in versions.txt
1 parent 4cbd28f commit 8cf5d71

File tree

6 files changed

+149
-67
lines changed

6 files changed

+149
-67
lines changed

dev_dep_versions.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
__cuda_version__: "12.8"
22
__tensorrt_version__: "10.9.0"
3+
__tensorrt_llm_version__: "0.17.0.post1"

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def cross_compile_for_windows(
100100
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
101101
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
102102
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
103+
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
103104
**kwargs: Any,
104105
) -> torch.fx.GraphModule:
105106
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -175,6 +176,7 @@ def cross_compile_for_windows(
175176
enable_weight_streaming (bool): Enable weight streaming.
176177
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
177178
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
179+
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
178180
**kwargs: Any,
179181
Returns:
180182
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -334,6 +336,7 @@ def cross_compile_for_windows(
334336
"enable_weight_streaming": enable_weight_streaming,
335337
"tiling_optimization_level": tiling_optimization_level,
336338
"l2_limit_for_tiling": l2_limit_for_tiling,
339+
"use_distributed_mode_trace": use_distributed_mode_trace,
337340
}
338341

339342
# disable the following settings is not supported for cross compilation for windows feature
@@ -435,6 +438,7 @@ def compile(
435438
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
436439
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
437440
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
441+
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
438442
**kwargs: Any,
439443
) -> torch.fx.GraphModule:
440444
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -512,7 +516,11 @@ def compile(
512516
enable_weight_streaming (bool): Enable weight streaming.
513517
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
514518
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
519+
<<<<<<< HEAD
515520
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
521+
=======
522+
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
523+
>>>>>>> c3b62d239 (TensorRT-LLM import fix and aot_joint_export specify as explicit setting in dynamo.compile)
516524
**kwargs: Any,
517525
Returns:
518526
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -681,6 +689,7 @@ def compile(
681689
"tiling_optimization_level": tiling_optimization_level,
682690
"l2_limit_for_tiling": l2_limit_for_tiling,
683691
"offload_module_to_cpu": offload_module_to_cpu,
692+
"use_distributed_mode_trace": use_distributed_mode_trace,
684693
}
685694

686695
settings = CompilationSettings(**compilation_options)
@@ -986,6 +995,7 @@ def convert_exported_program_to_serialized_trt_engine(
986995
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
987996
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
988997
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
998+
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
989999
**kwargs: Any,
9901000
) -> bytes:
9911001
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -1051,6 +1061,7 @@ def convert_exported_program_to_serialized_trt_engine(
10511061
enable_weight_streaming (bool): Enable weight streaming.
10521062
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
10531063
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
1064+
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
10541065
Returns:
10551066
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
10561067
"""
@@ -1170,6 +1181,7 @@ def convert_exported_program_to_serialized_trt_engine(
11701181
"tiling_optimization_level": tiling_optimization_level,
11711182
"l2_limit_for_tiling": l2_limit_for_tiling,
11721183
"offload_module_to_cpu": offload_module_to_cpu,
1184+
"use_distributed_mode_trace": use_distributed_mode_trace,
11731185
}
11741186

11751187
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 2 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import collections
2-
import ctypes
32
import functools
43
import logging
5-
import os
64
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload
75

86
import numpy as np
@@ -13,6 +11,7 @@
1311
from torch.fx.node import Argument, Target
1412
from torch.fx.passes.shape_prop import TensorMetadata
1513
from torch_tensorrt import _enums
14+
from torch_tensorrt._enums import Platform
1615
from torch_tensorrt.dynamo._settings import CompilationSettings
1716
from torch_tensorrt.dynamo._SourceIR import SourceIR
1817
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -1012,69 +1011,6 @@ def args_bounds_check(
10121011
return args[i] if len(args) > i and args[i] is not None else replacement
10131012

10141013

1015-
def load_tensorrt_llm() -> bool:
1016-
"""
1017-
Attempts to load the TensorRT-LLM plugin and initialize it.
1018-
1019-
Returns:
1020-
bool: True if the plugin was successfully loaded and initialized, False otherwise.
1021-
"""
1022-
try:
1023-
import tensorrt_llm as trt_llm # noqa: F401
1024-
1025-
_LOGGER.info("TensorRT-LLM successfully imported")
1026-
return True
1027-
except (ImportError, AssertionError) as e_import_error:
1028-
# Check for environment variable for the plugin library path
1029-
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
1030-
if not plugin_lib_path:
1031-
_LOGGER.warning(
1032-
"TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops",
1033-
)
1034-
return False
1035-
1036-
_LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}")
1037-
try:
1038-
# Load the shared library
1039-
handle = ctypes.CDLL(plugin_lib_path)
1040-
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
1041-
except OSError as e_os_error:
1042-
_LOGGER.error(
1043-
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
1044-
f"Ensure the path is correct and the library is compatible",
1045-
exc_info=e_os_error,
1046-
)
1047-
return False
1048-
1049-
try:
1050-
# Configure plugin initialization arguments
1051-
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
1052-
handle.initTrtLlmPlugins.restype = ctypes.c_bool
1053-
except AttributeError as e_plugin_unavailable:
1054-
_LOGGER.warning(
1055-
"Unable to initialize the TensorRT-LLM plugin library",
1056-
exc_info=e_plugin_unavailable,
1057-
)
1058-
return False
1059-
1060-
try:
1061-
# Initialize the plugin
1062-
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
1063-
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
1064-
_LOGGER.info("TensorRT-LLM plugin successfully initialized")
1065-
return True
1066-
else:
1067-
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization")
1068-
return False
1069-
except Exception as e_initialization_error:
1070-
_LOGGER.warning(
1071-
"Exception occurred during TensorRT-LLM plugin library initialization",
1072-
exc_info=e_initialization_error,
1073-
)
1074-
return False
1075-
return False
1076-
1077-
10781014
def promote_trt_tensors_to_same_dtype(
10791015
ctx: ConversionContext, lhs: TRTTensor, rhs: TRTTensor, name_prefix: str
10801016
) -> tuple[TRTTensor, TRTTensor]:
@@ -1112,3 +1048,4 @@ def promote_trt_tensors_to_same_dtype(
11121048
rhs_cast = cast_trt_tensor(ctx, rhs, promoted_dtype, f"{name_prefix}rhs_cast")
11131049

11141050
return lhs_cast, rhs_cast
1051+

py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
1212
dynamo_tensorrt_converter,
1313
)
14-
from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm
1514
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
1615
tensorrt_fused_nccl_all_gather_op,
1716
tensorrt_fused_nccl_reduce_scatter_op,
1817
)
18+
from torch_tensorrt.dynamo.utils import load_tensorrt_llm
1919

2020
_LOGGER: logging.Logger = logging.getLogger(__name__)
2121

py/torch_tensorrt/dynamo/utils.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
22

3+
import ctypes
34
import gc
45
import logging
6+
import os
7+
import urllib.request
58
import warnings
69
from dataclasses import fields, replace
710
from enum import Enum
@@ -14,9 +17,10 @@
1417
from torch._subclasses.fake_tensor import FakeTensor
1518
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1619
from torch_tensorrt._Device import Device
17-
from torch_tensorrt._enums import dtype
20+
from torch_tensorrt._enums import Platform, dtype
1821
from torch_tensorrt._features import ENABLED_FEATURES
1922
from torch_tensorrt._Input import Input
23+
from torch_tensorrt._version import __tensorrt_llm_version__
2024
from torch_tensorrt.dynamo import _defaults
2125
from torch_tensorrt.dynamo._defaults import default_device
2226
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
@@ -817,3 +821,127 @@ def is_tegra_platform() -> bool:
817821
if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]:
818822
return True
819823
return False
824+
825+
826+
def download_plugin_lib_path(py_version: str, platform: str) -> str:
827+
plugin_lib_path = None
828+
829+
# Downloading TRT-LLM lib
830+
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
831+
file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{py_version}-{py_version}-{platform}.whl"
832+
download_url = base_url + file_name
833+
if not (os.path.exists(file_name)):
834+
try:
835+
logger.debug(f"Downloading {download_url} ...")
836+
urllib.request.urlretrieve(download_url, file_name)
837+
logger.debug("Download succeeded and TRT-LLM wheel is now present")
838+
except urllib.error.HTTPError as e:
839+
logger.error(
840+
f"HTTP error {e.code} when trying to download {download_url}: {e.reason}"
841+
)
842+
except urllib.error.URLError as e:
843+
logger.error(
844+
f"URL error when trying to download {download_url}: {e.reason}"
845+
)
846+
except OSError as e:
847+
logger.error(f"Local file write error: {e}")
848+
849+
# Proceeding with the unzip of the wheel file
850+
# This will exist if the filename was already downloaded
851+
if "linux" in platform:
852+
lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
853+
else:
854+
lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
855+
plugin_lib_path = os.path.join("./tensorrt_llm/libs", lib_filename)
856+
if os.path.exists(plugin_lib_path):
857+
return plugin_lib_path
858+
try:
859+
import zipfile
860+
except ImportError as e:
861+
raise ImportError(
862+
"zipfile module is required but not found. Please install zipfile"
863+
)
864+
with zipfile.ZipFile(file_name, "r") as zip_ref:
865+
zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm'
866+
plugin_lib_path = "./tensorrt_llm/libs/" + lib_filename
867+
return plugin_lib_path
868+
869+
870+
def load_tensorrt_llm() -> bool:
871+
"""
872+
Attempts to load the TensorRT-LLM plugin and initialize it.
873+
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
874+
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
875+
876+
Returns:
877+
bool: True if the plugin was successfully loaded and initialized, False otherwise.
878+
"""
879+
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
880+
if not plugin_lib_path:
881+
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
882+
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
883+
"1",
884+
"true",
885+
"yes",
886+
"on",
887+
)
888+
if not use_trtllm_plugin:
889+
logger.warning(
890+
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
891+
)
892+
return False
893+
else:
894+
# this is used as the default py version
895+
py_version = "cp310"
896+
platform = Platform.current_platform()
897+
898+
platform = str(platform).lower()
899+
plugin_lib_path = download_plugin_lib_path(py_version, platform)
900+
901+
try:
902+
# Load the shared TRT-LLM file
903+
handle = ctypes.CDLL(plugin_lib_path)
904+
logger.info(f"Successfully loaded plugin library: {plugin_lib_path}")
905+
except OSError as e_os_error:
906+
if "libmpi" in str(e_os_error):
907+
logger.warning(
908+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. "
909+
f"The dependency libmpi.so is missing. "
910+
f"Please install the packages libmpich-dev and libopenmpi-dev.",
911+
exc_info=e_os_error,
912+
)
913+
else:
914+
logger.warning(
915+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
916+
f"Ensure the path is correct and the library is compatible",
917+
exc_info=e_os_error,
918+
)
919+
return False
920+
921+
try:
922+
# Configure plugin initialization arguments
923+
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
924+
handle.initTrtLlmPlugins.restype = ctypes.c_bool
925+
except AttributeError as e_plugin_unavailable:
926+
logger.warning(
927+
"Unable to initialize the TensorRT-LLM plugin library",
928+
exc_info=e_plugin_unavailable,
929+
)
930+
return False
931+
932+
try:
933+
# Initialize the plugin
934+
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
935+
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
936+
logger.info("TensorRT-LLM plugin successfully initialized")
937+
return True
938+
else:
939+
logger.warning("TensorRT-LLM plugin library failed in initialization")
940+
return False
941+
except Exception as e_initialization_error:
942+
logger.warning(
943+
"Exception occurred during TensorRT-LLM plugin library initialization",
944+
exc_info=e_initialization_error,
945+
)
946+
return False
947+
return False

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
__version__: str = "0.0.0"
2929
__cuda_version__: str = "0.0"
3030
__tensorrt_version__: str = "0.0"
31+
__tensorrt_llm_version__: str = "0.0"
3132

3233
LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
3334

@@ -63,6 +64,7 @@ def get_base_version() -> str:
6364
def load_dep_info():
6465
global __cuda_version__
6566
global __tensorrt_version__
67+
global __tensorrt_llm_version__
6668
with open("dev_dep_versions.yml", "r") as stream:
6769
versions = yaml.safe_load(stream)
6870
if (gpu_arch_version := os.environ.get("CU_VERSION")) is not None:
@@ -72,6 +74,7 @@ def load_dep_info():
7274
else:
7375
__cuda_version__ = versions["__cuda_version__"]
7476
__tensorrt_version__ = versions["__tensorrt_version__"]
77+
__tensorrt_llm_version__ = versions["__tensorrt_llm_version__"]
7578

7679

7780
load_dep_info()
@@ -223,6 +226,7 @@ def gen_version_file():
223226
f.write('__version__ = "' + __version__ + '"\n')
224227
f.write('__cuda_version__ = "' + __cuda_version__ + '"\n')
225228
f.write('__tensorrt_version__ = "' + __tensorrt_version__ + '"\n')
229+
f.write('__tensorrt_llm_version__ = "' + __tensorrt_llm_version__ + '"\n')
226230

227231

228232
def copy_libtorchtrt(multilinux=False, rt_only=False):

0 commit comments

Comments
 (0)