Skip to content

Commit c636b39

Browse files
apbosenarendasan
andauthored
using nccl ops from TRT-LLM namespace (#3250)
Signed-off-by: Naren Dasan <[email protected]> Co-authored-by: Naren Dasan <[email protected]>
1 parent 85e4332 commit c636b39

24 files changed

+2235
-273
lines changed

examples/distributed_inference/README.md

+34
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,37 @@ See the examples started with `data_parallel` for more details.
1414
Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.
1515

1616
torchrun --nproc_per_node=2 tensor_parallel_llama2.py
17+
18+
3. Tensor parallel distributed inference using nccl ops plugin
19+
20+
apt install libmpich-dev
21+
22+
apt install libopenmpi-dev
23+
24+
#For python3.10
25+
26+
pip install tensorrt-llm
27+
28+
For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so. Please set that in the environment variable export TRTLLM_PLUGINS_PATH={lib_path}. For example, we have already set the variable in initialize_distributed_env(). You can replace this with your TRTLLM_PLUGINS_PATH and unset it there
29+
30+
#then pip install the tensorrt and torch version compatible with installed torchTRT
31+
32+
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
33+
34+
#For other python
35+
36+
4. Tensor parallel distributed llama3 inference using nccl ops plugin
37+
38+
apt install libmpich-dev
39+
40+
apt install libopenmpi-dev
41+
42+
#For python3.10
43+
44+
pip install tensorrt-llm
45+
46+
For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so
47+
48+
#then pip install the tensorrt and torch version compatible with installed torchTRT
49+
50+
mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
accelerate
22
transformers
3-
diffusers
3+
diffusers
4+
tensorrt-llm
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import logging
2+
import os
3+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
4+
5+
import numpy as np
6+
import tensorrt as trt
7+
import torch
8+
import torch.distributed as dist
9+
from torch.distributed._tensor.device_mesh import init_device_mesh
10+
11+
12+
def find_repo_root(max_depth=10):
13+
dir_path = os.path.dirname(os.path.realpath(__file__))
14+
for i in range(max_depth):
15+
files = os.listdir(dir_path)
16+
if "MODULE.bazel" in files:
17+
return dir_path
18+
else:
19+
dir_path = os.path.dirname(dir_path)
20+
21+
raise RuntimeError("Could not find repo root")
22+
23+
24+
def initialize_logger(rank, logger_file_name):
25+
logger = logging.getLogger()
26+
logger.setLevel(logging.INFO)
27+
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
28+
fh.setLevel(logging.INFO)
29+
logger.addHandler(fh)
30+
return logger
31+
32+
33+
# This is required for env initialization since we use mpirun
34+
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
35+
local_rank = int(
36+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
37+
)
38+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
39+
40+
# Set up environment variable to run with mpirun
41+
os.environ["RANK"] = str(local_rank)
42+
os.environ["WORLD_SIZE"] = str(world_size)
43+
os.environ["MASTER_ADDR"] = "127.0.0.1"
44+
os.environ["MASTER_PORT"] = str(port)
45+
os.environ["TRTLLM_PLUGINS_PATH"] = (
46+
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
47+
)
48+
49+
# Necessary to assign a device to each rank.
50+
torch.cuda.set_device(local_rank)
51+
52+
# We use nccl backend
53+
dist.init_process_group("nccl")
54+
55+
# set a manual seed for reproducibility
56+
torch.manual_seed(1111)
57+
58+
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
59+
rank = device_mesh.get_rank()
60+
assert rank == local_rank
61+
logger = initialize_logger(rank, logger_file_name)
62+
device_id = (
63+
rank % torch.cuda.device_count()
64+
) # Ensure each rank gets a unique device
65+
torch.cuda.set_device(device_id)
66+
67+
return device_mesh, world_size, rank, logger

examples/distributed_inference/tensor_parallel_llama3.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,25 @@
55
import time
66

77
import torch
8-
import torch_tensorrt
98
from llama3_model import ModelArgs, ParallelTransformer
9+
from tensor_parallel_initialize_dist import initialize_distributed_env
1010
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
1111
from torch.distributed._composable.fsdp.fully_shard import fully_shard
1212
from torch.distributed._tensor import Replicate, Shard
1313
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1414
checkpoint_wrapper,
1515
)
16-
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
1716

