1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import unittest
8
+
9
+ import torch
10
+ from executorch .backends .xnnpack .test .tester import Export , Tester
11
+ from executorch .exir .dialects ._ops import ops as exir_ops
12
+ from torch .export import Dim
13
+
14
+
15
+ class TestSqueeze (unittest .TestCase ):
16
+ class Squeeze (torch .nn .Module ):
17
+ def __init__ (self , dims ):
18
+ super ().__init__ ()
19
+ self .dims = dims
20
+
21
+ def forward (self , x ):
22
+ return torch .squeeze (x , self .dims )
23
+
24
+ def test_fp32_squeeze (self ):
25
+ inputs = (torch .randn (1 ,2 ,1 ,4 ,1 ),)
26
+ squeeze_dims = (0 , 2 , 4 )
27
+
28
+ for dims in squeeze_dims :
29
+ (
30
+ Tester (self .Squeeze (dims ), inputs )
31
+ .export ()
32
+ .check_node_count ({
33
+ torch .ops .aten .squeeze .dim : 1 ,
34
+ })
35
+ .to_edge_transform_and_lower ()
36
+ .check_node_count ({
37
+ exir_ops .edge .aten .squeeze_copy .dim : 0 ,
38
+ exir_ops .edge .aten .view_copy .default : 0 ,
39
+ torch .ops .higher_order .executorch_call_delegate : 1 ,
40
+ })
41
+ .run_method_and_compare_outputs ()
42
+ )
43
+
44
+ def test_fp16_squeeze (self ):
45
+ inputs = (torch .randn (1 ,2 ,1 ,4 ,1 ).to (torch .float16 ),)
46
+ squeeze_dims = (0 , 2 , 4 )
47
+
48
+ for dims in squeeze_dims :
49
+ (
50
+ Tester (self .Squeeze (dims ), inputs )
51
+ .export ()
52
+ .check_node_count ({
53
+ torch .ops .aten .squeeze .dim : 1 ,
54
+ })
55
+ .to_edge_transform_and_lower ()
56
+ .check_node_count ({
57
+ exir_ops .edge .aten .squeeze_copy .dim : 0 ,
58
+ exir_ops .edge .aten .view_copy .default : 0 ,
59
+ torch .ops .higher_order .executorch_call_delegate : 1 ,
60
+ })
61
+ .run_method_and_compare_outputs ()
62
+ )
63
+
64
+ def test_fp32_squeeze_dynamic (self ):
65
+ inputs = (torch .randn (1 ,2 ,1 ,4 ,1 ),)
66
+ squeeze_dims = (0 , 2 , 4 )
67
+ dynamic_shapes = { "x" : { 1 : Dim ("x_1" , min = 1 , max = 10 ) } }
68
+
69
+ for dims in squeeze_dims :
70
+ (
71
+ Tester (self .Squeeze (dims ), inputs )
72
+ .export (Export (dynamic_shapes = dynamic_shapes ))
73
+ .check_node_count ({
74
+ torch .ops .aten .squeeze .dim : 1 ,
75
+ })
76
+ .to_edge_transform_and_lower ()
77
+ .check_node_count ({
78
+ exir_ops .edge .aten .squeeze_copy .dim : 0 ,
79
+ exir_ops .edge .aten .view_copy .default : 0 ,
80
+ torch .ops .higher_order .executorch_call_delegate : 1 ,
81
+ })
82
+ .run_method_and_compare_outputs ()
83
+ )
84
+
85
+ def test_fp32_squeeze_unsupported_dynamism (self ):
86
+ inputs = (torch .randn (1 ,2 ,1 ,4 ,1 ),)
87
+ squeeze_dims = (0 , 2 , 4 )
88
+ # Only one dynamic dimension is supported.
89
+ dynamic_shapes = { "x" : {
90
+ 1 : Dim ("x_1" , min = 1 , max = 10 ),
91
+ 3 : Dim ("x_3" , min = 1 , max = 10 ),
92
+ } }
93
+
94
+ for dims in squeeze_dims :
95
+ (
96
+ Tester (self .Squeeze (dims ), inputs )
97
+ .export (Export (dynamic_shapes = dynamic_shapes ))
98
+ .check_node_count ({
99
+ torch .ops .aten .squeeze .dim : 1 ,
100
+ })
101
+ .to_edge_transform_and_lower ()
102
+ .check_node_count ({
103
+ exir_ops .edge .aten .squeeze_copy .dims : 1 ,
104
+ torch .ops .higher_order .executorch_call_delegate : 0 ,
105
+ })
106
+ .run_method_and_compare_outputs ()
107
+ )
0 commit comments