Skip to content

Commit 8d9c413

Browse files
committed
Added torchtrt.dynamo.debugger. Cleaning settings.debug
1 parent 031267c commit 8d9c413

File tree

13 files changed

+225
-57
lines changed

13 files changed

+225
-57
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
load_cross_compiled_exported_program,
1515
save_cross_compiled_exported_program,
1616
)
17+
from ._debugger import Debugger
1718
from ._exporter import export
1819
from ._refit import refit_module_weights
1920
from ._settings import CompilationSettings

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,6 @@ def compile(
422422
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
423423
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
424424
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
425-
engine_vis_dir: Optional[str] = _defaults.ENGINE_VIS_DIR,
426425
**kwargs: Any,
427426
) -> torch.fx.GraphModule:
428427
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -506,7 +505,13 @@ def compile(
506505
"""
507506

508507
if debug:
509-
set_log_level(logger.parent, logging.DEBUG)
508+
warnings.warn(
509+
"The 'debug' argument is deprecated and will be removed in a future release. "
510+
"Please use the torch_tensorrt.dynamo.Debugger context manager for debugging and graph capture.",
511+
DeprecationWarning,
512+
stacklevel=2,
513+
)
514+
510515
if "truncate_long_and_double" in kwargs.keys():
511516
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
512517
raise ValueError(
@@ -637,7 +642,6 @@ def compile(
637642
"enabled_precisions": (
638643
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
639644
),
640-
"debug": debug,
641645
"device": device,
642646
"assume_dynamic_shape_support": assume_dynamic_shape_support,
643647
"workspace_size": workspace_size,
@@ -676,7 +680,6 @@ def compile(
676680
"enable_weight_streaming": enable_weight_streaming,
677681
"tiling_optimization_level": tiling_optimization_level,
678682
"l2_limit_for_tiling": l2_limit_for_tiling,
679-
"engine_vis_dir": engine_vis_dir,
680683
}
681684

682685
settings = CompilationSettings(**compilation_options)
@@ -728,7 +731,7 @@ def compile_module(
728731

729732
# Check the number of supported operations in the graph
730733
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
731-
gm, settings.debug, settings.torch_executed_ops
734+
gm, settings.torch_executed_ops
732735
)
733736

734737
dryrun_tracker.total_ops_in_graph = total_ops
@@ -780,7 +783,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
780783
logger.info("Partitioning the graph via the fast partitioner")
781784
partitioned_module, supported_ops = partitioning.fast_partition(
782785
gm,
783-
verbose=settings.debug,
784786
min_block_size=settings.min_block_size,
785787
torch_executed_ops=settings.torch_executed_ops,
786788
require_full_compilation=settings.require_full_compilation,
@@ -801,7 +803,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
801803
logger.info("Partitioning the graph via the global partitioner")
802804
partitioned_module, supported_ops = partitioning.global_partition(
803805
gm,
804-
verbose=settings.debug,
805806
min_block_size=settings.min_block_size,
806807
torch_executed_ops=settings.torch_executed_ops,
807808
require_full_compilation=settings.require_full_compilation,
@@ -906,17 +907,21 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
906907
)
907908

908909
trt_modules[name] = trt_module
910+
from torch_tensorrt.dynamo._debugger import (
911+
DEBUG_FILE_DIR,
912+
SAVE_ENGINE_PROFILE,
913+
)
909914

910-
if settings.debug and settings.engine_vis_dir:
915+
if SAVE_ENGINE_PROFILE:
911916
if settings.use_python_runtime:
912917
logger.warning(
913918
"Profiling can only be enabled when using the C++ runtime"
914919
)
915920
else:
916-
if not os.path.exists(settings.engine_vis_dir):
917-
os.makedirs(settings.engine_vis_dir)
921+
path = os.path.join(DEBUG_FILE_DIR, "engine_visualization")
922+
os.makedirs(path, exist_ok=True)
918923
trt_module.enable_profiling(
919-
profiling_results_dir=settings.engine_vis_dir,
924+
profiling_results_dir=path,
920925
profile_format="trex",
921926
)
922927

py/torch_tensorrt/dynamo/_debugger.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import logging
2+
import os
3+
import tempfile
4+
from logging.config import dictConfig
5+
from typing import Any, List, Optional
6+
7+
import torch
8+
from torch_tensorrt.dynamo.lowering import (
9+
ATEN_POST_LOWERING_PASSES,
10+
ATEN_PRE_LOWERING_PASSES,
11+
)
12+
13+
_LOGGER = logging.getLogger("torch_tensorrt [TensorRT Conversion Context]")
14+
GRAPH_LEVEL = 5
15+
DEBUG_FILE_DIR = tempfile.TemporaryDirectory().name
16+
SAVE_ENGINE_PROFILE = False
17+
18+
19+
class Debugger:
20+
def __init__(
21+
self,
22+
level: str,
23+
capture_fx_graph_before: Optional[List[str]] = None,
24+
capture_fx_graph_after: Optional[List[str]] = None,
25+
save_engine_profile: bool = False,
26+
logging_dir: Optional[str] = None,
27+
):
28+
logging.addLevelName(GRAPH_LEVEL, "GRAPHS")
29+
if level != "graphs" and (capture_fx_graph_after or save_engine_profile):
30+
_LOGGER.warning(
31+
"Capture FX Graph or Draw Engine Graph is only supported when level is 'graphs'"
32+
)
33+
34+
if level == "debug":
35+
self.level = logging.DEBUG
36+
elif level == "info":
37+
self.level = logging.INFO
38+
elif level == "warning":
39+
self.level = logging.WARNING
40+
elif level == "error":
41+
self.level = logging.ERROR
42+
elif level == "internal_errors":
43+
self.level = logging.CRITICAL
44+
elif level == "graphs":
45+
self.level = GRAPH_LEVEL
46+
47+
else:
48+
raise ValueError(
49+
f"Invalid level: {level}, allowed levels are: debug, info, warning, error, internal_errors, graphs"
50+
)
51+
52+
self.capture_fx_graph_before = capture_fx_graph_before
53+
self.capture_fx_graph_after = capture_fx_graph_after
54+
global SAVE_ENGINE_PROFILE
55+
SAVE_ENGINE_PROFILE = save_engine_profile
56+
57+
if logging_dir is not None:
58+
global DEBUG_FILE_DIR
59+
DEBUG_FILE_DIR = logging_dir
60+
os.makedirs(DEBUG_FILE_DIR, exist_ok=True)
61+
62+
def __enter__(self) -> None:
63+
self.original_lvl = _LOGGER.getEffectiveLevel()
64+
self.rt_level = torch.ops.tensorrt.get_logging_level()
65+
dictConfig(self.get_config())
66+
67+
if self.level == GRAPH_LEVEL:
68+
self.old_pre_passes, self.old_post_passes = (
69+
ATEN_PRE_LOWERING_PASSES.passes,
70+
ATEN_POST_LOWERING_PASSES.passes,
71+
)
72+
pre_pass_names = [p.__name__ for p in self.old_pre_passes]
73+
post_pass_names = [p.__name__ for p in self.old_post_passes]
74+
path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
75+
if self.capture_fx_graph_before is not None:
76+
pre_vis_passes = [
77+
p for p in self.capture_fx_graph_before if p in pre_pass_names
78+
]
79+
post_vis_passes = [
80+
p for p in self.capture_fx_graph_before if p in post_pass_names
81+
]
82+
ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(pre_vis_passes, path)
83+
ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(
84+
post_vis_passes, path
85+
)
86+
if self.capture_fx_graph_after is not None:
87+
pre_vis_passes = [
88+
p for p in self.capture_fx_graph_after if p in pre_pass_names
89+
]
90+
post_vis_passes = [
91+
p for p in self.capture_fx_graph_after if p in post_pass_names
92+
]
93+
ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(pre_vis_passes, path)
94+
ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(post_vis_passes, path)
95+
96+
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
97+
98+
dictConfig(self.get_default_config())
99+
torch.ops.tensorrt.set_logging_level(self.rt_level)
100+
if self.level == GRAPH_LEVEL and self.capture_fx_graph_after:
101+
ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
102+
self.old_pre_passes,
103+
self.old_post_passes,
104+
)
105+
106+
def get_config(self) -> dict[str, Any]:
107+
config = {
108+
"version": 1,
109+
"disable_existing_loggers": False,
110+
"formatters": {
111+
"brief": {
112+
"format": "%(asctime)s - %(levelname)s - %(message)s",
113+
"datefmt": "%H:%M:%S",
114+
},
115+
"standard": {
116+
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
117+
"datefmt": "%Y-%m-%d %H:%M:%S",
118+
},
119+
},
120+
"handlers": {
121+
"file": {
122+
"level": self.level,
123+
"class": "logging.FileHandler",
124+
"filename": f"{DEBUG_FILE_DIR}/torch_tensorrt_logging.log",
125+
"formatter": "standard",
126+
},
127+
"console": {
128+
"level": self.level,
129+
"class": "logging.StreamHandler",
130+
"formatter": "brief",
131+
},
132+
},
133+
"loggers": {
134+
"": { # root logger
135+
"handlers": ["file", "console"],
136+
"level": self.level,
137+
"propagate": True,
138+
},
139+
},
140+
"force": True,
141+
}
142+
return config
143+
144+
def get_default_config(self) -> dict[str, Any]:
145+
config = {
146+
"version": 1,
147+
"disable_existing_loggers": False,
148+
"formatters": {
149+
"brief": {
150+
"format": "%(asctime)s - %(levelname)s - %(message)s",
151+
"datefmt": "%H:%M:%S",
152+
},
153+
"standard": {
154+
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
155+
"datefmt": "%Y-%m-%d %H:%M:%S",
156+
},
157+
},
158+
"handlers": {
159+
"console": {
160+
"level": self.original_lvl,
161+
"class": "logging.StreamHandler",
162+
"formatter": "brief",
163+
},
164+
},
165+
"loggers": {
166+
"": { # root logger
167+
"handlers": ["console"],
168+
"level": self.original_lvl,
169+
"propagate": True,
170+
},
171+
},
172+
"force": True,
173+
}
174+
return config

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
DLA_SRAM_SIZE = 1048576
1616
ENGINE_CAPABILITY = EngineCapability.STANDARD
1717
WORKSPACE_SIZE = 0
18-
ENGINE_VIS_DIR = None
1918
MIN_BLOCK_SIZE = 5
2019
PASS_THROUGH_BUILD_FAILURES = False
2120
MAX_AUX_STREAMS = None

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
check_module_output,
4040
get_model_device,
4141
get_torch_inputs,
42-
set_log_level,
4342
to_torch_device,
4443
to_torch_tensorrt_device,
4544
)
@@ -72,7 +71,6 @@ def construct_refit_mapping(
7271
interpreter = TRTInterpreter(
7372
module,
7473
inputs,
75-
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
7674
output_dtypes=output_dtypes,
7775
compilation_settings=settings,
7876
)
@@ -266,9 +264,6 @@ def refit_module_weights(
266264
not settings.immutable_weights
267265
), "Refitting is not enabled. Please recompile the engine with immutable_weights=False."
268266

269-
if settings.debug:
270-
set_log_level(logger.parent, logging.DEBUG)
271-
272267
device = to_torch_tensorrt_device(settings.device)
273268
if arg_inputs:
274269
if not isinstance(arg_inputs, collections.abc.Sequence):
@@ -304,7 +299,6 @@ def refit_module_weights(
304299
try:
305300
new_partitioned_module, supported_ops = partitioning.fast_partition(
306301
new_gm,
307-
verbose=settings.debug,
308302
min_block_size=settings.min_block_size,
309303
torch_executed_ops=settings.torch_executed_ops,
310304
)
@@ -320,7 +314,6 @@ def refit_module_weights(
320314
if not settings.use_fast_partitioner:
321315
new_partitioned_module, supported_ops = partitioning.global_partition(
322316
new_gm,
323-
verbose=settings.debug,
324317
min_block_size=settings.min_block_size,
325318
torch_executed_ops=settings.torch_executed_ops,
326319
)

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from torch_tensorrt.dynamo._defaults import (
88
ASSUME_DYNAMIC_SHAPE_SUPPORT,
99
CACHE_BUILT_ENGINES,
10-
DEBUG,
1110
DISABLE_TF32,
1211
DLA_GLOBAL_DRAM_SIZE,
1312
DLA_LOCAL_DRAM_SIZE,
@@ -18,7 +17,6 @@
1817
ENABLE_WEIGHT_STREAMING,
1918
ENABLED_PRECISIONS,
2019
ENGINE_CAPABILITY,
21-
ENGINE_VIS_DIR,
2220
HARDWARE_COMPATIBLE,
2321
IMMUTABLE_WEIGHTS,
2422
L2_LIMIT_FOR_TILING,
@@ -101,7 +99,7 @@ class CompilationSettings:
10199
"""
102100

103101
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
104-
debug: bool = DEBUG
102+
# debug: bool = True,
105103
workspace_size: int = WORKSPACE_SIZE
106104
min_block_size: int = MIN_BLOCK_SIZE
107105
torch_executed_ops: Collection[Target] = field(default_factory=set)
@@ -141,7 +139,6 @@ class CompilationSettings:
141139
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
142140
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
143141
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144-
engine_vis_dir: Optional[str] = ENGINE_VIS_DIR
145142

146143

147144
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def __init__(
7575
self,
7676
module: torch.fx.GraphModule,
7777
input_specs: Sequence[Input],
78-
logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING,
7978
output_dtypes: Optional[Sequence[dtype]] = None,
8079
compilation_settings: CompilationSettings = CompilationSettings(),
8180
engine_cache: Optional[BaseEngineCache] = None,

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
from typing import Any, List, Optional, Sequence
55

6-
import tensorrt as trt
76
import torch
87
from torch_tensorrt._enums import dtype
98
from torch_tensorrt._features import ENABLED_FEATURES
@@ -60,7 +59,6 @@ def interpret_module_to_result(
6059
interpreter = TRTInterpreter(
6160
module,
6261
inputs,
63-
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
6462
output_dtypes=output_dtypes,
6563
compilation_settings=settings,
6664
engine_cache=engine_cache,

0 commit comments

Comments
 (0)