Skip to content

Commit 74bb32d

Browse files
committed
Added torchtrt.dynamo.debugger. Cleaning settings.debug
1 parent 031267c commit 74bb32d

File tree

13 files changed

+228
-57
lines changed

13 files changed

+228
-57
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
load_cross_compiled_exported_program,
1515
save_cross_compiled_exported_program,
1616
)
17+
from ._debugger import Debugger
1718
from ._exporter import export
1819
from ._refit import refit_module_weights
1920
from ._settings import CompilationSettings

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,6 @@ def compile(
422422
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
423423
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
424424
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
425-
engine_vis_dir: Optional[str] = _defaults.ENGINE_VIS_DIR,
426425
**kwargs: Any,
427426
) -> torch.fx.GraphModule:
428427
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -506,7 +505,13 @@ def compile(
506505
"""
507506

508507
if debug:
509-
set_log_level(logger.parent, logging.DEBUG)
508+
warnings.warn(
509+
"The 'debug' argument is deprecated and will be removed in a future release. "
510+
"Please use the torch_tensorrt.dynamo.Debugger context manager for debugging and graph capture.",
511+
DeprecationWarning,
512+
stacklevel=2,
513+
)
514+
510515
if "truncate_long_and_double" in kwargs.keys():
511516
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
512517
raise ValueError(
@@ -637,7 +642,6 @@ def compile(
637642
"enabled_precisions": (
638643
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
639644
),
640-
"debug": debug,
641645
"device": device,
642646
"assume_dynamic_shape_support": assume_dynamic_shape_support,
643647
"workspace_size": workspace_size,
@@ -676,7 +680,6 @@ def compile(
676680
"enable_weight_streaming": enable_weight_streaming,
677681
"tiling_optimization_level": tiling_optimization_level,
678682
"l2_limit_for_tiling": l2_limit_for_tiling,
679-
"engine_vis_dir": engine_vis_dir,
680683
}
681684

682685
settings = CompilationSettings(**compilation_options)
@@ -728,7 +731,7 @@ def compile_module(
728731

729732
# Check the number of supported operations in the graph
730733
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
731-
gm, settings.debug, settings.torch_executed_ops
734+
gm, settings.torch_executed_ops
732735
)
733736

734737
dryrun_tracker.total_ops_in_graph = total_ops
@@ -780,7 +783,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
780783
logger.info("Partitioning the graph via the fast partitioner")
781784
partitioned_module, supported_ops = partitioning.fast_partition(
782785
gm,
783-
verbose=settings.debug,
784786
min_block_size=settings.min_block_size,
785787
torch_executed_ops=settings.torch_executed_ops,
786788
require_full_compilation=settings.require_full_compilation,
@@ -801,7 +803,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
801803
logger.info("Partitioning the graph via the global partitioner")
802804
partitioned_module, supported_ops = partitioning.global_partition(
803805
gm,
804-
verbose=settings.debug,
805806
min_block_size=settings.min_block_size,
806807
torch_executed_ops=settings.torch_executed_ops,
807808
require_full_compilation=settings.require_full_compilation,
@@ -906,17 +907,21 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
906907
)
907908

908909
trt_modules[name] = trt_module
910+
from torch_tensorrt.dynamo._debugger import (
911+
DEBUG_FILE_DIR,
912+
SAVE_ENGINE_PROFILE,
913+
)
909914

910-
if settings.debug and settings.engine_vis_dir:
915+
if SAVE_ENGINE_PROFILE:
911916
if settings.use_python_runtime:
912917
logger.warning(
913918
"Profiling can only be enabled when using the C++ runtime"
914919
)
915920
else:
916-
if not os.path.exists(settings.engine_vis_dir):
917-
os.makedirs(settings.engine_vis_dir)
921+
path = os.path.join(DEBUG_FILE_DIR, "engine_visualization")
922+
os.makedirs(path, exist_ok=True)
918923
trt_module.enable_profiling(
919-
profiling_results_dir=settings.engine_vis_dir,
924+
profiling_results_dir=path,
920925
profile_format="trex",
921926
)
922927

py/torch_tensorrt/dynamo/_debugger.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import logging
2+
import os
3+
import tempfile
4+
from logging.config import dictConfig
5+
from typing import Any, List, Optional
6+
7+
import torch
8+
from torch_tensorrt.dynamo.lowering import (
9+
ATEN_POST_LOWERING_PASSES,
10+
ATEN_PRE_LOWERING_PASSES,
11+
)
12+
13+
_LOGGER = logging.getLogger("torch_tensorrt [TensorRT Conversion Context]")
14+
GRAPH_LEVEL = 5
15+
logging.addLevelName(GRAPH_LEVEL, "GRAPHS")
16+
17+
# Debugger States
18+
DEBUG_FILE_DIR = tempfile.TemporaryDirectory().name
19+
SAVE_ENGINE_PROFILE = False
20+
21+
22+
class Debugger:
23+
def __init__(
24+
self,
25+
level: str,
26+
capture_fx_graph_before: Optional[List[str]] = None,
27+
capture_fx_graph_after: Optional[List[str]] = None,
28+
save_engine_profile: bool = False,
29+
logging_dir: Optional[str] = None,
30+
):
31+
32+
if level != "graphs" and (capture_fx_graph_after or save_engine_profile):
33+
_LOGGER.warning(
34+
"Capture FX Graph or Draw Engine Graph is only supported when level is 'graphs'"
35+
)
36+
37+
if level == "debug":
38+
self.level = logging.DEBUG
39+
elif level == "info":
40+
self.level = logging.INFO
41+
elif level == "warning":
42+
self.level = logging.WARNING
43+
elif level == "error":
44+
self.level = logging.ERROR
45+
elif level == "internal_errors":
46+
self.level = logging.CRITICAL
47+
elif level == "graphs":
48+
self.level = GRAPH_LEVEL
49+
50+
else:
51+
raise ValueError(
52+
f"Invalid level: {level}, allowed levels are: debug, info, warning, error, internal_errors, graphs"
53+
)
54+
55+
self.capture_fx_graph_before = capture_fx_graph_before
56+
self.capture_fx_graph_after = capture_fx_graph_after
57+
global SAVE_ENGINE_PROFILE
58+
SAVE_ENGINE_PROFILE = save_engine_profile
59+
60+
if logging_dir is not None:
61+
global DEBUG_FILE_DIR
62+
DEBUG_FILE_DIR = logging_dir
63+
os.makedirs(DEBUG_FILE_DIR, exist_ok=True)
64+
65+
def __enter__(self) -> None:
66+
self.original_lvl = _LOGGER.getEffectiveLevel()
67+
self.rt_level = torch.ops.tensorrt.get_logging_level()
68+
dictConfig(self.get_config())
69+
70+
if self.level == GRAPH_LEVEL:
71+
self.old_pre_passes, self.old_post_passes = (
72+
ATEN_PRE_LOWERING_PASSES.passes,
73+
ATEN_POST_LOWERING_PASSES.passes,
74+
)
75+
pre_pass_names = [p.__name__ for p in self.old_pre_passes]
76+
post_pass_names = [p.__name__ for p in self.old_post_passes]
77+
path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
78+
if self.capture_fx_graph_before is not None:
79+
pre_vis_passes = [
80+
p for p in self.capture_fx_graph_before if p in pre_pass_names
81+
]
82+
post_vis_passes = [
83+
p for p in self.capture_fx_graph_before if p in post_pass_names
84+
]
85+
ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(pre_vis_passes, path)
86+
ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(
87+
post_vis_passes, path
88+
)
89+
if self.capture_fx_graph_after is not None:
90+
pre_vis_passes = [
91+
p for p in self.capture_fx_graph_after if p in pre_pass_names
92+
]
93+
post_vis_passes = [
94+
p for p in self.capture_fx_graph_after if p in post_pass_names
95+
]
96+
ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(pre_vis_passes, path)
97+
ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(post_vis_passes, path)
98+
99+
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
100+
101+
dictConfig(self.get_default_config())
102+
torch.ops.tensorrt.set_logging_level(self.rt_level)
103+
if self.level == GRAPH_LEVEL and self.capture_fx_graph_after:
104+
ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
105+
self.old_pre_passes,
106+
self.old_post_passes,
107+
)
108+
109+
def get_config(self) -> dict[str, Any]:
110+
config = {
111+
"version": 1,
112+
"disable_existing_loggers": False,
113+
"formatters": {
114+
"brief": {
115+
"format": "%(asctime)s - %(levelname)s - %(message)s",
116+
"datefmt": "%H:%M:%S",
117+
},
118+
"standard": {
119+
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
120+
"datefmt": "%Y-%m-%d %H:%M:%S",
121+
},
122+
},
123+
"handlers": {
124+
"file": {
125+
"level": self.level,
126+
"class": "logging.FileHandler",
127+
"filename": f"{DEBUG_FILE_DIR}/torch_tensorrt_logging.log",
128+
"formatter": "standard",
129+
},
130+
"console": {
131+
"level": self.level,
132+
"class": "logging.StreamHandler",
133+
"formatter": "brief",
134+
},
135+
},
136+
"loggers": {
137+
"": { # root logger
138+
"handlers": ["file", "console"],
139+
"level": self.level,
140+
"propagate": True,
141+
},
142+
},
143+
"force": True,
144+
}
145+
return config
146+
147+
def get_default_config(self) -> dict[str, Any]:
148+
config = {
149+
"version": 1,
150+
"disable_existing_loggers": False,
151+
"formatters": {
152+
"brief": {
153+
"format": "%(asctime)s - %(levelname)s - %(message)s",
154+
"datefmt": "%H:%M:%S",
155+
},
156+
"standard": {
157+
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
158+
"datefmt": "%Y-%m-%d %H:%M:%S",
159+
},
160+
},
161+
"handlers": {
162+
"console": {
163+
"level": self.original_lvl,
164+
"class": "logging.StreamHandler",
165+
"formatter": "brief",
166+
},
167+
},
168+
"loggers": {
169+
"": { # root logger
170+
"handlers": ["console"],
171+
"level": self.original_lvl,
172+
"propagate": True,
173+
},
174+
},
175+
"force": True,
176+
}
177+
return config

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
DLA_SRAM_SIZE = 1048576
1616
ENGINE_CAPABILITY = EngineCapability.STANDARD
1717
WORKSPACE_SIZE = 0
18-
ENGINE_VIS_DIR = None
1918
MIN_BLOCK_SIZE = 5
2019
PASS_THROUGH_BUILD_FAILURES = False
2120
MAX_AUX_STREAMS = None

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
check_module_output,
4040
get_model_device,
4141
get_torch_inputs,
42-
set_log_level,
4342
to_torch_device,
4443
to_torch_tensorrt_device,
4544
)
@@ -72,7 +71,6 @@ def construct_refit_mapping(
7271
interpreter = TRTInterpreter(
7372
module,
7473
inputs,
75-
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
7674
output_dtypes=output_dtypes,
7775
compilation_settings=settings,
7876
)
@@ -266,9 +264,6 @@ def refit_module_weights(
266264
not settings.immutable_weights
267265
), "Refitting is not enabled. Please recompile the engine with immutable_weights=False."
268266

269-
if settings.debug:
270-
set_log_level(logger.parent, logging.DEBUG)
271-
272267
device = to_torch_tensorrt_device(settings.device)
273268
if arg_inputs:
274269
if not isinstance(arg_inputs, collections.abc.Sequence):
@@ -304,7 +299,6 @@ def refit_module_weights(
304299
try:
305300
new_partitioned_module, supported_ops = partitioning.fast_partition(
306301
new_gm,
307-
verbose=settings.debug,
308302
min_block_size=settings.min_block_size,
309303
torch_executed_ops=settings.torch_executed_ops,
310304
)
@@ -320,7 +314,6 @@ def refit_module_weights(
320314
if not settings.use_fast_partitioner:
321315
new_partitioned_module, supported_ops = partitioning.global_partition(
322316
new_gm,
323-
verbose=settings.debug,
324317
min_block_size=settings.min_block_size,
325318
torch_executed_ops=settings.torch_executed_ops,
326319
)

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from dataclasses import dataclass, field
23
from typing import Collection, Optional, Set, Tuple, Union
34

@@ -7,7 +8,6 @@
78
from torch_tensorrt.dynamo._defaults import (
89
ASSUME_DYNAMIC_SHAPE_SUPPORT,
910
CACHE_BUILT_ENGINES,
10-
DEBUG,
1111
DISABLE_TF32,
1212
DLA_GLOBAL_DRAM_SIZE,
1313
DLA_LOCAL_DRAM_SIZE,
@@ -18,7 +18,6 @@
1818
ENABLE_WEIGHT_STREAMING,
1919
ENABLED_PRECISIONS,
2020
ENGINE_CAPABILITY,
21-
ENGINE_VIS_DIR,
2221
HARDWARE_COMPATIBLE,
2322
IMMUTABLE_WEIGHTS,
2423
L2_LIMIT_FOR_TILING,
@@ -101,7 +100,7 @@ class CompilationSettings:
101100
"""
102101

103102
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
104-
debug: bool = DEBUG
103+
debug: bool = logging.root.manager.root.level <= logging.DEBUG
105104
workspace_size: int = WORKSPACE_SIZE
106105
min_block_size: int = MIN_BLOCK_SIZE
107106
torch_executed_ops: Collection[Target] = field(default_factory=set)
@@ -141,7 +140,6 @@ class CompilationSettings:
141140
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
142141
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
143142
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144-
engine_vis_dir: Optional[str] = ENGINE_VIS_DIR
145143

146144

147145
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def __init__(
7575
self,
7676
module: torch.fx.GraphModule,
7777
input_specs: Sequence[Input],
78-
logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING,
7978
output_dtypes: Optional[Sequence[dtype]] = None,
8079
compilation_settings: CompilationSettings = CompilationSettings(),
8180
engine_cache: Optional[BaseEngineCache] = None,

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
from typing import Any, List, Optional, Sequence
55

6-
import tensorrt as trt
76
import torch
87
from torch_tensorrt._enums import dtype
98
from torch_tensorrt._features import ENABLED_FEATURES
@@ -60,7 +59,6 @@ def interpret_module_to_result(
6059
interpreter = TRTInterpreter(
6160
module,
6261
inputs,
63-
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
6462
output_dtypes=output_dtypes,
6563
compilation_settings=settings,
6664
engine_cache=engine_cache,

0 commit comments

Comments
 (0)