Skip to content

Commit 865a7c5

Browse files
titaiwangmsjustinchuby
authored andcommitted
[ONNX] Improve the conversion of from dynamic axes to shapes (pytorch#140488)
Features: (1) Add support for tree structure. (2) Add user warning before axes to shapes conversion (3) Add suggestion of providing `dynamic_shapes` when conversion fails Notes: (1) `input_names` is crucial to the conversion, as we don't know the ONNX graph inputs. (2) min and max are set as default, so LLM has higher chance to fail if users use `dynamic_axes` in terms of the min/max constraints dependency between `attention_mask` and `sequence_length`, etc. (Found in llama-3.2-1B_Instruct) Pull Request resolved: pytorch#140488 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu <[email protected]>
1 parent 9482476 commit 865a7c5

File tree

3 files changed

+377
-40
lines changed

3 files changed

+377
-40
lines changed

test/onnx/exporter/test_api.py

+59
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@ def forward(self, x, b):
3333
return x.relu(), b.sigmoid()
3434

3535

36+
class NestedModelForDynamicShapes(torch.nn.Module):
37+
def forward(
38+
self,
39+
x: torch.Tensor,
40+
ys: list[torch.Tensor],
41+
zs: dict[str, torch.Tensor],
42+
c: torch.Tensor,
43+
):
44+
y = ys[0] + ys[1] + zs["a"] + zs["b"]
45+
w = 5
46+
if x.shape[0] < 3 and c.shape[0] != 4:
47+
return x + w, x + y, c
48+
else:
49+
return x - w, x - y, c
50+
51+
3652
class TestExportAPIDynamo(common_utils.TestCase):
3753
"""Tests for the ONNX exporter API when dynamo=True."""
3854

@@ -71,6 +87,7 @@ def test_dynamic_axes_supports_partial_dynamic_shapes(self):
7187
self.assert_export(
7288
SampleModelForDynamicShapes(),
7389
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
90+
input_names=["x", "b"],
7491
dynamic_axes={
7592
"b": [0, 1, 2],
7693
},
@@ -80,6 +97,7 @@ def test_dynamic_axes_supports_output_names(self):
8097
self.assert_export(
8198
SampleModelForDynamicShapes(),
8299
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
100+
input_names=["x", "b"],
83101
dynamic_axes={
84102
"b": [0, 1, 2],
85103
},
@@ -181,6 +199,47 @@ def forward(self, x):
181199
assert onnx_program is not None
182200
onnx_testing.assert_onnx_program(onnx_program)
183201

202+
def test_dynamic_shapes_supports_nested_input_model_with_input_names_assigned(self):
203+
# kwargs can still be renamed as long as it's in order
204+
input_names = ["input_x", "input_y", "input_z", "d", "e", "f"]
205+
206+
dynamic_axes = {
207+
"input_x": {0: "dim"},
208+
"input_y": {0: "dim"},
209+
"input_z": {0: "dim"},
210+
"d": {0: "dim"},
211+
"e": {0: "dim"},
212+
}
213+
214+
model = NestedModelForDynamicShapes()
215+
input = (
216+
torch.ones(5),
217+
[torch.zeros(5), torch.ones(5)],
218+
{"a": torch.zeros(5), "b": torch.ones(5)},
219+
torch.ones(4),
220+
)
221+
222+
self.assert_export(
223+
model, input, dynamic_axes=dynamic_axes, input_names=input_names
224+
)
225+
226+
# Check whether inputs are dynamically shaped
227+
onnx_program = torch.onnx.export(
228+
model,
229+
input,
230+
dynamic_axes=dynamic_axes,
231+
input_names=input_names,
232+
dynamo=True,
233+
)
234+
self.assertTrue(
235+
all(
236+
[
237+
input.type.tensor_type.shape.dim[0].dim_param
238+
for input in onnx_program.model_proto.graph.input
239+
][:-1]
240+
)
241+
)
242+
184243
def test_refine_dynamic_shapes_with_onnx_export(self):
185244
# NOTE: From test/export/test_export.py
186245

test/onnx/exporter/test_compat.py

+249
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# Owner(s): ["module: onnx"]
2+
"""Unit tests for the _compat module."""
3+
4+
from __future__ import annotations
5+
6+
import torch
7+
from torch.onnx._internal.exporter import _compat
8+
from torch.testing._internal import common_utils
9+
from torch.utils import _pytree
10+
11+
12+
class SingnatureOnlyLlamaModel(torch.nn.Module):
13+
def forward(
14+
self,
15+
input_ids: torch.LongTensor = None,
16+
attention_mask: torch.Tensor | None = None,
17+
position_ids: torch.LongTensor | None = None,
18+
past_key_values: list[torch.FloatTensor] | None = None,
19+
inputs_embeds: torch.FloatTensor | None = None,
20+
labels: torch.LongTensor | None = None,
21+
use_cache: bool | None = None,
22+
output_attentions: bool | None = None,
23+
output_hidden_states: bool | None = None,
24+
return_dict: bool | None = None,
25+
cache_position: torch.LongTensor | None = None,
26+
num_logits_to_keep: int = 0,
27+
**kwargs,
28+
):
29+
pass
30+
31+
32+
@common_utils.instantiate_parametrized_tests
33+
class TestPyTreeDynamicAxesShapes(common_utils.TestCase):
34+
# The test can't be parametrized because the torch.export.Dim generates objects,
35+
# and we need the exact same object to compare them.
36+
def test__unflatten_dynamic_shapes_with_inputs_tree_succeeds_on_tuple(self):
37+
inputs = (torch.randn(1, 2, 3), torch.randn(1, 2, 3))
38+
x_dim = torch.export.Dim("x_dim_0")
39+
y_dim = torch.export.Dim("y_dim_1")
40+
dynamic_shapes = {
41+
"x": {0: x_dim},
42+
"y": {1: y_dim},
43+
}
44+
unflatten_dynamic_shapes = _compat._unflatten_dynamic_shapes_with_inputs_tree(
45+
inputs, dynamic_shapes
46+
)
47+
48+
expected_dynamic_shapes = (
49+
{0: x_dim},
50+
{1: y_dim},
51+
)
52+
self.assertEqual(unflatten_dynamic_shapes, expected_dynamic_shapes)
53+
54+
def test__unflatten_dynamic_shapes_with_inputs_tree_succeeds_on_dict(self):
55+
inputs = {"x": torch.randn(1, 2, 3), "y": torch.randn(1, 2, 3)}
56+
x_dim = torch.export.Dim("x_dim_0")
57+
y_dim = torch.export.Dim("y_dim_1")
58+
dynamic_shapes = {
59+
"x": {0: x_dim},
60+
"y": {1: y_dim},
61+
}
62+
unflatten_dynamic_shapes = _compat._unflatten_dynamic_shapes_with_inputs_tree(
63+
inputs, dynamic_shapes
64+
)
65+
66+
expected_dynamic_shapes = {
67+
"x": {0: x_dim},
68+
"y": {1: y_dim},
69+
}
70+
self.assertEqual(unflatten_dynamic_shapes, expected_dynamic_shapes)
71+
72+
def test__unflatten_dynamic_shapes_with_inputs_tree_succeeds_on_tuple_of_mixed_structure(
73+
self,
74+
):
75+
inputs = (
76+
torch.randn(1, 2, 3),
77+
({"x0": torch.randn(1, 2, 3)}, {"x1": torch.randn(1, 2, 3)}),
78+
(torch.randn(1, 2, 3), torch.randn(1, 2, 3)),
79+
[torch.randn(1, 2, 3), torch.randn(1, 2, 3)],
80+
)
81+
w_dim_0 = torch.export.Dim("w_dim_0")
82+
x0_dim_1 = torch.export.Dim("x0_dim_1")
83+
x0_dim_2 = torch.export.Dim("x0_dim_2")
84+
x1_dim_1 = torch.export.Dim("x1_dim_1")
85+
y0_dim_0 = torch.export.Dim("y0_dim_0")
86+
y0_dim_1 = torch.export.Dim("y0_dim_1")
87+
y1_dim_2 = torch.export.Dim("y1_dim_2")
88+
z0_dim_2 = torch.export.Dim("z0_dim_2")
89+
z1_dim_1 = torch.export.Dim("z1_dim_1")
90+
dynamic_shapes = {
91+
"w": {0: w_dim_0},
92+
"x0": {1: x0_dim_1, 2: x0_dim_2},
93+
"x1": {1: x1_dim_1},
94+
"y0": {0: y0_dim_0, 1: y0_dim_1},
95+
"y1": {2: y1_dim_2},
96+
"z0": {2: z0_dim_2},
97+
"z1": {1: z1_dim_1},
98+
}
99+
unflatten_dynamic_shapes = _compat._unflatten_dynamic_shapes_with_inputs_tree(
100+
inputs, dynamic_shapes
101+
)
102+
expected_dynamic_shapes = (
103+
{0: w_dim_0},
104+
({"x0": {1: x0_dim_1, 2: x0_dim_2}}, {"x1": {1: x1_dim_1}}),
105+
({0: y0_dim_0, 1: y0_dim_1}, {2: y1_dim_2}),
106+
[{2: z0_dim_2}, {1: z1_dim_1}],
107+
)
108+
self.assertEqual(unflatten_dynamic_shapes, expected_dynamic_shapes)
109+
110+
@common_utils.parametrize(
111+
"model, args, kwargs,input_names, output_names, dynamic_axes, expected_dynamic_shapes",
112+
[
113+
# llama-3.2-1B-Instruct (trimmed)
114+
(
115+
SingnatureOnlyLlamaModel(),
116+
(),
117+
{
118+
"input_ids": torch.randn(2, 16),
119+
"attention_mask": torch.randn(2, 32),
120+
"position_ids": torch.randn(2, 16),
121+
"past_key_values": [
122+
(torch.randn(2, 8, 16, 64), torch.randn(2, 8, 16, 64)),
123+
(torch.randn(2, 8, 16, 64), torch.randn(2, 8, 16, 64)),
124+
],
125+
},
126+
[
127+
"input_ids",
128+
"attention_mask",
129+
"position_ids",
130+
"past_key_values.0.key",
131+
"past_key_values.0.value",
132+
"past_key_values.1.key",
133+
"past_key_values.1.value",
134+
],
135+
[
136+
"logits",
137+
"present.0.key",
138+
"present.0.value",
139+
"present.1.key",
140+
"present.1.value",
141+
],
142+
{
143+
"input_ids": {0: "batch_size", 1: "sequence_length"},
144+
"attention_mask": {
145+
0: "batch_size",
146+
1: "past_sequence_length + sequence_length",
147+
},
148+
"position_ids": {0: "batch_size", 1: "sequence_length"},
149+
"past_key_values.0.key": {
150+
0: "batch_size",
151+
2: "past_sequence_length",
152+
},
153+
"past_key_values.0.value": {
154+
0: "batch_size",
155+
2: "past_sequence_length",
156+
},
157+
"past_key_values.1.key": {
158+
0: "batch_size",
159+
2: "past_sequence_length",
160+
},
161+
"past_key_values.1.value": {
162+
0: "batch_size",
163+
2: "past_sequence_length",
164+
},
165+
"logits": {0: "batch_size", 1: "sequence_length"},
166+
"present.0.key": {
167+
0: "batch_size",
168+
2: "past_sequence_length + sequence_length",
169+
},
170+
"present.0.value": {
171+
0: "batch_size",
172+
2: "past_sequence_length + sequence_length",
173+
},
174+
"present.1.key": {
175+
0: "batch_size",
176+
2: "past_sequence_length + sequence_length",
177+
},
178+
"present.1.value": {
179+
0: "batch_size",
180+
2: "past_sequence_length + sequence_length",
181+
},
182+
},
183+
[
184+
{
185+
0: torch.export.Dim("batch_size"),
186+
1: torch.export.Dim("sequence_length"),
187+
},
188+
{
189+
0: torch.export.Dim("batch_size"),
190+
1: torch.export.Dim("past_sequence_lengthsequence_length"),
191+
},
192+
{
193+
0: torch.export.Dim("batch_size"),
194+
1: torch.export.Dim("sequence_length"),
195+
},
196+
[
197+
(
198+
{
199+
0: torch.export.Dim("batch_size"),
200+
2: torch.export.Dim("past_sequence_length"),
201+
},
202+
{
203+
0: torch.export.Dim("batch_size"),
204+
2: torch.export.Dim("past_sequence_length"),
205+
},
206+
),
207+
(
208+
{
209+
0: torch.export.Dim("batch_size"),
210+
2: torch.export.Dim("past_sequence_length"),
211+
},
212+
{
213+
0: torch.export.Dim("batch_size"),
214+
2: torch.export.Dim("past_sequence_length"),
215+
},
216+
),
217+
],
218+
],
219+
)
220+
],
221+
)
222+
def test__from_dynamic_axes_to_dynamic_shapes_succeeds_on_llm(
223+
self,
224+
model,
225+
args,
226+
kwargs,
227+
input_names,
228+
output_names,
229+
dynamic_axes,
230+
expected_dynamic_shapes,
231+
):
232+
dynamic_shapes = _compat._from_dynamic_axes_to_dynamic_shapes(
233+
model,
234+
args,
235+
kwargs,
236+
input_names=input_names,
237+
output_names=output_names,
238+
dynamic_axes=dynamic_axes,
239+
)
240+
241+
# NOTE: torch.export.Dim being an object makes it impossible to compare the objects directly.
242+
# And it's unrealistic to test whole model, so we are testing the structure of the dynamic_shapes.
243+
_, tree1 = _pytree.tree_flatten(dynamic_shapes)
244+
_, tree2 = _pytree.tree_flatten(expected_dynamic_shapes)
245+
self.assertEqual(tree1, tree2)
246+
247+
248+
if __name__ == "__main__":
249+
common_utils.run_tests()

0 commit comments

Comments
 (0)