Skip to content

Commit f6a3f86

Browse files
committed
Added pass name check
1 parent 1d34057 commit f6a3f86

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import os
12
import tempfile
2-
from types import new_class
3-
from typing import Any, Callable, List, Optional, Union
3+
from typing import Any, Callable, List, Optional
44

55
import torch
66
from torch.fx import passes
@@ -14,6 +14,8 @@ def _generate_draw_fx_graph_pass(
1414
def draw_fx_graph_pass(
1515
gm: torch.fx.GraphModule, settings: CompilationSettings
1616
) -> torch.fx.GraphModule:
17+
if not os.path.exists(f"{output_path_prefix}/"):
18+
os.makedirs(f"{output_path_prefix}/")
1719
path = f"{output_path_prefix}/{name}.svg"
1820
g = passes.graph_drawer.FxGraphDrawer(gm, name)
1921
with open(path, "wb") as f:
@@ -33,7 +35,7 @@ def __init__(
3335
]
3436
]
3537
] = None,
36-
constraints: Optional[List[Callable]] = None
38+
constraints: Optional[List[Callable]] = None,
3739
):
3840
super().__init__(passes, constraints)
3941

@@ -68,7 +70,7 @@ def remove_pass_with_index(self, index: int) -> None:
6870
del self.passes[index]
6971

7072
def insert_debug_pass_before(
71-
self, passes: List[str], output_path_prefix: str=tempfile.gettempdir()
73+
self, passes: List[str], output_path_prefix: str = tempfile.gettempdir()
7274
) -> None:
7375
"""Insert debug passes in the PassManager pass sequence prior to the execution of a particular pass.
7476
@@ -79,17 +81,22 @@ def insert_debug_pass_before(
7981
Debug passes generate SVG visualizations of the FX graph at specified points
8082
in the pass sequence.
8183
"""
84+
self.check_pass_names_valid(passes)
8285
new_pass_list = []
8386
for ps in self.passes:
8487
if ps.__name__ in passes:
85-
new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"before_{ps.__name__}"))
88+
new_pass_list.append(
89+
_generate_draw_fx_graph_pass(
90+
output_path_prefix, f"before_{ps.__name__}"
91+
)
92+
)
8693
new_pass_list.append(ps)
8794

8895
self.passes = new_pass_list
8996
self._validated = False
9097

9198
def insert_debug_pass_after(
92-
self, passes: List[str], output_path_prefix: str=tempfile.gettempdir()
99+
self, passes: List[str], output_path_prefix: str = tempfile.gettempdir()
93100
) -> None:
94101
"""Insert debug passes in the PassManager pass sequence after the execution of a particular pass.
95102
@@ -100,16 +107,27 @@ def insert_debug_pass_after(
100107
Debug passes generate SVG visualizations of the FX graph at specified points
101108
in the pass sequence.
102109
"""
110+
self.check_pass_names_valid(passes)
103111
new_pass_list = []
104112
for ps in self.passes:
105113
new_pass_list.append(ps)
106114
if ps.__name__ in passes:
107-
new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"after_{ps.__name__}"))
108-
115+
new_pass_list.append(
116+
_generate_draw_fx_graph_pass(
117+
output_path_prefix, f"after_{ps.__name__}"
118+
)
119+
)
109120

110121
self.passes = new_pass_list
111122
self._validated = False
112123

124+
def check_pass_names_valid(self, debug_pass_names: List[str]) -> None:
125+
pass_names_str = [p.__name__ for p in self.passes]
126+
for name in debug_pass_names:
127+
assert (
128+
name in pass_names_str
129+
), f"{name} is not a valid pass! Passes: {pass_names_str}"
130+
113131
def __call__(self, gm: Any, settings: CompilationSettings) -> Any:
114132
self.validate()
115133
out = gm

0 commit comments

Comments
 (0)