18-
_rank = int(os.environ["RANK"])
19-
_world_size = int(os.environ["WORLD_SIZE"])
20-
tp_size = 2
21-
22-
logger = logging.getLogger()
23-
logger.setLevel(logging.INFO)
24-
fh = logging.FileHandler(f"./tensor_parallel_log_{_rank}.log", mode="w")
25-
fh.setLevel(logging.INFO)
26-
logger.addHandler(fh)
17+
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
18+
"./tensor_parallel_llama3"
19+
)
20+
# Import should be after initialization of the TRT-LLM plugin .so path
21+
import tensorrt_llm
2722

28-
tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
23+
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
24+
assert (
25+
_world_size % 2 == 0
26+
), f"TP examples require even number of GPUs, but got {_world_size} gpus"
2927

3028
model_args = ModelArgs(
3129
vocab_size=32000,
@@ -38,7 +36,7 @@
3836
)
3937

4038
with torch.no_grad():
41-
model = ParallelTransformer(model_args, tp_mesh)
39+
model = ParallelTransformer(model_args, device_mesh)
4240
torch.manual_seed(0)
4341
inp = torch.randint(32000, (8, 256), device="cuda")
4442
python_result = model(inp)
@@ -53,7 +51,7 @@
5351
"use_python_runtime": True,
5452
"workspace_size": 1 << 33,
5553
"debug": False,
56-
"timing_cache_path": "/opt/file/cache/timing_cache_llama.bin",
54+
"use_aot_joint_export": False,
5755
},
5856
dynamic=False,
5957
)

examples/distributed_inference/tensor_parallel_simple_example.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
1-
import os
2-
import sys
31
import time
42

3+
import tensorrt as trt
54
import torch
65
import torch.nn as nn
76
import torch_tensorrt
7+
from tensor_parallel_initialize_dist import initialize_distributed_env
88
from torch.distributed._tensor import Shard
9-
from torch.distributed._tensor.device_mesh import init_device_mesh
109
from torch.distributed.tensor.parallel import (
1110
ColwiseParallel,
1211
RowwiseParallel,
1312
parallelize_module,
1413
)
1514

15+
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
16+
"./tensor_parallel_simple_example"
17+
)
18+
import tensorrt_llm
19+
1620
"""
1721
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
1822
"""
@@ -36,14 +40,7 @@ def forward(self, x):
3640
return x
3741

3842

39-
# create a device mesh based on the given world_size.
40-
_world_size = int(os.environ["WORLD_SIZE"])
41-
42-
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
43-
_rank = device_mesh.get_rank()
44-
45-
46-
print(f"Starting PyTorch TP example on rank {_rank}.")
43+
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
4744
assert (
4845
_world_size % 2 == 0
4946
), f"TP examples require even number of GPUs, but got {_world_size} gpus"
@@ -78,6 +75,7 @@ def forward(self, x):
7875
"enabled_precisions": {torch.float32, torch.float16},
7976
"use_python_runtime": True,
8077
"min_block_size": 1,
78+
"use_aot_joint_export": False,
8179
},
8280
dynamic=False,
8381
)
@@ -91,9 +89,9 @@ def forward(self, x):
9189
output = tp_model(inp)
9290
end = time.time()
9391
if i == 0:
94-
print(f"Compilation time is {end-start}")
92+
logger.info(f"Compilation time is {end-start}")
9593
assert (
9694
python_result - output
9795
).std() < 0.01, "Compilation result is not correct."
9896
elif _rank == 0:
99-
print(f"Inference time is {end-start}")
97+
logger.info(f"Inference time is {end-start}")

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
IMMUTABLE_WEIGHTS = True
4747
ENABLE_WEIGHT_STREAMING = False
4848
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
49+
USE_AOT_JOINT_EXPORT = True
4950

5051

