77
88
99class TestUtils :
10- """Utility class for testing Conv2D layers with and without bias, profiling FLOPs and parameters using THOP."""
10+ """Utility class for testing Conv2D layers with and without bias, profiling MACs and parameters using THOP."""
1111
1212 def test_conv2d_no_bias (self ):
1313 """Tests a 2D Conv layer without bias using THOP profiling on predefined input dimensions and parameters."""
@@ -20,12 +20,12 @@ def test_conv2d_no_bias(self):
2020 out = net (data )
2121
2222 _ , _ , oh , ow = out .shape
23-
24- flops , params = profile (net , inputs = (data ,))
25- assert flops == 810000 , f"{ flops } v.s. 810000"
23+ print ( f"Conv2d: in= { ih } x { iw } , kernel= { kh } x { kw } , stride= { s } , padding= { p } , out= { oh } x { ow } " )
24+ macs , params = profile (net , inputs = (data ,))
25+ assert macs == 810000 , f"{ macs } v.s. 810000"
2626
2727 def test_conv2d (self ):
28- """Tests Conv2D layer with bias, profiling FLOPs and params for specific input dimensions and layer configs."""
28+ """Tests Conv2D layer with bias, profiling MACs and params for specific input dimensions and layer configs."""
2929 n , in_c , ih , iw = 1 , 3 , 32 , 32 # torch.randint(1, 10, (4,)).tolist()
3030 out_c , kh , kw = 12 , 5 , 5
3131 s , p , d , g = 1 , 1 , 1 , 1
@@ -35,12 +35,12 @@ def test_conv2d(self):
3535 out = net (data )
3636
3737 _ , _ , oh , ow = out .shape
38-
39- flops , params = profile (net , inputs = (data ,))
40- assert flops == 810000 , f"{ flops } v.s. 810000"
38+ print ( f"Conv2d: in= { ih } x { iw } , kernel= { kh } x { kw } , stride= { s } , padding= { p } , out= { oh } x { ow } " )
39+ macs , params = profile (net , inputs = (data ,))
40+ assert macs == 810000 , f"{ macs } v.s. 810000"
4141
4242 def test_conv2d_random (self ):
43- """Test Conv2D layer with random parameters and validate the computed FLOPs and parameters using 'profile'."""
43+ """Test Conv2D layer with random parameters and validate the computed MACs and parameters using 'profile'."""
4444 for _ in range (10 ):
4545 out_c , kh , kw = torch .randint (1 , 20 , (3 ,)).tolist ()
4646 n , in_c , ih , iw = torch .randint (1 , 20 , (4 ,)).tolist () # torch.randint(1, 10, (4,)).tolist()
@@ -52,10 +52,168 @@ def test_conv2d_random(self):
5252 data = torch .randn (n , in_c , ih , iw )
5353 out = net (data )
5454
55+ _ , _ , oh , ow = out .shape
56+ print (f"Conv2d: in={ ih } x{ iw } , kernel={ kh } x{ kw } , stride={ s } , padding={ p } , out={ oh } x{ ow } " )
57+ macs , params = profile (net , inputs = (data ,))
58+ assert macs == n * out_c * oh * ow // g * in_c * kh * kw , (
59+ f"{ macs } v.s. { n * out_c * oh * ow // g * in_c * kh * kw } "
60+ )
61+
62+ def test_convtranspose2d_no_bias (self ):
63+ """Tests a 2D ConvTranspose layer without bias using THOP profiling on predefined input dimensions and
64+ parameters.
65+ """
66+ n , in_c , ih , iw = 1 , 3 , 2 , 2
67+ out_c , kh , kw = 1 , 2 , 2
68+ s , p , d , g = 2 , 0 , 1 , 1
69+
70+ net = nn .ConvTranspose2d (
71+ in_c , out_c , kernel_size = (kh , kw ), stride = s , padding = p , dilation = d , groups = g , bias = False
72+ )
73+ data = torch .randn (n , in_c , ih , iw )
74+ out = net (data )
75+
76+ _ , _ , oh , ow = out .shape
77+
78+ profile_result = profile (net , inputs = (data ,))
79+ macs = profile_result [0 ]
80+ profile_result [1 ]
81+ # For ConvTranspose: MACs = input_size * (output_channels / groups) * kernel_size
82+ print (f"ConvTranspose2d: in={ ih } x{ iw } , kernel={ kh } x{ kw } , stride={ s } , padding={ p } , out={ oh } x{ ow } " )
83+ expected_macs = n * in_c * ih * iw * (out_c // g ) * kh * kw
84+ assert macs == expected_macs , f"{ macs } v.s. { expected_macs } "
85+
86+ def test_convtranspose2d (self ):
87+ """Tests ConvTranspose2D layer with bias, profiling MACs and params for specific input dimensions and layer
88+ configs.
89+ """
90+ n , in_c , ih , iw = 1 , 3 , 2 , 2
91+ out_c , kh , kw = 1 , 2 , 2
92+ s , p , d , g = 2 , 0 , 1 , 1
93+
94+ net = nn .ConvTranspose2d (
95+ in_c , out_c , kernel_size = (kh , kw ), stride = s , padding = p , dilation = d , groups = g , bias = True
96+ )
97+ data = torch .randn (n , in_c , ih , iw )
98+ out = net (data )
99+
100+ _ , _ , oh , ow = out .shape
101+
102+ profile_result = profile (net , inputs = (data ,))
103+ macs = profile_result [0 ]
104+ profile_result [1 ]
105+ # For ConvTranspose: MACs = input_size * (output_channels / groups) * kernel_size
106+ print (f"ConvTranspose2d: in={ ih } x{ iw } , kernel={ kh } x{ kw } , stride={ s } , padding={ p } , out={ oh } x{ ow } " )
107+ expected_macs = n * in_c * ih * iw * (out_c // g ) * kh * kw
108+ assert macs == expected_macs , f"{ macs } v.s. { expected_macs } "
109+
110+ def test_convtranspose2d_groups (self ):
111+ """Tests ConvTranspose2D layer with groups, validating MAC calculation for grouped transposed convolution."""
112+ n , in_c , ih , iw = 1 , 8 , 4 , 4
113+ out_c , kh , kw = 4 , 3 , 3
114+ s , p , d , g = 1 , 1 , 1 , 2
115+
116+ net = nn .ConvTranspose2d (
117+ in_c , out_c , kernel_size = (kh , kw ), stride = s , padding = p , dilation = d , groups = g , bias = False
118+ )
119+ data = torch .randn (n , in_c , ih , iw )
120+ out = net (data )
121+
122+ _ , _ , oh , ow = out .shape
123+
124+ profile_result = profile (net , inputs = (data ,))
125+ macs = profile_result [0 ]
126+ profile_result [1 ]
127+ # For ConvTranspose with groups: MACs = input_size * (output_channels / groups) * kernel_size
128+ print (f"ConvTranspose2d: in={ ih } x{ iw } , kernel={ kh } x{ kw } , stride={ s } , padding={ p } , out={ oh } x{ ow } " )
129+ expected_macs = n * in_c * ih * iw * (out_c // g ) * kh * kw
130+ assert macs == expected_macs , f"{ macs } v.s. { expected_macs } "
131+
132+ def test_convtranspose2d_random (self ):
133+ """Test ConvTranspose2D layer with random parameters and validate the computed MACs and parameters using
134+ 'profile'.
135+ """
136+ for _ in range (10 ):
137+ # Generate random parameters ensuring valid ConvTranspose configurations
138+ out_c , kh , kw = torch .randint (1 , 10 , (3 ,)).tolist ()
139+ n , in_c = torch .randint (1 , 5 , (2 ,)).tolist ()
140+ stride = int (torch .randint (1 , 3 , (1 ,)).item ()) # stride
141+ padding = int (torch .randint (0 , 2 , (1 ,)).item ()) # padding
142+ dilation , groups = 1 , 1 # Keep dilation=1 and groups=1 for simplicity
143+
144+ # Ensure input size is large enough to produce valid output
145+ # ConvTranspose output size formula: (input_size - 1) * stride - 2 * padding + kernel_size
146+ # To ensure positive output, we need: input_size >= (2 * padding + 1) / stride + 1
147+ min_input_size = max (3 , (2 * padding + kh ) // stride + 1 , (2 * padding + kw ) // stride + 1 )
148+ ih , iw = torch .randint (min_input_size , min_input_size + 10 , (2 ,)).tolist ()
149+
150+ net = nn .ConvTranspose2d (
151+ in_c ,
152+ out_c ,
153+ kernel_size = (kh , kw ),
154+ stride = stride ,
155+ padding = padding ,
156+ dilation = dilation ,
157+ groups = groups ,
158+ bias = False ,
159+ )
160+ data = torch .randn (n , in_c , ih , iw )
161+ out = net (data )
162+
55163 _ , _ , oh , ow = out .shape
56164
57- flops , params = profile (net , inputs = (data ,))
58- print (flops , params )
59- assert flops == n * out_c * oh * ow // g * in_c * kh * kw , (
60- f"{ flops } v.s. { n * out_c * oh * ow // g * in_c * kh * kw } "
165+ profile_result = profile (net , inputs = (data ,))
166+ macs = profile_result [0 ]
167+ profile_result [1 ]
168+ # For ConvTranspose: MACs = input_size * (output_channels / groups) * kernel_size
169+ expected_macs = n * in_c * ih * iw * (out_c // groups ) * kh * kw
170+ print (
171+ f"ConvTranspose2d: in={ ih } x{ iw } , kernel={ kh } x{ kw } , stride={ stride } , padding={ padding } , out={ oh } x{ ow } , macs={ macs } "
61172 )
173+ assert macs == expected_macs , f"ConvTranspose2d MACs: { macs } v.s. { expected_macs } "
174+
175+ def test_conv_vs_convtranspose_symmetry (self ):
176+ """
177+ Test that Conv2d and ConvTranspose2d with symmetric configurations have equal MAC counts.
178+
179+ Test case: Conv2d downsamples 4x4 -> 2x2, ConvTranspose2d upsamples 2x2 -> 4x4.
180+ """
181+ # Conv2d: 4x4 -> 2x2
182+ conv_net = nn .Conv2d (1 , 3 , kernel_size = 2 , stride = 2 , bias = False )
183+ conv_data = torch .randn (1 , 1 , 4 , 4 )
184+ conv_out = conv_net (conv_data )
185+ conv_profile_result = profile (conv_net , inputs = (conv_data ,))
186+ conv_macs = conv_profile_result [0 ]
187+ conv_params = conv_profile_result [1 ]
188+
189+ # ConvTranspose2d: 2x2 -> 4x4 (symmetric operation)
190+ convt_net = nn .ConvTranspose2d (3 , 1 , kernel_size = 2 , stride = 2 , bias = False )
191+ convt_data = torch .randn (1 , 3 , 2 , 2 )
192+ convt_out = convt_net (convt_data )
193+ convt_profile_result = profile (convt_net , inputs = (convt_data ,))
194+ convt_macs = convt_profile_result [0 ]
195+ convt_params = convt_profile_result [1 ]
196+
197+ # Verify symmetric operations have equal MAC counts
198+ assert conv_macs == convt_macs , f"Symmetric operations should have equal MACs: { conv_macs } != { convt_macs } "
199+
200+ # Manual verification
201+ # Conv: output_size * (input_channels / groups) * kernel_size
202+ conv_expected = (
203+ conv_out .numel ()
204+ * (conv_net .in_channels // conv_net .groups )
205+ * (conv_net .kernel_size [0 ] * conv_net .kernel_size [1 ])
206+ )
207+ # ConvTranspose: input_size * (output_channels / groups) * kernel_size
208+ convt_expected = (
209+ convt_data .numel ()
210+ * (convt_net .out_channels // convt_net .groups )
211+ * (convt_net .kernel_size [0 ] * convt_net .kernel_size [1 ])
212+ )
213+ print (f"Conv2d: { conv_data .shape } -> { conv_out .shape } , MACs: { conv_macs } , Params: { conv_params } " )
214+ print (f"ConvTranspose2d: { convt_data .shape } -> { convt_out .shape } , MACs: { convt_macs } , Params: { convt_params } " )
215+ print (f"Conv2d expected MACs: { conv_expected } , ConvTranspose2d expected MACs: { convt_expected } " )
216+ assert conv_macs == conv_expected , f"Conv2d MAC calculation incorrect: { conv_macs } != { conv_expected } "
217+ assert convt_macs == convt_expected , (
218+ f"ConvTranspose2d MAC calculation incorrect: { convt_macs } != { convt_expected } "
219+ )
0 commit comments