-
Notifications
You must be signed in to change notification settings - Fork 364
FX graph visualization #3528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FX graph visualization #3528
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py 2025-05-23 04:32:05.196604+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py 2025-05-23 04:32:28.295008+00:00
@@ -31,11 +31,11 @@
Callable[
[torch.fx.GraphModule, CompilationSettings], torch.fx.GraphModule
]
]
] = None,
- constraints: Optional[List[Callable]] = None
+ constraints: Optional[List[Callable]] = None,
):
super().__init__(passes, constraints)
@classmethod
def build_from_passlist(
@@ -66,11 +66,11 @@
def remove_pass_with_index(self, index: int) -> None:
del self.passes[index]
def insert_debug_pass_before(
- self, passes: List[str], output_path_prefix: str=tempfile.gettempdir()
+ self, passes: List[str], output_path_prefix: str = tempfile.gettempdir()
) -> None:
"""Insert debug passes in the PassManager pass sequence prior to the execution of a particular pass.
Args:
passes: List of pass names to insert debug passes before
@@ -80,18 +80,22 @@
in the pass sequence.
"""
new_pass_list = []
for ps in self.passes:
if ps.__name__ in passes:
- new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"before_{ps.__name__}"))
+ new_pass_list.append(
+ _generate_draw_fx_graph_pass(
+ output_path_prefix, f"before_{ps.__name__}"
+ )
+ )
new_pass_list.append(ps)
self.passes = new_pass_list
self._validated = False
def insert_debug_pass_after(
- self, passes: List[str], output_path_prefix: str=tempfile.gettempdir()
+ self, passes: List[str], output_path_prefix: str = tempfile.gettempdir()
) -> None:
"""Insert debug passes in the PassManager pass sequence after the execution of a particular pass.
Args:
passes: List of pass names to insert debug passes after
@@ -102,12 +106,15 @@
"""
new_pass_list = []
for ps in self.passes:
new_pass_list.append(ps)
if ps.__name__ in passes:
- new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"after_{ps.__name__}"))
-
+ new_pass_list.append(
+ _generate_draw_fx_graph_pass(
+ output_path_prefix, f"after_{ps.__name__}"
+ )
+ )
self.passes = new_pass_list
self._validated = False
def __call__(self, gm: Any, settings: CompilationSettings) -> Any:
2a91f9d
to
f6a3f86
Compare
@@ -15,6 +15,7 @@ | |||
DLA_SRAM_SIZE = 1048576 | |||
ENGINE_CAPABILITY = EngineCapability.STANDARD | |||
WORKSPACE_SIZE = 0 | |||
ENGINE_VIS_DIR = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just set to temp_dir/torch_tensorrt_debug or something
@cehongwang can you target the debugging branch and we can pull all those changes in at once? |
a6fd323
to
031267c
Compare
8d9c413
to
d3e3058
Compare
d3e3058
to
74bb32d
Compare
Description
Debugging FX graphs can be challenging due to the complexity of analyzing node connections directly from the FX table. Therefore, providing a clear visualization of the FX graph is essential to facilitate effective debugging.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: