@@ -156,87 +156,6 @@ def forward(self, x):
156156 torch ._dynamo .reset ()
157157
158158
159- class TestFP32Accumulation (TestCase ):
160- def test_fp32_acc (self ):
161- class FP32Acc (torch .nn .Module ):
162- def forward (self , input , weight ):
163- out = torch .ops .aten .mm .default (input , weight )
164- return out
165-
166- inputs = [
167- torch .rand ((3 , 4 )).cuda (),
168- torch .rand ((4 , 5 )).cuda (),
169- ]
170-
171- fx_graph = torch .fx .symbolic_trace (FP32Acc ())
172- expected_ops = {torch .ops .aten ._to_copy .default , torch .ops .aten .mm .default }
173- unexpected_ops = {}
174-
175- unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
176- fx_graph ,
177- inputs ,
178- expected_ops = expected_ops ,
179- unexpected_ops = unexpected_ops ,
180- min_block_size = 1 ,
181- use_fp32_acc = True ,
182- )
183-
184- self .assertEqual (
185- len (unexpected_ops_seen ),
186- 0 ,
187- f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
188- )
189-
190- self .assertEqual (
191- len (expected_ops_unseen ),
192- 0 ,
193- f"The following expected ops were not encountered: { expected_ops_unseen } " ,
194- )
195- torch ._dynamo .reset ()
196-
197- def test_fp32_acc_for_addmm (self ):
198- class FP32Acc (torch .nn .Module ):
199- def forward (self , input , mat1 , mat2 ):
200- out = torch .ops .aten .addmm .default (input , mat1 , mat2 , beta = 20 , alpha = 2 )
201- return out
202-
203- inputs = [
204- torch .rand ((3 , 5 )).cuda (),
205- torch .rand ((3 , 4 )).cuda (),
206- torch .rand ((4 , 5 )).cuda (),
207- ]
208-
209- fx_graph = torch .fx .symbolic_trace (FP32Acc ())
210- expected_ops = {
211- torch .ops .aten ._to_copy .default ,
212- torch .ops .aten .mm .default ,
213- torch .ops .aten .add .Tensor ,
214- }
215- unexpected_ops = {}
216-
217- unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
218- fx_graph ,
219- inputs ,
220- expected_ops = expected_ops ,
221- unexpected_ops = unexpected_ops ,
222- min_block_size = 1 ,
223- use_fp32_acc = True ,
224- )
225-
226- self .assertEqual (
227- len (unexpected_ops_seen ),
228- 0 ,
229- f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
230- )
231-
232- self .assertEqual (
233- len (expected_ops_unseen ),
234- 0 ,
235- f"The following expected ops were not encountered: { expected_ops_unseen } " ,
236- )
237- torch ._dynamo .reset ()
238-
239-
240159class TestComplexSubgraph (TestCase ):
241160 def test_complex_subgraph (self ):
242161 BATCH = 1
0 commit comments