Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 105 additions & 20 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Tensor Parallel Initialize Distributed Environment
==================================================

This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference.
This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. These utilities are useful for tensor parallel distributed inference examples using torch.distributed.
"""

import logging
Expand All @@ -14,32 +14,68 @@
import tensorrt as trt
import torch
import torch.distributed as dist
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh


def find_repo_root(max_depth=10):
dir_path = os.path.dirname(os.path.realpath(__file__))
for i in range(max_depth):
files = os.listdir(dir_path)
if "MODULE.bazel" in files:
return dir_path
else:
dir_path = os.path.dirname(dir_path)
def initialize_logger(
rank, logger_file_name, file_level=logging.DEBUG, console_level=logging.INFO
):
"""Initialize rank-specific Torch-TensorRT logger with configurable handler levels.

raise RuntimeError("Could not find repo root")
Logger level is set to DEBUG (pass-through), handlers control filtering for files and stream buffers

Args:
rank: Process rank for multi-GPU
logger_file_name: Base name for log file (will add _rank.log)
file_level: What goes to file - default DEBUG (everything)
console_level: What prints to console - default INFO (clean output)
"""
logger = logging.getLogger("torch_tensorrt")
logger.setLevel(logging.DEBUG)
logger.handlers.clear()

def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# File handler
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
fh.setLevel(file_level)
fh.setFormatter(
logging.Formatter(
f"[Rank {rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
)
logger.addHandler(fh)

# console handler
ch = logging.StreamHandler()
ch.setLevel(
console_level
) # Console handler controls what's printed in console output
ch.setFormatter(logging.Formatter(f"[Rank {rank}] %(levelname)s: %(message)s"))
logger.addHandler(ch)

# safegauard though not reqd
logger.propagate = False
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
def initialize_distributed_env(
logger_file_name,
rank=0,
world_size=1,
port=29500,
file_level="debug",
console_level="info",
):
"""Initialize distributed environment with handler-based logging.

Args:
logger_file_name: Base name for log files
rank: Initial rank (overridden by OMPI env vars)
world_size: Initial world size (overridden by OMPI env vars)
port: Master port for distributed communication
file_level: File handler level - "debug", "info", "warning" (default: "debug")
console_level: Console handler level - "debug", "info", "warning" (default: "info")
"""
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
Expand All @@ -50,9 +86,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["TRTLLM_PLUGINS_PATH"] = (
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)
Expand All @@ -66,16 +99,68 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)
# Convert string handler levels to logging constants
level_map = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
}
file_level_int = level_map.get(file_level.lower(), logging.DEBUG)
console_level_int = level_map.get(console_level.lower(), logging.INFO)

# Initialize logger with handler-specific levels
# Logger itself is always DEBUG - handlers do the filtering
logger = initialize_logger(
rank,
logger_file_name,
file_level=file_level_int,
console_level=console_level_int,
)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

# Set C++ TensorRT runtime log level based on most verbose handler
# Use the most verbose level to ensure all important logs are captured
cpp_level = min(file_level_int, console_level_int)
try:
import torch_tensorrt.logging as torchtrt_logging

torchtrt_logging.set_level(cpp_level)
except Exception as e:
logger.warning(f"Could not set C++ TensorRT log level: {e}")

return device_mesh, world_size, rank, logger


def cleanup_distributed_env():
"""Clean up distributed process group to prevent resource leaks."""
if dist.is_initialized():
dist.destroy_process_group()


def check_tensor_parallel_device_number(world_size: int) -> None:
if world_size % 2 != 0:
raise ValueError(
f"TP examples require even number of GPUs, but got {world_size} gpus"
)


def get_tensor_parallel_device_mesh(
rank: int = 0, world_size: int = 1
) -> tuple[DeviceMesh, int, int]:
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,19 @@

"""

import logging
import os
import time

import torch
import torch_tensorrt
from rotary_embedding import RotaryAttention, parallel_rotary_block
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
initialize_distributed_env,
)

# Initialize distributed environment and logger BEFORE importing torch_tensorrt
# This ensures logging is configured before any import-time log messages
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_rotary_embedding"
"tensor_parallel_rotary_embedding"
)


Expand Down
15 changes: 9 additions & 6 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
.. _tensor_parallel_simple_example:

Torch Parallel Distributed example for simple model
Tensor Parallel Distributed Inference with Torch-TensorRT
=========================================

Below example shows how to use Torch-TensorRT backend for distributed inference with tensor parallelism.
Expand All @@ -25,22 +25,25 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
get_tensor_parallel_device_mesh,
initialize_distributed_env,
)

# Initialize distributed environment and logger BEFORE importing torch_tensorrt
# This ensures logging is configured before any import-time log messages
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"tensor_parallel_simple_example"
)

from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
)

"""
This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""
Expand Down
36 changes: 34 additions & 2 deletions py/torch_tensorrt/_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,29 @@

def _enabled_features_str() -> str:
enabled = lambda x: "ENABLED" if x else "DISABLED"
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n" # type: ignore[no-untyped-call]
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n - TensorRT-LLM for NCCL: {enabled(_TRTLLM_AVAIL)}\n" # type: ignore[no-untyped-call]
return out_str


# Inline helper functions for checking feature availability
def has_torch_tensorrt_runtime() -> bool:
"""Check if Torch-TensorRT C++ runtime is available.

Returns:
bool: True if libtorchtrt_runtime.so or libtorchtrt.so is available
"""
return bool(ENABLED_FEATURES.torch_tensorrt_runtime)


def has_torchscript_frontend() -> bool:
"""Check if TorchScript frontend is available.

Returns:
bool: True if libtorchtrt.so is available
"""
return bool(ENABLED_FEATURES.torchscript_frontend)


def needs_tensorrt_rtx(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.tensorrt_rtx:
Expand Down Expand Up @@ -165,14 +184,27 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:


def needs_trtllm_for_nccl(f: Callable[..., Any]) -> Callable[..., Any]:
"""
Runtime check decorator for TensorRT-LLM NCCL plugin availability.

WARNING: This decorator CANNOT prevent registration of converters at import time.
When used with @dynamo_tensorrt_converter, the converter is always registered
regardless of decorator order, because registration happens at import time before
the wrapper is called.

This decorator is kept for potential non-registration use cases where
runtime checks are appropriate.
@apbose: to discuss if this is required
"""

def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.trtllm_for_nccl:
return f(*args, **kwargs)
else:

def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
raise NotImplementedError(
"Refit feature is currently not available in Python 3.13 or higher"
"TensorRT-LLM plugin for NCCL is not available"
)

return not_implemented(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import torch._dynamo as td
import torch_tensorrt.logging as torchtrt_logging
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import aot_export_joint_simple
Expand All @@ -23,7 +24,6 @@
from torch_tensorrt.dynamo.utils import (
parse_dynamo_kwargs,
prepare_inputs,
set_log_level,
)

logger = logging.getLogger(__name__)
Expand All @@ -40,7 +40,7 @@ def torch_tensorrt_backend(
and "debug" in kwargs["options"]
and kwargs["options"]["debug"]
) or ("debug" in kwargs and kwargs["debug"]):
set_log_level(logger.parent, logging.DEBUG)
torchtrt_logging.set_level(logging.DEBUG)

DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend

Expand Down
Loading
Loading