Skip to content

Commit ef6393a

Browse files
authored
Create the IntermediateOutputCapturer Class to Store the IntermediateOutput of the AOT Graph
Differential Revision: D75492919 Pull Request resolved: #11202
1 parent 5200778 commit ef6393a

File tree

4 files changed

+204
-0
lines changed

4 files changed

+204
-0
lines changed

devtools/inspector/TARGETS

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ python_library(
4848
],
4949
)
5050

51+
python_library(
52+
name = "intermediate_output_capturer",
53+
srcs = [
54+
"_intermediate_output_capturer.py",
55+
],
56+
deps = [
57+
],
58+
)
59+
5160
python_library(
5261
name = "lib",
5362
srcs = ["__init__.py"],
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
from typing import Any, Dict, Tuple
11+
12+
import torch
13+
from torch.fx import GraphModule
14+
from torch.fx.interpreter import Interpreter
15+
16+
17+
class IntermediateOutputCapturer(Interpreter):
18+
def __init__(self, module: GraphModule):
19+
super().__init__(module)
20+
21+
def run_and_capture(self, *args, **kwargs) -> Dict[Tuple[int, ...], Any]:
22+
captured_outputs = {}
23+
24+
def capture_run_node(n: torch.fx.Node) -> Any:
25+
result = super(IntermediateOutputCapturer, self).run_node(n)
26+
debug_handle = n.meta.get("debug_handle", None)
27+
if debug_handle is not None and n.op == "call_function":
28+
# Convert the debug handle to a tuple to use as a dictionary key
29+
key = (
30+
(debug_handle,)
31+
if isinstance(debug_handle, int)
32+
else tuple(debug_handle)
33+
)
34+
# Handle tensor results by detaching and cloning
35+
if isinstance(result, torch.Tensor):
36+
captured_outputs[key] = result.detach().clone()
37+
elif isinstance(result, (tuple, list)):
38+
captured_outputs[key] = [
39+
r.detach().clone() if isinstance(r, torch.Tensor) else r
40+
for r in result
41+
]
42+
else:
43+
captured_outputs[key] = result
44+
return result
45+
46+
original_run_node = self.run_node
47+
self.run_node = capture_run_node
48+
self.run(*args, **kwargs)
49+
self.run_node = original_run_node
50+
return captured_outputs

devtools/inspector/tests/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,14 @@ python_unittest(
3939
"//executorch/devtools/inspector:inspector_utils",
4040
],
4141
)
42+
43+
python_unittest(
44+
name = "intermediate_output_capturer_test",
45+
srcs = ["intermediate_output_capturer_test.py"],
46+
deps = [
47+
"//executorch/devtools/inspector:inspector",
48+
"//executorch/devtools/inspector:lib",
49+
"//executorch/devtools/inspector:intermediate_output_capturer",
50+
"//executorch/exir:lib",
51+
],
52+
)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import unittest
11+
12+
import torch
13+
import torch.nn as nn
14+
import torch.nn.functional as F
15+
from executorch.devtools.inspector._intermediate_output_capturer import (
16+
IntermediateOutputCapturer,
17+
)
18+
19+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
20+
from torch.export import export, ExportedProgram
21+
from torch.fx import GraphModule
22+
23+
24+
class TestIntermediateOutputCapturer(unittest.TestCase):
25+
@classmethod
26+
def setUpClass(cls):
27+
class TestModule(nn.Module):
28+
def __init__(self):
29+
super(TestModule, self).__init__()
30+
self.conv = nn.Conv2d(
31+
in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1
32+
)
33+
self.conv.weight = nn.Parameter(
34+
torch.tensor(
35+
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]
36+
)
37+
)
38+
self.conv.bias = nn.Parameter(torch.tensor([0.0]))
39+
40+
self.linear = nn.Linear(in_features=4, out_features=2)
41+
self.linear.weight = nn.Parameter(
42+
torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]])
43+
)
44+
self.linear.bias = nn.Parameter(torch.tensor([0.0, 0.0]))
45+
self.bias = nn.Parameter(torch.tensor([0.5, -0.5]), requires_grad=False)
46+
self.scale = nn.Parameter(torch.tensor([2.0, 0.5]), requires_grad=False)
47+
48+
def forward(self, x):
49+
x = self.conv(x)
50+
x = x.view(x.size(0), -1)
51+
x = self.linear(x)
52+
x = x + self.bias
53+
x = x - 0.1
54+
x = x * self.scale
55+
x = x / (self.scale + 1.0)
56+
x = F.relu(x)
57+
x = torch.sigmoid(x)
58+
x1, x2 = torch.split(x, 1, dim=1)
59+
return x1, x2
60+
61+
cls.model = TestModule()
62+
cls.input = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True)
63+
cls.aten_model: ExportedProgram = export(cls.model, (cls.input,), strict=True)
64+
cls.edge_program_manager: EdgeProgramManager = to_edge(
65+
cls.aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
66+
)
67+
cls.graph_module: GraphModule = cls.edge_program_manager._edge_programs[
68+
"forward"
69+
].module()
70+
cls.capturer = IntermediateOutputCapturer(cls.graph_module)
71+
cls.intermediate_outputs = cls.capturer.run_and_capture(cls.input)
72+
73+
def test_keying_with_debug_handle_tuple(self):
74+
for key in self.intermediate_outputs.keys():
75+
self.assertIsInstance(key, tuple)
76+
77+
def test_tensor_cloning_and_detaching(self):
78+
for output in self.intermediate_outputs.values():
79+
if isinstance(output, torch.Tensor):
80+
self.assertFalse(output.requires_grad)
81+
self.assertTrue(output.is_leaf)
82+
83+
def test_placeholder_nodes_are_skipped(self):
84+
for node in self.graph_module.graph.nodes:
85+
if node.op == "placeholder":
86+
self.assertNotIn(
87+
node.meta.get("debug_handle"), self.intermediate_outputs
88+
)
89+
90+
def test_multiple_outputs_capture(self):
91+
outputs = self.capturer.run_and_capture(self.input)
92+
for output in outputs.values():
93+
if isinstance(output, tuple):
94+
self.assertEqual(len(output), 2)
95+
for part in output:
96+
self.assertIsInstance(part, torch.Tensor)
97+
98+
def test_capture_correct_outputs(self):
99+
expected_outputs_with_handles = {
100+
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
101+
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
102+
(12,): torch.tensor(
103+
[[0.1000, 0.5000], [0.2000, 0.6000], [0.3000, 0.7000], [0.4000, 0.8000]]
104+
),
105+
(13,): torch.tensor([[5.0000, 14.1200]]),
106+
(14,): torch.tensor([[5.5000, 13.6200]]),
107+
(15,): torch.tensor([[5.4000, 13.5200]]),
108+
(16,): torch.tensor([[10.8000, 6.7600]]),
109+
(17,): torch.tensor([3.0000, 1.5000]),
110+
(18,): torch.tensor([[3.6000, 4.5067]]),
111+
(19,): torch.tensor([[3.6000, 4.5067]]),
112+
(20,): torch.tensor([[0.9734, 0.9891]]),
113+
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
114+
(22,): torch.tensor([[0.9734]]),
115+
(23,): torch.tensor([[0.9891]]),
116+
}
117+
self.assertEqual(
118+
len(self.intermediate_outputs), len(expected_outputs_with_handles)
119+
)
120+
121+
for debug_handle, expected_output in expected_outputs_with_handles.items():
122+
actual_output = self.intermediate_outputs.get(debug_handle)
123+
self.assertIsNotNone(actual_output)
124+
if isinstance(expected_output, list):
125+
self.assertIsInstance(actual_output, list)
126+
self.assertEqual(len(actual_output), len(expected_output))
127+
for actual, expected in zip(actual_output, expected_output):
128+
self.assertTrue(
129+
torch.allclose(actual, expected, rtol=1e-4, atol=1e-5)
130+
)
131+
else:
132+
self.assertTrue(
133+
torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5)
134+
)

0 commit comments

Comments
 (0)