1
+ import os
1
2
import tempfile
2
- from types import new_class
3
- from typing import Any , Callable , List , Optional , Union
3
+ from typing import Any , Callable , List , Optional
4
4
5
5
import torch
6
6
from torch .fx import passes
@@ -14,6 +14,8 @@ def _generate_draw_fx_graph_pass(
14
14
def draw_fx_graph_pass (
15
15
gm : torch .fx .GraphModule , settings : CompilationSettings
16
16
) -> torch .fx .GraphModule :
17
+ if not os .path .exists (f"{ output_path_prefix } /" ):
18
+ os .makedirs (f"{ output_path_prefix } /" )
17
19
path = f"{ output_path_prefix } /{ name } .svg"
18
20
g = passes .graph_drawer .FxGraphDrawer (gm , name )
19
21
with open (path , "wb" ) as f :
@@ -33,7 +35,7 @@ def __init__(
33
35
]
34
36
]
35
37
] = None ,
36
- constraints : Optional [List [Callable ]] = None
38
+ constraints : Optional [List [Callable ]] = None ,
37
39
):
38
40
super ().__init__ (passes , constraints )
39
41
@@ -68,7 +70,7 @@ def remove_pass_with_index(self, index: int) -> None:
68
70
del self .passes [index ]
69
71
70
72
def insert_debug_pass_before (
71
- self , passes : List [str ], output_path_prefix : str = tempfile .gettempdir ()
73
+ self , passes : List [str ], output_path_prefix : str = tempfile .gettempdir ()
72
74
) -> None :
73
75
"""Insert debug passes in the PassManager pass sequence prior to the execution of a particular pass.
74
76
@@ -79,17 +81,22 @@ def insert_debug_pass_before(
79
81
Debug passes generate SVG visualizations of the FX graph at specified points
80
82
in the pass sequence.
81
83
"""
84
+ self .check_pass_names_valid (passes )
82
85
new_pass_list = []
83
86
for ps in self .passes :
84
87
if ps .__name__ in passes :
85
- new_pass_list .append (_generate_draw_fx_graph_pass (output_path_prefix , f"before_{ ps .__name__ } " ))
88
+ new_pass_list .append (
89
+ _generate_draw_fx_graph_pass (
90
+ output_path_prefix , f"before_{ ps .__name__ } "
91
+ )
92
+ )
86
93
new_pass_list .append (ps )
87
94
88
95
self .passes = new_pass_list
89
96
self ._validated = False
90
97
91
98
def insert_debug_pass_after (
92
- self , passes : List [str ], output_path_prefix : str = tempfile .gettempdir ()
99
+ self , passes : List [str ], output_path_prefix : str = tempfile .gettempdir ()
93
100
) -> None :
94
101
"""Insert debug passes in the PassManager pass sequence after the execution of a particular pass.
95
102
@@ -100,16 +107,27 @@ def insert_debug_pass_after(
100
107
Debug passes generate SVG visualizations of the FX graph at specified points
101
108
in the pass sequence.
102
109
"""
110
+ self .check_pass_names_valid (passes )
103
111
new_pass_list = []
104
112
for ps in self .passes :
105
113
new_pass_list .append (ps )
106
114
if ps .__name__ in passes :
107
- new_pass_list .append (_generate_draw_fx_graph_pass (output_path_prefix , f"after_{ ps .__name__ } " ))
108
-
115
+ new_pass_list .append (
116
+ _generate_draw_fx_graph_pass (
117
+ output_path_prefix , f"after_{ ps .__name__ } "
118
+ )
119
+ )
109
120
110
121
self .passes = new_pass_list
111
122
self ._validated = False
112
123
124
+ def check_pass_names_valid (self , debug_pass_names : List [str ]) -> None :
125
+ pass_names_str = [p .__name__ for p in self .passes ]
126
+ for name in debug_pass_names :
127
+ assert (
128
+ name in pass_names_str
129
+ ), f"{ name } is not a valid pass! Passes: { pass_names_str } "
130
+
113
131
def __call__ (self , gm : Any , settings : CompilationSettings ) -> Any :
114
132
self .validate ()
115
133
out = gm
0 commit comments