Skip to content

Commit 4ffb8b3

Browse files
committed
rebased to main
1 parent 1d038a1 commit 4ffb8b3

File tree

9 files changed

+238
-32
lines changed

9 files changed

+238
-32
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
)
4343
from torch_tensorrt.dynamo.utils import (
4444
deallocate_module,
45+
get_cpu_memory_usage,
4546
get_flat_args_with_check,
4647
get_output_metadata,
4748
parse_graph_io,
@@ -681,7 +682,7 @@ def compile(
681682
"offload_module_to_cpu": offload_module_to_cpu,
682683
"use_distributed_mode_trace": use_distributed_mode_trace,
683684
}
684-
685+
logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
685686
settings = CompilationSettings(**compilation_options)
686687
logger.info("Compilation Settings: %s\n", settings)
687688
exported_program = pre_export_lowering(exported_program, settings)
@@ -695,14 +696,17 @@ def compile(
695696

696697
# Apply lowering on the graph module
697698
gm = post_lowering(gm, settings)
699+
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB")
698700
logger.debug("Lowered Input graph: " + str(gm.graph))
699701

700702
# Move the weights in the state_dict to CPU
701703
if offload_module_to_cpu:
704+
deallocate_module(gm, delete_module=False)
702705
deallocate_module(exported_program.module(), delete_module=False)
703706
logger.info(
704707
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
705708
)
709+
logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB")
706710
else:
707711
remaining_memory, total_memory = torch.cuda.mem_get_info()
708712
if remaining_memory < total_memory // 2:
@@ -868,6 +872,11 @@ def preserve_module_specs(
868872
# Iterate over all components that can be accelerated
869873
# Generate the corresponding TRT Module for those
870874

875+
# Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function.
876+
# This is done to release CPU memory.
877+
for attr in dir(gm):
878+
if attr.startswith("_frozen_param"):
879+
delattr(gm, attr)
871880
for name, _ in partitioned_module.named_children():
872881
submodule = getattr(partitioned_module, name)
873882
# filter on the GraphModule
@@ -1243,7 +1252,7 @@ def convert_exported_program_to_serialized_trt_engine(
12431252

12441253
# Prepare torch_trt inputs
12451254
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
1246-
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
1255+
trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs)
12471256
device = to_torch_tensorrt_device(device)
12481257
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
12491258

@@ -1330,7 +1339,7 @@ def convert_exported_program_to_serialized_trt_engine(
13301339
)
13311340

13321341
flattened_input_list = get_flat_args_with_check(
1333-
exported_program, list(trt_arg_inputs), trt_kwarg_inputs
1342+
exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore
13341343
)[0]
13351344

13361345
try:

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@
5050
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
5151
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
5252
from torch_tensorrt.dynamo.observer import Observer
53-
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device
53+
from torch_tensorrt.dynamo.utils import (
54+
DYNAMIC_DIM,
55+
deallocate_module,
56+
get_cpu_memory_usage,
57+
to_torch_device,
58+
)
5459
from torch_tensorrt.logging import TRT_LOGGER
5560

5661
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -65,7 +70,7 @@ class UnsupportedOperatorException(RuntimeError):
6570

6671

6772
class TRTInterpreterResult(NamedTuple):
68-
serialized_engine: bytes
73+
engine: trt.ICudaEngine
6974
input_names: Sequence[str]
7075
output_names: Sequence[str]
7176
weight_name_map: Optional[dict[Any, Any]]
@@ -512,8 +517,7 @@ def _save_weight_mapping(self) -> None:
512517
_LOGGER.info("Building weight name mapping...")
513518
# Stage 1: Name mapping
514519
torch_device = to_torch_device(self.compilation_settings.device)
515-
self.module.to(torch_device)
516-
sd = self.module.state_dict()
520+
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
517521
weight_name_map: dict[str, Any] = {}
518522
weight_refit_map = self.ctx.weight_refit_map
519523
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
@@ -592,13 +596,11 @@ def _save_weight_mapping(self) -> None:
592596
torch.cuda.empty_cache()
593597

594598
@needs_refit # type: ignore[misc]
595-
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
599+
def _insert_engine_to_cache(self, hash_val: str, engine: trt.ICudaEngine) -> None:
600+
serialized_engine = engine.serialize()
596601
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
597602
# if not self.compilation_settings.strip_engine_weights:
598603
# # set EXCLUDE_WEIGHTS flag to strip weights
599-
# runtime = trt.Runtime(TRT_LOGGER)
600-
# engine = runtime.deserialize_cuda_engine(serialized_engine)
601-
602604
# serialization_config = engine.create_serialization_config()
603605
# serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
604606
# serialized_engine = engine.serialize_with_config(
@@ -733,6 +735,9 @@ def run(
733735
return interpreter_result # type: ignore[no-any-return]
734736

735737
self._construct_trt_network_def()
738+
_LOGGER.debug(
739+
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"
740+
)
736741

737742
if not self.compilation_settings.immutable_weights:
738743
self._save_weight_mapping()
@@ -750,16 +755,19 @@ def run(
750755
self._create_timing_cache(
751756
builder_config, self.compilation_settings.timing_cache_path
752757
)
753-
serialized_engine = self.builder.build_serialized_network(
758+
759+
cuda_engine = self.builder.build_engine_with_config(
754760
self.ctx.net, builder_config
755761
)
756-
assert serialized_engine
762+
assert cuda_engine
763+
764+
_LOGGER.debug(
765+
f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB"
766+
)
757767

758768
_LOGGER.info(
759769
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
760770
)
761-
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
762-
763771
self.ctx.clear_cpu_weights_reference_holder()
764772

765773
self._save_timing_cache(
@@ -772,14 +780,10 @@ def run(
772780
and self.compilation_settings.cache_built_engines
773781
and self.engine_cache is not None
774782
):
775-
self._insert_engine_to_cache(hash_val, serialized_engine)
776-
777-
with io.BytesIO() as engine_bytes:
778-
engine_bytes.write(serialized_engine)
779-
engine_str = engine_bytes.getvalue()
783+
self._insert_engine_to_cache(hash_val, cuda_engine)
780784

781785
return TRTInterpreterResult(
782-
engine_str,
786+
cuda_engine,
783787
self._input_names,
784788
self._output_names,
785789
self.weight_name_map,

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,34 @@
11
from __future__ import annotations
22

3+
import io
34
import logging
4-
from typing import Any, List, Optional, Sequence
5+
from typing import Any, List, NamedTuple, Optional, Sequence
56

67
import torch
78
from torch_tensorrt._enums import dtype
89
from torch_tensorrt._features import ENABLED_FEATURES
910
from torch_tensorrt._Input import Input
1011
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
1112
from torch_tensorrt.dynamo._settings import CompilationSettings
12-
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
13-
TRTInterpreter,
14-
TRTInterpreterResult,
15-
)
13+
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
1614
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
17-
from torch_tensorrt.dynamo.utils import get_output_dtypes
15+
from torch_tensorrt.dynamo.utils import (
16+
get_cpu_memory_usage,
17+
get_output_dtypes,
18+
release_memory,
19+
)
1820

1921
logger = logging.getLogger(__name__)
2022

2123

24+
class SerializedInterpreterResult(NamedTuple):
25+
serialized_engine: bytes
26+
input_names: Sequence[str]
27+
output_names: Sequence[str]
28+
weight_name_map: Optional[dict[Any, Any]]
29+
requires_output_allocator: bool
30+
31+
2232
def infer_module_output_dtypes(
2333
module: torch.fx.GraphModule,
2434
truncate_double: bool = False,
@@ -29,7 +39,7 @@ def infer_module_output_dtypes(
2939
"""
3040
outputs = [node for node in module.graph.nodes if node.op == "output"]
3141
outputs = outputs[0].args
32-
return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return]
42+
return get_output_dtypes(outputs, truncate_double)
3343

3444

3545
def interpret_module_to_result(
@@ -39,7 +49,7 @@ def interpret_module_to_result(
3949
arg_inputs: Optional[Sequence[Input]] = None,
4050
kwarg_inputs: Optional[dict[str, Any]] = None,
4151
engine_cache: Optional[BaseEngineCache] = None,
42-
) -> TRTInterpreterResult:
52+
) -> SerializedInterpreterResult:
4353
"""Interpret an FX module to a TRTInterpreterResult
4454
Args:
4555
module: FX GraphModule to interpret
@@ -65,7 +75,32 @@ def interpret_module_to_result(
6575
)
6676

6777
interpreter_result = interpreter.run()
68-
return interpreter_result
78+
# Delete the frozen parameters from the module to release CPU memory
79+
del interpreter
80+
for attr in dir(module):
81+
if attr.startswith("_frozen_param"):
82+
delattr(module, attr)
83+
release_memory()
84+
logger.debug(
85+
f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB"
86+
)
87+
88+
serialized_engine = interpreter_result.engine.serialize()
89+
with io.BytesIO() as engine_bytes:
90+
engine_bytes.write(serialized_engine)
91+
serialized_engine = engine_bytes.getvalue()
92+
logger.debug(
93+
f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB"
94+
)
95+
serialized_interpreter_result = SerializedInterpreterResult(
96+
serialized_engine=serialized_engine,
97+
input_names=interpreter_result.input_names,
98+
output_names=interpreter_result.output_names,
99+
weight_name_map=interpreter_result.weight_name_map,
100+
requires_output_allocator=interpreter_result.requires_output_allocator,
101+
)
102+
103+
return serialized_interpreter_result
69104

70105

71106
def convert_module(

py/torch_tensorrt/dynamo/debug/_Debugger.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]:
220220
"class": "logging.FileHandler",
221221
"filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log",
222222
"formatter": "standard",
223+
"mode": "w", # This will clear the previous content
223224
}
224225
config["loggers"][""]["handlers"].append("file")
225226
return config

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def constant_fold(
3737
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
3838
for node, constant in cf.node_replacements.items():
3939
replace_node_with_constant(
40-
gm, node, torch.nn.Parameter(constant, requires_grad=False)
40+
gm,
41+
node,
42+
torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False),
4143
)
4244

4345
erased_params = []

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from typing import Collection, Dict, List, Optional, Tuple
33

4+
import psutil
45
import torch
56
import torch.fx.passes.operator_support as ops
67
from torch.fx.node import Target
@@ -225,13 +226,80 @@ def partition_graph(self) -> torch.fx.GraphModule:
225226
# Remove segments smaller than the block size (with exceptions)
226227
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
227228

229+
num_of_break = self.calculate_num_of_break(subgraphs)
230+
subgraphs = self.break_subgraphs(subgraphs, num_of_break=num_of_break)
231+
228232
# Set the number of TRT engines to be generated
229233
self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc])
230234

231235
# Tag the accelerated nodes and split the graph accordingly
232236
self.tag(subgraphs)
233237
return self.split()
234238

239+
def calculate_num_of_break(self, subgraphs: List[Subgraph]) -> int:
240+
"""
241+
This function calculates the break period based on the number of subgraphs.
242+
"""
243+
rss = psutil.Process().memory_info().rss
244+
available_rss = psutil.virtual_memory().available
245+
num_of_graphs = len(subgraphs)
246+
if rss < available_rss * 0.3:
247+
num_of_graphs = 1
248+
elif rss < available_rss * 0.5:
249+
num_of_graphs = 2
250+
elif rss < available_rss:
251+
num_of_graphs = 4
252+
elif rss < available_rss * 1.5:
253+
num_of_graphs = 8
254+
elif rss < available_rss * 2:
255+
num_of_graphs = 16
256+
else:
257+
num_of_graphs = 32
258+
259+
return max(
260+
1, num_of_graphs // ((len(subgraphs) + 1) // 2)
261+
) # If there are already graph breaks, for each TRT subgraph, we break for a few times.
262+
263+
def break_subgraphs(
264+
self, subgraphs: List[Subgraph], num_of_break: int = 1
265+
) -> List[Subgraph]:
266+
"""
267+
This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory.
268+
"""
269+
270+
num_of_sdpa_node = len(
271+
[node for node in self.acc_nodes if "scaled_dot" in str(node.target)]
272+
)
273+
break_period = num_of_sdpa_node // num_of_break + 1
274+
current_break_idx = 0
275+
current_num_break = 0
276+
new_subgraphs = []
277+
for subgraph in subgraphs:
278+
if subgraph.is_acc:
279+
for i, node in enumerate(subgraph.nodes):
280+
if "scaled_dot" in str(node.target):
281+
current_num_break += 1
282+
if current_num_break % break_period != 0:
283+
continue
284+
new_subgraphs.append(
285+
Subgraph(
286+
is_acc=True,
287+
nodes=subgraph.nodes[current_break_idx : i + 1],
288+
device_ordinal=subgraph.device_ordinal,
289+
)
290+
)
291+
current_break_idx = i + 1
292+
new_subgraphs.append(
293+
Subgraph(
294+
is_acc=True,
295+
nodes=subgraph.nodes[current_break_idx:],
296+
device_ordinal=subgraph.device_ordinal,
297+
)
298+
)
299+
else:
300+
new_subgraphs.append(subgraph)
301+
return new_subgraphs
302+
235303
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
236304
"""Generates starter nodes for partitioning + segmentation"""
237305
# Starter accelerated nodes are all callable accelerated ops

0 commit comments

Comments
 (0)