|
3 | 3 |
|
4 | 4 | import logging
|
5 | 5 | import tempfile
|
6 |
| -from typing import Mapping, Tuple |
7 | 6 |
|
8 | 7 | import onnx
|
9 | 8 | import onnx.inliner
|
@@ -111,75 +110,6 @@ def forward(self, x):
|
111 | 110 |
|
112 | 111 | _ = dynamo_export(TopKModel(), x, export_options=self.export_options)
|
113 | 112 |
|
114 |
| - def test_symbolic_shape_of_values_inside_function_is_exported_as_graph_value_info( |
115 |
| - self, |
116 |
| - ): |
117 |
| - class SubModule(torch.nn.Module): |
118 |
| - def forward(self, x, y, bias): |
119 |
| - output = x @ y |
120 |
| - return output + bias |
121 |
| - |
122 |
| - class Module(torch.nn.Module): |
123 |
| - def __init__(self) -> None: |
124 |
| - super().__init__() |
125 |
| - self.submodule = SubModule() |
126 |
| - |
127 |
| - def forward(self, x, y, bias): |
128 |
| - return self.submodule(x, y, bias) |
129 |
| - |
130 |
| - x = torch.randn(2, 3) |
131 |
| - y = torch.randn(3, 4) |
132 |
| - bias = torch.randn(4) |
133 |
| - onnx_program = torch.onnx.dynamo_export( |
134 |
| - Module(), |
135 |
| - x, |
136 |
| - y, |
137 |
| - bias, |
138 |
| - export_options=torch.onnx.ExportOptions(dynamic_shapes=True), |
139 |
| - ) |
140 |
| - model_proto = onnx_program.model_proto |
141 |
| - |
142 |
| - # Assert value_info for values inside local function can be retrieved |
143 |
| - def _assert_node_outputs_has_value_info( |
144 |
| - node: onnx.NodeProto, |
145 |
| - value_infos: Mapping[str, onnx.ValueInfoProto], |
146 |
| - local_functions: Mapping[Tuple[str, str], onnx.FunctionProto], |
147 |
| - exclude_names_in_value_info, |
148 |
| - function_id: str = "", |
149 |
| - ): |
150 |
| - for output in node.output: |
151 |
| - name = f"{function_id}/{output}" if function_id else output |
152 |
| - if name not in exclude_names_in_value_info: |
153 |
| - self.assertIn(name, value_infos) |
154 |
| - if node.domain.startswith("pkg.onnxscript.torch_lib"): |
155 |
| - # No shape info available for values inside torchlib functions. |
156 |
| - return |
157 |
| - if ( |
158 |
| - function := local_functions.get((node.domain, node.op_type)) |
159 |
| - ) is not None: |
160 |
| - for node in function.node: |
161 |
| - function_id = f"{function.domain}::{function.name}" |
162 |
| - _assert_node_outputs_has_value_info( |
163 |
| - node, |
164 |
| - value_infos, |
165 |
| - local_functions, |
166 |
| - exclude_names_in_value_info, |
167 |
| - function_id, |
168 |
| - ) |
169 |
| - |
170 |
| - type_infos = {vi.name: vi for vi in model_proto.graph.value_info} |
171 |
| - functions = {(f.domain, f.name): f for f in model_proto.functions} |
172 |
| - # NOTE: inputs, outputs, and initializers are not included in value_info spec |
173 |
| - exclude_names_in_value_info = ( |
174 |
| - [input.name for input in model_proto.graph.input] |
175 |
| - + [output.name for output in model_proto.graph.output] |
176 |
| - + [init.name for init in model_proto.graph.initializer] |
177 |
| - ) |
178 |
| - for node in model_proto.graph.node: |
179 |
| - _assert_node_outputs_has_value_info( |
180 |
| - node, type_infos, functions, exclude_names_in_value_info |
181 |
| - ) |
182 |
| - |
183 | 113 | def test_dynamo_export_retains_readable_parameter_and_buffer_names(self):
|
184 | 114 | class SubModule(torch.nn.Module):
|
185 | 115 | def __init__(self) -> None:
|
|
0 commit comments