Skip to content

Commit ef06235

Browse files
Juntian777facebook-github-bot
authored andcommitted
Create the IntermediateOutputCapturer Class to Store the IntermediateOutput of the AOT Graph (#11202)
Summary: Pull Request resolved: #11202 This Diff introduces a new Python class, IntermediateOutputCapturer, which inherits from torch.fx.interpreter.Interpreter. The primary purpose of this class is to capture the output tensor(s) produced by each operator (node) for an EdgeProgramManager's GraphModule. We will use these stored outputs to compare with later runtime operator outputs to detect numerical discrepancies. The IntermediateOutputCapturer class overrides the run_node method to store the computed results in an instance dictionary. It checks for the presence of a debug_handle in the node's metadata and the type of the node and uses it as a key to store the result. Tensors are detached and cloned to prevent side effects, while non-tensor results are stored directly. A public method, run_and_capture, is implemented to call the base Interpreter's run method and return the dictionary containing the captured debug_handle -> output mappings. Additionally, an __init__ method is provided to accept an fx.GraphModule as input and a print_captured_outputs method is included for debugging purposes. Differential Revision: D75492919
1 parent 45008d6 commit ef06235

File tree

4 files changed

+207
-0
lines changed

4 files changed

+207
-0
lines changed

devtools/inspector/TARGETS

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ python_library(
4848
],
4949
)
5050

