Skip to content

Commit 54b1e28

Browse files
Juntian777facebook-github-bot
authored andcommitted
Create the IntermediateOutputCapturer Class to Store the IntermediateOutput of the AOT Graph
Summary: This Diff introduces a new Python class, IntermediateOutputCapturer, which inherits from torch.fx.interpreter.Interpreter. The primary purpose of this class is to capture the output tensor(s) produced by each operator (node) for an EdgeProgramManager's GraphModule. We will use these stored outputs to compare with later runtime operator outputs to detect numerical discrepancies. The IntermediateOutputCapturer class overrides the run_node method to store the computed results in an instance dictionary. It checks for the presence of a debug_handle in the node's metadata and the type of the node and uses it as a key to store the result. Tensors are detached and cloned to prevent side effects, while non-tensor results are stored directly. A public method, run_and_capture, is implemented to call the base Interpreter's run method and return the dictionary containing the captured debug_handle -> output mappings. Additionally, an __init__ method is provided to accept an fx.GraphModule as input and a print_captured_outputs method is included for debugging purposes. Differential Revision: D75492919
1 parent b00a90a commit 54b1e28

File tree

4 files changed

+197
-0
lines changed

4 files changed

+197
-0
lines changed

devtools/inspector/TARGETS

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ python_library(
4848
],
4949
)
5050

