-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Expand file tree
/
Copy pathbackend.py
More file actions
170 lines (145 loc) · 6.22 KB
/
backend.py
File metadata and controls
170 lines (145 loc) · 6.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
from typing import List, Optional
import torch
import torch._inductor.config as inductor_config
from torch._functorch.aot_autograd import aot_module_simplified
from torch._inductor.compile_fx import compile_fx, select_decomp_table
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._subclasses import FakeTensor
from torch.fx import GraphModule
import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm.mapping import Mapping
from .multi_stream.auto_multi_stream import multi_stream_schedule
from .patterns.ar_residual_norm import register_ar_fusions
from .patterns.residual_add_norm import register_add_norm
from .piecewise_optimizer import piecewise_optimizer
from .recover_pass import recover_pass
from .remove_copy_pass import remove_copy_for_mutates_args
class Backend:
_custom_pass_instances: List[PatternMatcherPass] = None
_graph_pool_handle: tuple[int, int] = None
# Following classes are used to let weakref ref the stream and eventlist objects.
class Streams(list):
pass
class Events(list):
pass
def __init__(
self,
enable_inductor=True,
enable_userbuffers=False,
enable_piecewise_cuda_graph: bool = False,
capture_num_tokens: Optional[List[int]] = None,
max_num_streams: int = 1,
mapping=None,
) -> None:
super().__init__()
self.elapsed_time = 0
self.module_inference_event = []
self.module_inference_time = 0
self.call_count = 0
self.mapping = mapping
self.custom_passes = Backend.get_custom_pass(enable_userbuffers,
mapping)
self.rank = tensorrt_llm.mpi_rank()
self.enable_inductor = enable_inductor
self.capture_num_tokens = sorted(capture_num_tokens or [])
self.piecewise_cuda_graph = enable_piecewise_cuda_graph
self.no_optimization = False
self.num_streams = max_num_streams
self.events = Backend.Events()
inductor_config.enable_auto_functionalized_v2 = False
if Backend._graph_pool_handle is None:
Backend._graph_pool_handle = torch.cuda.graph_pool_handle()
self.match_count = []
@classmethod
def get_custom_pass(cls, enable_userbuffers, mapping: Mapping):
world_size = tensorrt_llm.mpi_world_size()
if not cls._custom_pass_instances:
# Really naive pass manager here
cls._custom_pass_instances = [PatternMatcherPass()]
if world_size > 1:
# Currently torch compile cannot work properly with lamport fusion kernel
# TO-DO: Fix this issue
os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1"
ub_enabled = enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
)
register_ar_fusions(cls._custom_pass_instances, mapping,
ub_enabled)
# Fallback: fuse remaining add+rmsnorm not preceded by allreduce
cls._custom_pass_instances.append(PatternMatcherPass())
register_add_norm(cls._custom_pass_instances[-1])
else:
register_add_norm(cls._custom_pass_instances[0])
return cls._custom_pass_instances
def bypass_optimization(self):
self.no_optimization = True
def enable_optimization(self):
self.no_optimization = False
def generate_events(self, num_events: int):
if num_events > len(self.events):
self.events += [
torch.cuda.Event() for _ in range(num_events - len(self.events))
]
def optimize(
self,
gm: GraphModule,
example_inputs: List[torch.Tensor],
):
graph = gm.graph
for custom_pass in self.custom_passes:
self.match_count.append(custom_pass.apply(graph))
while self.match_count[-1]:
self.match_count.append(custom_pass.apply(graph))
graph.eliminate_dead_code()
# After this pass, cannot run any dce!!!
remove_copy_for_mutates_args(graph)
# Do not apply multi-stream if enable piecewise cuda graph or inductor
# For piecewise cuda graph, we will apply the multi-stream optimization in piecewise_optimizer
# For inductor, we do not control the passes inside inductor.
if self.num_streams > 1 and not self.piecewise_cuda_graph and not self.enable_inductor:
num_events = multi_stream_schedule(gm, self.num_streams)
self.generate_events(num_events)
gm.recompile()
if self.piecewise_cuda_graph:
gm, num_events = piecewise_optimizer(
gm,
example_inputs,
self.enable_inductor,
self.input_num_tokens,
self.capture_num_tokens,
self._graph_pool_handle,
self.num_streams,
)
self.generate_events(num_events)
return gm
elif self.enable_inductor:
return compile_fx(gm, example_inputs)
else:
return gm
def __call__(self, gm: GraphModule,
example_inputs: List[torch.Tensor]) -> callable:
if self.no_optimization:
logger.warning(
"Bypassing torch.compile optimization and fallback to eager execution!"
)
return gm
self.input_num_tokens = None
for node in gm.graph.nodes:
if node.op == "placeholder":
if node.name in ["l_input_ids_", "l_kwargs_input_ids_"]:
example_value = node.meta["example_value"]
assert isinstance(example_value, FakeTensor)
self.input_num_tokens = example_value.shape[0]
break
if self.piecewise_cuda_graph:
assert (
self.input_num_tokens is not None
), "Cannot detect input_num_tokens. Cannot use piecewise CUDA graph. What is the name of `input_ids`?"
gm = recover_pass(gm)
return aot_module_simplified(
gm,
example_inputs,
fw_compiler=self.optimize,
decompositions=select_decomp_table(),
)