51+
python_library(
52+
name = "intermediate_output_capturer",
53+
srcs = [
54+
"_intermediate_output_capturer.py",
55+
],
56+
deps = [
57+
],
58+
)
59+
5160
python_library(
5261
name = "lib",
5362
srcs = ["__init__.py"],
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
from typing import Any, Dict, Tuple
11+
12+
import torch
13+
from torch.fx import GraphModule
14+
from torch.fx.interpreter import Interpreter
15+
16+
17+
class IntermediateOutputCapturer(Interpreter):
18+
def __init__(self, module: GraphModule):
19+
super().__init__(module)
20+
21+
def run_and_capture(self, *args, **kwargs) -> Dict[Tuple[int, ...], Any]:
22+
captured_outputs = {}
23+
24+
def capture_run_node(n: torch.fx.Node) -> Any:
25+
result = super(IntermediateOutputCapturer, self).run_node(n)
26+
debug_handle = n.meta.get("debug_handle", None)
27+
if debug_handle is not None and (
28+
n.op == "call_function" or n.op == "output"
29+
):
30+
# Convert the debug handle to a tuple to use as a dictionary key
31+
key = (
32+
(debug_handle,)
33+
if isinstance(debug_handle, int)
34+
else tuple(debug_handle)
35+
)
36+
# Handle tensor results by detaching and cloning
37+
if isinstance(result, torch.Tensor):
38+
captured_outputs[key] = result.detach().clone()
39+
elif isinstance(result, (tuple, list)):
40+
captured_outputs[key] = [
41+
r.detach().clone() if isinstance(r, torch.Tensor) else r
42+
for r in result
43+
]
44+
else:
45+
captured_outputs[key] = result
46+
return result
47+
48+
original_run_node = self.run_node
49+
self.run_node = capture_run_node
50+
self.run(*args, **kwargs)
51+
self.run_node = original_run_node
52+
return captured_outputs

devtools/inspector/tests/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,14 @@ python_unittest(
3939
"//executorch/devtools/inspector:inspector_utils",
4040
],
4141
)
42+
43+
python_unittest(
44+
name = "intermediate_output_capturer_test",
45+
srcs = ["intermediate_output_capturer_test.py"],
46+
deps = [
47+
"//executorch/devtools/inspector:inspector",
48+
"//executorch/devtools/inspector:lib",
49+
"//executorch/devtools/inspector:intermediate_output_capturer",
50+
"//executorch/exir:lib",
51+
],
52+
)
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import unittest
11+
12+
import torch
13+
import torch.nn as nn
14+
import torch.nn.functional as F
15+
from executorch.devtools.inspector._intermediate_output_capturer import (
16+
IntermediateOutputCapturer,
17+
)
18+
19+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
20+
from torch.export import export, ExportedProgram
21+
from torch.fx import GraphModule
22+
23+
24+
class TestIntermediateOutputCapturer(unittest.TestCase):
25+
@classmethod
26+
def setUpClass(cls):
27+
class TestModule(nn.Module):
28+
def __init__(self):
29+
super(TestModule, self).__init__()
30+
self.conv = nn.Conv2d(
31+
in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1
32+
)
33+
self.conv.weight = nn.Parameter(
34+
torch.tensor(
35+
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]
36+
)
37+
)
38+
self.conv.bias = nn.Parameter(torch.tensor([0.0]))
39+
40+
self.linear = nn.Linear(in_features=4, out_features=2)
41+
self.linear.weight = nn.Parameter(
42+
torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]])
43+
)
44+
self.linear.bias = nn.Parameter(torch.tensor([0.0, 0.0]))
45+
self.bias = nn.Parameter(torch.tensor([0.5, -0.5]), requires_grad=False)
46+
self.scale = nn.Parameter(torch.tensor([2.0, 0.5]), requires_grad=False)
47+
48+
def forward(self, x):
49+
x = self.conv(x)
50+
x = x.view(x.size(0), -1)
51+
x = self.linear(x)
52+
x = x + self.bias
53+
x = x - 0.1
54+
x = x * self.scale
55+
x = x / (self.scale + 1.0)
56+
x = F.relu(x)
57+
x = torch.sigmoid(x)
58+
x1, x2 = torch.split(x, 1, dim=1)
59+
return x1, x2
60+
61+
cls.model = TestModule()
62+
cls.input = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True)
63+
cls.aten_model: ExportedProgram = export(cls.model, (cls.input,), strict=True)
64+
cls.edge_program_manager: EdgeProgramManager = to_edge(
65+
cls.aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
66+
)
67+
cls.graph_module: GraphModule = cls.edge_program_manager._edge_programs[
68+
"forward"
69+
].module()
70+
cls.capturer = IntermediateOutputCapturer(cls.graph_module)
71+
cls.intermediate_outputs = cls.capturer.run_and_capture(cls.input)
72+
73+
def test_keying_with_debug_handle_tuple(self):
74+
for key in self.intermediate_outputs.keys():
75+
self.assertIsInstance(key, tuple)
76+
77+
def test_tensor_cloning_and_detaching(self):
78+
for output in self.intermediate_outputs.values():
79+
if isinstance(output, torch.Tensor):
80+
self.assertFalse(output.requires_grad)
81+
self.assertTrue(output.is_leaf)
82+
83+
def test_placeholder_nodes_are_skipped(self):
84+
for node in self.graph_module.graph.nodes:
85+
if node.op == "placeholder":
86+
self.assertNotIn(
87+
node.meta.get("debug_handle"), self.intermediate_outputs
88+
)
89+
90+
def test_multiple_outputs_capture(self):
91+
outputs = self.capturer.run_and_capture(self.input)
92+
for output in outputs.values():
93+
if isinstance(output, tuple):
94+
self.assertEqual(len(output), 2)
95+
for part in output:
96+
self.assertIsInstance(part, torch.Tensor)
97+
98+
def test_capture_correct_outputs(self):
99+
expected_outputs_with_handles = {
100+
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
101+
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
102+
(12,): torch.tensor(
103+
[[0.1000, 0.5000], [0.2000, 0.6000], [0.3000, 0.7000], [0.4000, 0.8000]]
104+
),
105+
(13,): torch.tensor([[5.0000, 14.1200]]),
106+
(14,): torch.tensor([[5.5000, 13.6200]]),
107+
(15,): torch.tensor([[5.4000, 13.5200]]),
108+
(16,): torch.tensor([[10.8000, 6.7600]]),
109+
(17,): torch.tensor([3.0000, 1.5000]),
110+
(18,): torch.tensor([[3.6000, 4.5067]]),
111+
(19,): torch.tensor([[3.6000, 4.5067]]),
112+
(20,): torch.tensor([[0.9734, 0.9891]]),
113+
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
114+
(22,): torch.tensor([[0.9734]]),
115+
(23,): torch.tensor([[0.9891]]),
116+
(24,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
117+
}
118+
self.assertEqual(
119+
len(self.intermediate_outputs), len(expected_outputs_with_handles)
120+
)
121+
122+
for debug_handle, expected_output in expected_outputs_with_handles.items():
123+
actual_output = self.intermediate_outputs.get(debug_handle)
124+
self.assertIsNotNone(actual_output)
125+
if isinstance(expected_output, list):
126+
self.assertIsInstance(actual_output, list)
127+
self.assertEqual(len(actual_output), len(expected_output))
128+
for actual, expected in zip(actual_output, expected_output):
129+
self.assertTrue(
130+
torch.allclose(actual, expected, rtol=1e-4, atol=1e-5)
131+
)
132+
else:
133+
self.assertTrue(
134+
torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5)
135+
)

0 commit comments

Comments
 (0)