51+
python_library(
52+
name = "intermediate_output_capturer",
53+
srcs = [
54+
"_intermediate_output_capturer.py",
55+
],
56+
deps = [
57+
],
58+
)
59+
5160
python_library(
5261
name = "lib",
5362
srcs = ["__init__.py"],
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
from typing import Any, Dict, Tuple
11+
12+
import torch
13+
from torch.fx import GraphModule
14+
from torch.fx.interpreter import Interpreter
15+
16+
17+
class IntermediateOutputCapturer(Interpreter):
18+
def __init__(self, module: GraphModule):
19+
super().__init__(module)
20+
21+
def run_and_capture(self, *args, **kwargs) -> Dict[Tuple[int, ...], Any]:
22+
captured_outputs = {}
23+
24+
def capture_run_node(n: torch.fx.Node) -> Any:
25+
result = super(IntermediateOutputCapturer, self).run_node(n)
26+
debug_handle = n.meta.get("debug_handle", None)
27+
if debug_handle is not None and (n.op == "call_function" or n.op == "output"):
28+
# Convert the debug handle to a tuple to use as a dictionary key
29+
key = (
30+
(debug_handle,)
31+
if isinstance(debug_handle, int)
32+
else tuple(debug_handle)
33+
)
34+
# Handle tensor results by detaching and cloning
35+
if isinstance(result, torch.Tensor):
36+
captured_outputs[key] = result.detach().clone()
37+
elif isinstance(result, (tuple, list)):
38+
captured_outputs[key] = [
39+
r.detach().clone() if isinstance(r, torch.Tensor) else r
40+
for r in result
41+
]
42+
else:
43+
captured_outputs[key] = result
44+
return result
45+
46+
original_run_node = self.run_node
47+
self.run_node = capture_run_node
48+
self.run(*args, **kwargs)
49+
self.run_node = original_run_node
50+
return captured_outputs

devtools/inspector/tests/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,14 @@ python_unittest(
3939
"//executorch/devtools/inspector:inspector_utils",
4040
],
4141
)
42+
43+
python_unittest(
44+
name = "intermediate_output_capturer_test",
45+
srcs = ["intermediate_output_capturer_test.py"],
46+
deps = [
47+
"//executorch/devtools/inspector:inspector",
48+
"//executorch/devtools/inspector:lib",
49+
"//executorch/devtools/inspector:intermediate_output_capturer",
50+
"//executorch/exir:lib",
51+
],
52+
)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import unittest
11+
import torch
12+
import torch.nn as nn
13+
import torch.nn.functional as F
14+
from torch.fx import GraphModule
15+
from executorch.devtools.inspector._intermediate_output_capturer import IntermediateOutputCapturer
16+
from torch.export import export, ExportedProgram
17+
18+
from executorch.exir import (
19+
EdgeCompileConfig,
20+
EdgeProgramManager,
21+
to_edge,
22+
)
23+
24+
class TestIntermediateOutputCapturer(unittest.TestCase):
25+
@classmethod
26+
def setUpClass(cls):
27+
class TestModule(nn.Module):
28+
def __init__(self):
29+
super(TestModule, self).__init__()
30+
self.conv = nn.Conv2d(
31+
in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1
32+
)
33+
self.conv.weight = nn.Parameter(
34+
torch.tensor(
35+
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]
36+
)
37+
)
38+
self.conv.bias = nn.Parameter(torch.tensor([0.0]))
39+
40+
self.linear = nn.Linear(in_features=4, out_features=2)
41+
self.linear.weight = nn.Parameter(
42+
torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]])
43+
)
44+
self.linear.bias = nn.Parameter(torch.tensor([0.0, 0.0]))
45+
self.bias = nn.Parameter(torch.tensor([0.5, -0.5]), requires_grad=False)
46+
self.scale = nn.Parameter(torch.tensor([2.0, 0.5]), requires_grad=False)
47+
48+
def forward(self, x):
49+
x = self.conv(x)
50+
x = x.view(x.size(0), -1)
51+
x = self.linear(x)
52+
x = x + self.bias
53+
x = x - 0.1
54+
x = x * self.scale
55+
x = x / (self.scale + 1.0)
56+
x = F.relu(x)
57+
x = torch.sigmoid(x)
58+
x1, x2 = torch.split(x, 1, dim=1)
59+
return x1, x2
60+
61+
cls.model = TestModule()
62+
cls.input = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True)
63+
cls.aten_model: ExportedProgram = export(cls.model, (cls.input,), strict=True)
64+
cls.edge_program_manager: EdgeProgramManager = to_edge(
65+
cls.aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
66+
)
67+
cls.graph_module: GraphModule = cls.edge_program_manager._edge_programs[
68+
"forward"
69+
].module()
70+
cls.capturer = IntermediateOutputCapturer(cls.graph_module)
71+
cls.intermediate_outputs = cls.capturer.run_and_capture(cls.input)
72+
73+
def test_keying_with_debug_handle_tuple(self):
74+
for key in self.intermediate_outputs.keys():
75+
self.assertIsInstance(key, tuple)
76+
77+
def test_tensor_cloning_and_detaching(self):
78+
for output in self.intermediate_outputs.values():
79+
if isinstance(output, torch.Tensor):
80+
self.assertFalse(output.requires_grad)
81+
self.assertTrue(output.is_leaf)
82+
83+
def test_placeholder_nodes_are_skipped(self):
84+
for node in self.graph_module.graph.nodes:
85+
if node.op == "placeholder":
86+
self.assertNotIn(
87+
node.meta.get("debug_handle"), self.intermediate_outputs
88+
)
89+
90+
def test_multiple_outputs_capture(self):
91+
outputs = self.capturer.run_and_capture(self.input)
92+
for output in outputs.values():
93+
if isinstance(output, tuple):
94+
self.assertEqual(len(output), 2)
95+
for part in output:
96+
self.assertIsInstance(part, torch.Tensor)
97+
98+
def test_capture_correct_outputs(self):
99+
expected_outputs_with_handles = {
100+
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
101+
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
102+
(12,): torch.tensor([[0.1000, 0.5000], [0.2000, 0.6000], [0.3000, 0.7000], [0.4000, 0.8000]]),
103+
(13,): torch.tensor([[5.0000, 14.1200]]),
104+
(14,): torch.tensor([[5.5000, 13.6200]]),
105+
(15,): torch.tensor([[5.4000, 13.5200]]),
106+
(16,): torch.tensor([[10.8000, 6.7600]]),
107+
(17,): torch.tensor([3.0000, 1.5000]),
108+
(18,): torch.tensor([[3.6000, 4.5067]]),
109+
(19,): torch.tensor([[3.6000, 4.5067]]),
110+
(20,): torch.tensor([[0.9734, 0.9891]]),
111+
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
112+
(22,): torch.tensor([[0.9734]]),
113+
(23,): torch.tensor([[0.9891]]),
114+
(24,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
115+
}
116+
self.assertEqual(len(self.intermediate_outputs), len(expected_outputs_with_handles))
117+
118+
for debug_handle, expected_output in expected_outputs_with_handles.items():
119+
actual_output = self.intermediate_outputs.get(debug_handle)
120+
self.assertIsNotNone(actual_output)
121+
if isinstance(expected_output, list):
122+
self.assertIsInstance(actual_output, list)
123+
self.assertEqual(len(actual_output), len(expected_output))
124+
for actual, expected in zip(actual_output, expected_output):
125+
self.assertTrue(torch.allclose(actual, expected, rtol=1e-4, atol=1e-5))
126+
else:
127+
self.assertTrue(torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5))

0 commit comments

Comments
 (0)