5152
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
STRIP_ENGINE_WEIGHTS,
3434
TIMING_CACHE_PATH,
3535
TRUNCATE_DOUBLE,
36+
USE_AOT_JOINT_EXPORT,
3637
USE_EXPLICIT_TYPING,
3738
USE_FAST_PARTITIONER,
3839
USE_FP32_ACC,
@@ -91,6 +92,7 @@ class CompilationSettings:
9192
enable_weight_streaming (bool): Enable weight streaming.
9293
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
9394
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
95+
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
9496
"""
9597

9698
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -131,6 +133,7 @@ class CompilationSettings:
131133
immutable_weights: bool = IMMUTABLE_WEIGHTS
132134
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
133135
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
136+
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT
134137

135138

136139
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/backend/backends.py

+47-9
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
from __future__ import annotations
22

3+
import functools
34
import logging
45
import unittest
56
from typing import Any, Callable, Sequence
67

78
import torch
89
import torch._dynamo as td
10+
from torch._dynamo.backends.common import aot_autograd
911
from torch._dynamo.utils import detect_fake_mode
1012
from torch._functorch.aot_autograd import aot_export_joint_simple
1113
from torch_tensorrt.dynamo import CompilationSettings
1214
from torch_tensorrt.dynamo._compiler import compile_module
1315
from torch_tensorrt.dynamo.lowering import (
1416
get_decompositions,
17+
modify_reshape_complex_nodes,
1518
post_lowering,
1619
remove_detach,
1720
remove_sym_nodes,
@@ -49,7 +52,25 @@ def aot_torch_tensorrt_aten_backend(
4952
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
5053
) -> torch.nn.Module:
5154
settings, engine_cache = parse_dynamo_kwargs(kwargs)
52-
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
55+
if settings.use_aot_joint_export:
56+
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
57+
logger.debug("Wrapping the backend with aot_autograd\n")
58+
_pretraced_backend_autograd = functools.partial(
59+
_pretraced_backend, settings=settings, engine_cache=engine_cache
60+
)
61+
settings_aot_autograd = {}
62+
settings_aot_autograd["decompostions"] = get_decompositions(
63+
settings.enable_experimental_decompositions
64+
)
65+
# This is added since detach lowering leads to alias nodes
66+
# Error - View operation returned a tensor that is the same as the input base tensor
67+
# torch nop_decompositions in torch/_decomp/decompositions.py
68+
if aten.detach in settings_aot_autograd["decompositions"]:
69+
del settings_aot_autograd["decompositions"][aten.detach]
70+
return aot_autograd(
71+
fw_compiler=_pretraced_backend_autograd,
72+
decompositions=get_decompositions(settings.enable_experimental_decompositions),
73+
)(gm, sample_inputs)
5374

5475

5576
def _pretraced_backend(
@@ -89,22 +110,39 @@ def _pretraced_backend(
89110
# Remove detach nodes
90111
remove_detach(gm, settings)
91112

113+
complexInputIndices = []
114+
for i, torch_input in enumerate(torch_inputs):
115+
if torch_inputs[i].dtype == torch.complex64:
116+
complexInputIndices.append(i)
117+
torch_input_real = torch_inputs[i].real
118+
torch_input_imaginary = torch_inputs[i].imag
119+
torch_inputs[i] = torch.stack(
120+
(torch_input_real, torch_input_imaginary), dim=-1
121+
)
122+
92123
# Invoke AOTAutograd to translate operators to aten
93-
gm = aot_export_joint_simple(
94-
gm,
95-
sample_inputs,
96-
trace_joint=False,
97-
decompositions=get_decompositions(
98-
settings.enable_experimental_decompositions
99-
),
100-
)
124+
if settings.use_aot_joint_export:
125+
gm = aot_export_joint_simple(
126+
gm,
127+
sample_inputs,
128+
trace_joint=False,
129+
decompositions=get_decompositions(
130+
settings.enable_experimental_decompositions
131+
),
132+
)
101133

102134
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
103135

104136
gm = post_lowering(gm, settings)
105137

106138
logger.debug("Lowered Input graph:\n " + str(gm.graph))
107139

140+
if complexInputIndices:
141+
modify_reshape_complex_nodes(gm, complexInputIndices)
142+
logger.debug(
143+
"Input graph after modifying complex nodes:\n " + str(gm.graph)
144+
)
145+
108146
torchtrt_inputs = prepare_inputs(
109147
torch_inputs, disable_memory_format_check=True
110148
)

py/torch_tensorrt/dynamo/conversion/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from . import aten_ops_converters, ops_evaluators, plugins, prims_ops_converters
1+
from . import (
2+
aten_ops_converters,
3+
custom_ops_converters,
4+
ops_evaluators,
5+
plugins,
6+
prims_ops_converters,
7+
)
28
from ._conversion import convert_module, interpret_module_to_result
39
from ._ConversionContext import ConversionContext
410
from ._ConverterRegistry import * # noqa: F403

0 commit comments

Comments
 (0)