-
Notifications
You must be signed in to change notification settings - Fork 663
Expand file tree
/
Copy pathdecompose-complex-ops.mlir
More file actions
1039 lines (944 loc) · 89.1 KB
/
decompose-complex-ops.mlir
File metadata and controls
1039 lines (944 loc) · 89.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s
// CHECK-LABEL: func.func @matmul_no_decompose
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
func.func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @matmul_decompose_2d
// CHECK: torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
func.func @matmul_decompose_2d(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @matmul_no_decompose_3d_dynamic(
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
func.func @matmul_no_decompose_3d_dynamic(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @matmul_decompose_3d_static(
// CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[4,?,?],f32>, !torch.vtensor<[4,?,?],f32> -> !torch.tensor
func.func @matmul_decompose_3d_static(%arg0: !torch.vtensor<[4,?,?],f32>, %arg1: !torch.vtensor<[4,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,?,?],f32>, !torch.vtensor<[4,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @matmul_no_decompose_3d_broadcast(
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,?,?],f32>, !torch.vtensor<[1,?,?],f32> -> !torch.tensor
func.func @matmul_no_decompose_3d_broadcast(%arg0: !torch.vtensor<[4,?,?],f32>, %arg1: !torch.vtensor<[1,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,?,?],f32>, !torch.vtensor<[1,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @argmax_rank_1
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[VALUES:.*]], %[[INDICES:.*]] = torch.aten.max.dim %arg0, %[[I0]], %[[FALSE]] : !torch.vtensor<[20],si32>, !torch.int, !torch.bool -> !torch.vtensor<[],si32>, !torch.vtensor<[],si64>
// CHECK: return %[[INDICES]] : !torch.vtensor<[],si64>
func.func @argmax_rank_1(%arg0: !torch.vtensor<[20],si32>) -> !torch.vtensor<[],si64> {
%none = torch.constant.none
%false = torch.constant.bool false
%7 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[20],si32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
return %7 : !torch.vtensor<[],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.type_as$basic(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG_1]] : !torch.tensor -> !torch.int
// CHECK: %[[VAR:.*]] = torch.aten.to.dtype %[[ARG_0]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
// CHECK: return %[[VAR]] : !torch.tensor
func.func @torch.aten.type_as$basic(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor {
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.type_as$fold(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor<[?],f16>, %[[ARG_1:.*]]: !torch.tensor<[?,?],f16>) -> !torch.tensor<[?],f16> {
// CHECK: return %[[ARG_0]] : !torch.tensor<[?],f16>
func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch.tensor<[?,?],f16>) -> !torch.tensor<[?],f16> {
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16>
return %0 : !torch.tensor<[?], f16>
}
// -----
// CHECK-LABEL: func.func @torch.aten.one_hot$fold(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3],si64>, %arg1: !torch.int) -> !torch.vtensor<[3,?],si64> {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[ARANGE:.*]] = torch.aten.arange.start_step %[[INT0]], %arg1, %[[INT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
// CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze %[[ARG_0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,1],si64>
// CHECK: %[[EQ:.*]] = torch.aten.eq.Tensor %[[UNSQUEEZE]], %[[ARANGE]] : !torch.vtensor<[3,1],si64>, !torch.vtensor<[?],si64> -> !torch.vtensor<[3,?],i1>
// CHECK: %[[RESULT:.*]] = torch.aten.to.dtype %[[EQ]], %[[INT4]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,?],si64>
// CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3,?],si64>
func.func @torch.aten.one_hot$fold(%arg0: !torch.vtensor<[3],si64>, %arg1: !torch.int) -> !torch.vtensor<[3,?],si64> {
%0 = torch.aten.one_hot %arg0, %arg1 : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,?],si64>
return %0 : !torch.vtensor<[3,?],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[1],f32>,
// CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[1],si32>, %[[ARG_3:.*]]: !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[CONST1:.*]] = torch.constant.int 127
// CHECK: %[[CONST2:.*]] = torch.constant.int -128
// CHECK: %[[RESULT:.*]] = torch.aten.fake_quantize_per_tensor_affine.tensor_qparams %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[CONST2]], %[[CONST1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],si32>, %arg3: !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,?],f32> {
%int127 = torch.constant.int 127
%int-128 = torch.constant.int -128
%0:2 = torch.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams %arg0, %arg1, %arg2, %arg3, %int-128, %int127 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],i1>
return %0#0 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[CONST0:.*]] = torch.constant.int 0
// CHECK: %[[CONST1:.*]] = torch.constant.int 127
// CHECK: %[[CONST2:.*]] = torch.constant.int -128
// CHECK: %[[RESULT:.*]] = torch.aten.fake_quantize_per_channel_affine %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[CONST0]], %[[CONST2]], %[[CONST1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>, %arg2: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int0 = torch.constant.int 0
%int127 = torch.constant.int 127
%int-128 = torch.constant.int -128
%0:2 = torch.aten.fake_quantize_per_channel_affine_cachemask %arg0, %arg1, %arg2, %int0, %int-128, %int127 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],i1>
return %0#0 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: test_einsum_inner_prod
func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[5],f64>) -> !torch.vtensor<[],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} {
// CHECK-DAG: %[[INT5:.+]] = torch.constant.int 5
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[LHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]]
// CHECK: %[[LHS_PERM:.+]] = torch.aten.permute %arg0, %[[LHS_LIST]]
// CHECK: %[[RHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]]
// CHECK: %[[RHS_PERM:.+]] = torch.aten.permute %arg1, %[[RHS_LIST]]
// CHECK: %[[LHS_SHP:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]], %[[INT5]]
// CHECK: %[[LHS_VIEW:.+]] = torch.aten.view %[[LHS_PERM]], %[[LHS_SHP]]
// CHECK: %[[RHS_SHP:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT5]], %[[INT1]]
// CHECK: %[[RHS_VIEW:.+]] = torch.aten.view %[[RHS_PERM]], %[[RHS_SHP]]
// CHECK: %[[BMM:.+]] = torch.aten.bmm %[[LHS_VIEW]], %[[RHS_VIEW]]
// CHECK: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[OUT_VIEW:.+]] = torch.aten.view %[[BMM]], %[[EMPTY]]
// CHECK: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[OUT_PERM:.+]] = torch.aten.permute %[[OUT_VIEW]], %[[EMPTY]]
// CHECK: return %[[OUT_PERM]]
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>) -> !torch.list<vtensor>
%str = torch.constant.str "i,i"
%none_0 = torch.constant.none
%1 = torch.aten.einsum %str, %0, %none_0 : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[],f64>
return %1 : !torch.vtensor<[],f64>
}
// -----
// CHECK: func.func @torch.aten.fmod_int(%[[ARG0:.+]]: !torch.vtensor<[?],si32>, %[[ARG1:.+]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[?],si32> {
// CHECK: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
// CHECK: %[[V0:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32>
// CHECK: %[[V1:.+]] = torch.aten.mul.Tensor %[[V0]], %[[ARG1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32>
// CHECK: %[[V2:.+]] = torch.aten.sub.Tensor %[[ARG0]], %[[V1]], %[[FLOAT1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si32>, !torch.float -> !torch.vtensor<[?],si32>
// CHECK: return %[[V2]] : !torch.vtensor<[?],si32>
func.func @torch.aten.fmod_int(%arg0: !torch.vtensor<[?],si32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[?],si32> {
%0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32>
return %0 : !torch.vtensor<[?],si32>
}
// -----
// CHECK: func.func @torch.aten.fmod_float(%[[ARG0:.+]]: !torch.vtensor<[?],f16>, %[[ARG1:.+]]: !torch.vtensor<[1],f16>) -> !torch.vtensor<[?],f16> {
// CHECK: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
// CHECK: %[[V0:.+]] = torch.vtensor.literal(dense<-1.0{{.*}}> : tensor<f16>) : !torch.vtensor<[],f16>
// CHECK: %[[V1:.+]] = torch.vtensor.literal(dense<0.0{{.*}}> : tensor<f16>) : !torch.vtensor<[],f16>
// CHECK: %[[V2:.+]] = torch.vtensor.literal(dense<1.0{{.*}}> : tensor<f16>) : !torch.vtensor<[],f16>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[V3:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V4:.+]] = torch.aten.gt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1>
// CHECK: %[[V5:.+]] = torch.aten.lt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1>
// CHECK: %[[V8:.+]] = torch.aten.where.self %[[V4]], %[[V2]], %[[V1]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V10:.+]] = torch.aten.where.self %[[V5]], %[[V0]], %[[V8]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V11:.+]] = torch.aten.abs %[[V3]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V12:.+]] = torch.aten.floor %[[V11]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V13:.+]] = torch.aten.mul.Tensor %[[V10]], %[[V12]] : !torch.vtensor<[?],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V14:.+]] = torch.aten.mul.Tensor %[[V13]], %[[ARG1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V15:.+]] = torch.aten.sub.Tensor %[[ARG0]], %[[V14]], %[[FLOAT1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[?],f16>, !torch.float -> !torch.vtensor<[?],f16>
// CHECK: return %[[V15]] : !torch.vtensor<[?],f16>
func.func @torch.aten.fmod_float(%arg0: !torch.vtensor<[?],f16>, %arg1: !torch.vtensor<[1],f16>) -> !torch.vtensor<[?],f16> {
%0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16>
return %0 : !torch.vtensor<[?],f16>
}
// -----
// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim(
// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[INT5:.*]] = torch.constant.int 5
// CHECK-DAG: %[[INT16:.*]] = torch.constant.int 16
// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x10xf32>) : !torch.vtensor<[9,10],f32>
// CHECK: %[[VAR1:.*]] = torch.aten.mm %arg0, %[[VAR0]] : !torch.vtensor<[16,9],f32>, !torch.vtensor<[9,10],f32> -> !torch.vtensor<[16,10],f32>
// CHECK: %[[VAR2:.*]] = torch.prim.ListConstruct %[[INT16]], %[[INT5]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR3:.*]] = torch.aten.view %[[VAR1]], %[[VAR2]] : !torch.vtensor<[16,10],f32>, !torch.list<int> -> !torch.vtensor<[16,5,2],f32>
// CHECK: %[[VAR4:.*]] = torch.aten.view_as_complex %[[VAR3]] : !torch.vtensor<[16,5,2],f32> -> !torch.vtensor<[16,5],complex<f32>>
// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex<f32>>
func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
%int-1 = torch.constant.int -1
%none = torch.constant.none
%out = torch.aten.fft_rfft %arg0, %none, %int-1, %none : !torch.vtensor<[16,9],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[16,5],complex<f32>>
return %out : !torch.vtensor<[16,5],complex<f32>>
}
// -----
// CHECK-LABEL: func.func @sdpa_decomposes_single_head(
// CHECK: %[[KEY_T:.*]] = torch.aten.transpose.int %arg1
// CHECK: %[[KEY_VIEW:.*]] = torch.aten.view %[[KEY_T]]
// CHECK: %[[SCORES:.*]] = torch.aten.bmm %arg0, %[[KEY_VIEW]]
// CHECK: %[[SCORES_VIEW:.*]] = torch.aten.view %[[SCORES]]
// CHECK: %[[SCALED:.*]] = torch.aten.mul.Scalar %[[SCORES_VIEW]]
// CHECK: %[[MAX:.*]], %[[INDICES:.*]] = torch.aten.max.dim %[[SCALED]]
// CHECK: %[[CENTERED:.*]] = torch.aten.sub.Tensor %[[SCALED]], %[[MAX]]
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[CENTERED]]
// CHECK: %[[SUM_DIM:.*]] = torch.prim.ListConstruct
// CHECK: %[[DENOM:.*]] = torch.aten.sum.dim_IntList %[[EXP]], %[[SUM_DIM]]
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[DENOM]]
// CHECK: %[[RESULT:.*]] = torch.aten.bmm %[[SOFTMAX]], %arg2
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,4,8],f32>
func.func @sdpa_decomposes_single_head(
%query: !torch.vtensor<[1,4,8],f32>,
%key: !torch.vtensor<[1,4,8],f32>,
%value: !torch.vtensor<[1,4,8],f32>) -> !torch.vtensor<[1,4,8],f32> {
%none = torch.constant.none
%zero = torch.constant.float 0.000000e+00
%false = torch.constant.bool false
%result = torch.aten.scaled_dot_product_attention %query, %key, %value, %none, %zero, %false, %none, %false :
!torch.vtensor<[1,4,8],f32>, !torch.vtensor<[1,4,8],f32>, !torch.vtensor<[1,4,8],f32>, !torch.none, !torch.float, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,4,8],f32>
return %result : !torch.vtensor<[1,4,8],f32>
}
// -----
// CHECK-LABEL: func.func @sdpa_keeps_mask(
// CHECK: torch.aten.scaled_dot_product_attention %arg0, %arg1, %arg2, %[[MASK:.*]]
// CHECK: return %[[RES:.*]] : !torch.vtensor<[1,4,8],f32>
func.func @sdpa_keeps_mask(
%query: !torch.vtensor<[1,4,8],f32>,
%key: !torch.vtensor<[1,4,8],f32>,
%value: !torch.vtensor<[1,4,8],f32>) -> !torch.vtensor<[1,4,8],f32> {
%mask = torch.vtensor.literal(dense<0.0> : tensor<1x4x4xf32>) : !torch.vtensor<[1,4,4],f32>
%zero = torch.constant.float 0.000000e+00
%false = torch.constant.bool false
%none = torch.constant.none
%result = torch.aten.scaled_dot_product_attention %query, %key, %value, %mask, %zero, %false, %none, %false :
!torch.vtensor<[1,4,8],f32>, !torch.vtensor<[1,4,8],f32>, !torch.vtensor<[1,4,8],f32>, !torch.vtensor<[1,4,4],f32>, !torch.float, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,4,8],f32>
return %result : !torch.vtensor<[1,4,8],f32>
}
// -----
// CHECK-LABEL: func.func @sdpa_keeps_dropout(
// CHECK: torch.aten.scaled_dot_product_attention %arg0, %arg1, %arg2, %[[NONE:.*]], %[[P:.*]]
// CHECK: return %[[RES:.*]] : !torch.vtensor<[1,4,8],f32>
func.func @sdpa_keeps_dropout(
%query: !torch.vtensor<[1,4,8],f32>,
%key: !torch.vtensor<[1,4,8],f32>,
%value: !torch.vtensor<[1,4,8],f32>) -> !torch.vtensor<[1,4,8],f32> {
%none = torch.constant.none
%p = torch.constant.float 1.000000e-01
%false = torch.constant.bool false
%result = torch.aten.scaled_dot_product_attention %query, %key, %value, %none, %p, %false, %none, %false :
!torch.vtensor<[1,4,8],f32>, !torch.vtensor<[1,4,8],f32>, !torch.vtensor<[1,4,8],f32>, !torch.none, !torch.float, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,4,8],f32>
return %result : !torch.vtensor<[1,4,8],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim(
// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[INT19:.*]] = torch.constant.int 19
// CHECK-DAG: %[[INT23:.*]] = torch.constant.int 23
// CHECK-DAG: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x38xf32>) : !torch.vtensor<[36,38],f32>
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[VAR1:.*]] = torch.aten.transpose.int %arg0, %[[INT0]], %[[INT1]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32>
// CHECK: %[[VAR2:.*]] = torch.aten.mm %[[VAR1]], %[[VAR0]] : !torch.vtensor<[23,36],f32>, !torch.vtensor<[36,38],f32> -> !torch.vtensor<[23,38],f32>
// CHECK: %[[VAR3:.*]] = torch.prim.ListConstruct %[[INT23]], %[[INT19]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR4:.*]] = torch.aten.view %[[VAR2]], %[[VAR3]] : !torch.vtensor<[23,38],f32>, !torch.list<int> -> !torch.vtensor<[23,19,2],f32>
// CHECK: %[[VAR5:.*]] = torch.aten.view_as_complex %[[VAR4]] : !torch.vtensor<[23,19,2],f32> -> !torch.vtensor<[23,19],complex<f32>>
// CHECK: %[[VAR6:.*]] = torch.aten.transpose.int %[[VAR5]], %[[INT0]], %[[INT1]] : !torch.vtensor<[23,19],complex<f32>>, !torch.int, !torch.int -> !torch.vtensor<[19,23],complex<f32>>
// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex<f32>>
func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
%int0 = torch.constant.int 0
%none = torch.constant.none
%out = torch.aten.fft_rfft %arg0, %none, %int0, %none : !torch.vtensor<[36,23],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[19,23],complex<f32>>
return %out : !torch.vtensor<[19,23],complex<f32>>
}
// -----
// CHECK-LABEL: func.func @torch.aten.sym_constrain_range_for_size(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch.constant.int 7
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.none
// CHECK: torch.aten.sym_constrain_range %[[VAL_0]], %[[VAL_2]], %[[VAL_3]] : !torch.int, !torch.int, !torch.none
// CHECK: torch.aten.sym_constrain_range %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : !torch.int, !torch.int, !torch.int
// CHECK: return %[[VAL_0]] : !torch.int
// CHECK: }
func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.int) -> !torch.int {
%int7 = torch.constant.int 7
%int0 = torch.constant.int 0
%none = torch.constant.none
torch.aten.sym_constrain_range_for_size %arg0, %none, %none : !torch.int, !torch.none, !torch.none
torch.aten.sym_constrain_range_for_size %arg0, %int0, %int7 : !torch.int, !torch.int, !torch.int
return %arg0 : !torch.int
}
// -----
// CHECK-LABEL: func.func @torch.aten._assert_scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch.constant.int 2
// CHECK: %[[VAL_2:.*]] = torch.constant.int 3
// CHECK: %[[VAL_3:.*]] = torch.aten.ge.int %[[VAL_0]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[VAL_4:.*]] = torch.aten.Int.bool %[[VAL_3]] : !torch.bool -> !torch.int
// CHECK: %[[VAL_5:.*]] = torch.aten.Bool.int %[[VAL_4]] : !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_5]], "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'"
// CHECK: %[[VAL_6:.*]] = torch.aten.gt.int %[[VAL_0]], %[[VAL_1]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[VAL_7:.*]] = torch.aten.Int.bool %[[VAL_6]] : !torch.bool -> !torch.int
// CHECK: %[[VAL_8:.*]] = torch.aten.Bool.int %[[VAL_7]] : !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_8]], "Runtime assertion failed for expression 2 < u0 on node 'gt_1'"
// CHECK: return %[[VAL_0]] : !torch.int
// CHECK: }
func.func @torch.aten._assert_scalar(%arg0: !torch.int) -> !torch.int {
%str = torch.constant.str "Runtime assertion failed for expression 2 < u0 on node 'gt_1'"
%int2 = torch.constant.int 2
%str_0 = torch.constant.str "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'"
%int3 = torch.constant.int 3
%0 = torch.aten.ge.int %arg0, %int3 : !torch.int, !torch.int -> !torch.bool
%1 = torch.aten.Int.bool %0 : !torch.bool -> !torch.int
torch.aten._assert_scalar %1, %str_0 : !torch.int, !torch.str
%2 = torch.aten.gt.int %arg0, %int2 : !torch.int, !torch.int -> !torch.bool
%3 = torch.aten.Int.bool %2 : !torch.bool -> !torch.int
torch.aten._assert_scalar %3, %str : !torch.int, !torch.str
return %arg0 : !torch.int
}
// -----
// CHECK-LABEL: func.func @convolution_backward_none_result(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,3,3],f32>, %[[VAL_1:.*]]: !torch.vtensor<[1,1,5,5],f32>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1,1,3,3],f32>,
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>) {
func.func @convolution_backward_none_result(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,5,5],f32>, %arg2: !torch.vtensor<[1,1,3,3],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>) {
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
// CHECK: %[[VAL_5:.*]] = torch.constant.int 2
// CHECK: %[[VAL_6:.*]] = torch.constant.none
// CHECK: %[[VAL_7:.*]] = torch.constant.int 0
// CHECK: %[[VAL_8:.*]] = torch.constant.bool false
// CHECK: %[[VAL_9:.*]] = torch.constant.int 1
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_9]], %[[VAL_9]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_7]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_12:.*]] = torch.aten.transpose.int %[[VAL_1]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,5,5],f32>
// CHECK: %[[VAL_13:.*]] = torch.aten.transpose.int %[[VAL_0]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,3,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3,3],f32>
// CHECK: %[[VAL_14:.*]] = torch.aten.convolution %[[VAL_12]], %[[VAL_13]], %[[VAL_6]], %[[VAL_10]], %[[VAL_11]], %[[VAL_10]], %[[VAL_8]], %[[VAL_11]], %[[VAL_9]] : !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,3,3],f32>
// CHECK: %[[VAL_15:.*]] = torch.aten.transpose.int %[[VAL_14]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,3,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3,3],f32>
// CHECK: %[[VAL_16:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_5]], %[[VAL_4]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_17:.*]] = torch.aten.sum.dim_IntList %[[VAL_0]], %[[VAL_16]], %[[VAL_8]], %[[VAL_6]] : !torch.vtensor<[1,1,3,3],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1],f32>
// CHECK: return %[[VAL_15]], %[[VAL_17]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
%result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.list<bool> -> !torch.none, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>
return %result1, %result2 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>
}
// -----
// CHECK-LABEL: func.func @emptyLikeNoneDtype(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
// CHECK: %[[DTYPE:.*]] = torch.constant.int 7
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[C200:.*]] = torch.constant.int 200
// CHECK: %[[C26:.*]] = torch.constant.int 26
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C200]], %[[C200]], %[[C26]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[MEM_FMT:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[DTYPE]], %[[NONE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
func.func @emptyLikeNoneDtype(%arg0: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
%none = torch.constant.none
%none_0 = torch.constant.none
%none_1 = torch.constant.none
%false = torch.constant.bool false
%none_2 = torch.constant.none
%0 = torch.aten.empty_like %arg0, %none, %none_0, %none_1, %false, %none_2 : !torch.vtensor<[200,200,26],f64>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
return %0 : !torch.vtensor<[200,200,26],f64>
}
// -----
// CHECK-LABEL: func.func @randNoneDtype(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
// CHECK: %[[DTYPE:.*]] = torch.constant.int 7
// CHECK: %[[C1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[C0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[C200:.*]] = torch.constant.int 200
// CHECK: %[[C26:.*]] = torch.constant.int 26
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C200]], %[[C200]], %[[C26]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[CPU:.*]] = torch.constant.device "cpu"
// CHECK: %[[MEM_FMT:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[DTYPE]], %[[NONE]], %[[CPU]], %[[FALSE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
// CHECK: %[[UNIFORM:.*]] = torch.aten.uniform %[[MEM_FMT]], %[[C0]], %[[C1]], %[[NONE]] : !torch.vtensor<[200,200,26],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[200,200,26],f64>
// CHECK: return %[[UNIFORM]] : !torch.vtensor<[200,200,26],f64>
func.func @randNoneDtype(%arg0: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
%int200 = torch.constant.int 200
%int200_0 = torch.constant.int 200
%int26 = torch.constant.int 26
%0 = torch.prim.ListConstruct %int200, %int200_0, %int26 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%none = torch.constant.none
%none_1 = torch.constant.none
%cpu = torch.constant.device "cpu"
%false = torch.constant.bool false
%1 = torch.aten.rand %0, %none, %none_1, %cpu, %false : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[200,200,26],f64>
return %1 : !torch.vtensor<[200,200,26],f64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.stft.center_1D(
// CHECK-SAME: %arg0: !torch.vtensor<[40],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,37],complex<f32>> {
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[INT9:.*]] = torch.constant.int 9
// CHECK-DAG: %[[INT4:.*]] = torch.constant.int 4
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INTM1:.*]] = torch.constant.int -1
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK-DAG: %float0.000000e00 = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[INT40:.*]] = torch.constant.int 40
// CHECK-DAG: %[[INT37:.*]] = torch.constant.int 37
// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[VAR0:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT37]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR1:.*]] = torch.aten.empty.memory_format %[[VAR0]], %[[INT9]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,37],complex<f32>>
// CHECK: %[[VAR2:.*]] = torch.prim.Loop %[[INT37]], %[[TRUE]], init(%[[VAR1]]) {
// CHECK: ^bb0(%arg2: !torch.int, %arg3: !torch.vtensor<[3,37],complex<f32>>):
// CHECK: %[[VAR3:.*]] = torch.aten.add.int %arg2, %[[INT4]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR4:.*]] = torch.prim.min.int %[[VAR3]], %[[INT40]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR5:.*]] = torch.aten.sub.int %[[VAR4]], %arg2 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR6:.*]] = torch.aten.sub.int %[[INT4]], %[[VAR5]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR7:.*]] = torch.aten.add.int %arg2, %[[VAR5]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR8:.*]] = torch.aten.slice.Tensor %arg0, %[[INTM1]], %arg2, %[[VAR7]], %[[INT1]] : !torch.vtensor<[40],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[VAR9:.*]] = torch.prim.ListConstruct %[[INT0]], %[[VAR6]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR10:.*]] = torch.aten.constant_pad_nd %[[VAR8]], %[[VAR9]], %float0.000000e00 : !torch.vtensor<[?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?],f32>
// CHECK: %[[VAR11:.*]] = torch.tensor_static_info_cast %[[VAR10]] : !torch.vtensor<[?],f32> to !torch.vtensor<[4],f32>
// CHECK: %[[VAR12:.*]] = torch.aten.mul.Tensor %[[VAR11]], %arg1 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32>
// CHECK: %[[VAR13:.*]] = torch.aten.fft_fft %[[VAR12]], %[[NONE]], %[[INTM1]], %[[NONE]] : !torch.vtensor<[4],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[3],complex<f32>>
// CHECK: %[[VAR14:.*]] = torch.aten.unsqueeze %[[VAR13]], %[[INTM1]] : !torch.vtensor<[3],complex<f32>>, !torch.int -> !torch.vtensor<[3,1],complex<f32>>
// CHECK: %[[VAR15:.*]] = torch.aten.slice_scatter %arg3, %[[VAR14]], %[[INTM1]], %arg2, %[[NONE]], %[[INT1]] : !torch.vtensor<[3,37],complex<f32>>, !torch.vtensor<[3,1],complex<f32>>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[3,37],complex<f32>>
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[VAR15]] : !torch.vtensor<[3,37],complex<f32>>)
// CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[3,37],complex<f32>>) -> !torch.vtensor<[3,37],complex<f32>>
// CHECK: return %[[VAR2]] : !torch.vtensor<[3,37],complex<f32>>
func.func @torch.aten.stft.center_1D(%arg0: !torch.vtensor<[40],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,37],complex<f32>> {
%padmode = torch.constant.str "reflect"
%nfft = torch.constant.int 4
%hoplen = torch.constant.int 1
%winlen = torch.constant.int 4
%cstfalse = torch.constant.bool false
%csttrue = torch.constant.bool true
%0 = torch.aten.stft.center %arg0, %nfft, %hoplen, %winlen, %arg1, %cstfalse, %padmode, %cstfalse, %cstfalse, %csttrue, %cstfalse : !torch.vtensor<[40],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[4],f32>, !torch.bool, !torch.str, !torch.bool, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[3,37],complex<f32>>
return %0 : !torch.vtensor<[3,37],complex<f32>>
}
// -----
// CHECK-LABEL: func.func @torch.aten.stft.center_1D_unk_sig_len(
// CHECK-SAME: %arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[10],f32>) -> !torch.vtensor<[6,?],complex<f32>> {
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[INT9:.*]] = torch.constant.int 9
// CHECK-DAG: %[[INT6:.*]] = torch.constant.int 6
// CHECK-DAG: %[[INT10:.*]] = torch.constant.int 10
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INTM1:.*]] = torch.constant.int -1
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK-DAG: %float0.000000e00 = torch.constant.float 0.000000e+00
// CHECK: %[[VAR0:.*]] = torch.aten.size.int %arg0, %[[INTM1]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
// CHECK: %[[VAR1:.*]] = torch.aten.sub.int %[[VAR0]], %[[INT10]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR2:.*]] = torch.aten.floordiv.int %[[VAR1]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR3:.*]] = torch.aten.add.int %[[INT1]], %[[VAR2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR4:.*]] = torch.prim.ListConstruct %[[INT6]], %[[VAR3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR5:.*]] = torch.aten.empty.memory_format %[[VAR4]], %[[INT9]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[6,?],complex<f32>>
// CHECK: %[[VAR6:.*]] = torch.prim.Loop %[[VAR3]], %[[TRUE]], init(%[[VAR5]]) {
// CHECK: ^bb0(%arg2: !torch.int, %arg3: !torch.vtensor<[6,?],complex<f32>>):
// CHECK: %[[VAR7:.*]] = torch.aten.add.int %arg2, %[[INT10]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR8:.*]] = torch.prim.min.int %[[VAR7]], %[[VAR0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR9:.*]] = torch.aten.sub.int %[[VAR8]], %arg2 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR10:.*]] = torch.aten.sub.int %[[INT10]], %[[VAR9]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR11:.*]] = torch.aten.add.int %arg2, %[[VAR9]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR12:.*]] = torch.aten.slice.Tensor %arg0, %[[INTM1]], %arg2, %[[VAR11]], %[[INT1]] : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[VAR13:.*]] = torch.prim.ListConstruct %[[INT0]], %[[VAR10]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR14:.*]] = torch.aten.constant_pad_nd %[[VAR12]], %[[VAR13]], %float0.000000e00 : !torch.vtensor<[?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?],f32>
// CHECK: %[[VAR15:.*]] = torch.tensor_static_info_cast %[[VAR14]] : !torch.vtensor<[?],f32> to !torch.vtensor<[10],f32>
// CHECK: %[[VAR16:.*]] = torch.aten.mul.Tensor %[[VAR15]], %arg1 : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK: %[[VAR17:.*]] = torch.aten.fft_fft %[[VAR16]], %[[NONE]], %[[INTM1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[6],complex<f32>>
// CHECK: %[[VAR18:.*]] = torch.aten.unsqueeze %[[VAR17]], %[[INTM1]] : !torch.vtensor<[6],complex<f32>>, !torch.int -> !torch.vtensor<[6,1],complex<f32>>
// CHECK: %[[VAR19:.*]] = torch.aten.slice_scatter %arg3, %[[VAR18]], %[[INTM1]], %arg2, %[[NONE]], %[[INT1]] : !torch.vtensor<[6,?],complex<f32>>, !torch.vtensor<[6,1],complex<f32>>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[6,?],complex<f32>>
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[VAR19]] : !torch.vtensor<[6,?],complex<f32>>)
// CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[6,?],complex<f32>>) -> !torch.vtensor<[6,?],complex<f32>>
// CHECK: return %[[VAR6]] : !torch.vtensor<[6,?],complex<f32>>
func.func @torch.aten.stft.center_1D_unk_sig_len(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[10],f32>) -> !torch.vtensor<[6,?],complex<f32>> {
%padmode = torch.constant.str "reflect"
%nfft = torch.constant.int 10
%hoplen = torch.constant.int 1
%winlen = torch.constant.int 10
%cstfalse = torch.constant.bool false
%csttrue = torch.constant.bool true
%0 = torch.aten.stft.center %arg0, %nfft, %hoplen, %winlen, %arg1, %cstfalse, %padmode, %cstfalse, %cstfalse, %csttrue, %cstfalse : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[10],f32>, !torch.bool, !torch.str, !torch.bool, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[6,?],complex<f32>>
return %0 : !torch.vtensor<[6,?],complex<f32>>
}
// -----
// CHECK-LABEL: func.func @torch.aten.stft.center_2D(
// CHECK-SAME: %arg0: !torch.vtensor<[4,46],f32>, %arg1: !torch.vtensor<[7],f32>) -> !torch.vtensor<[4,4,40],complex<f32>> {
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[INT9:.*]] = torch.constant.int 9
// CHECK-DAG: %[[INT4:.*]] = torch.constant.int 4
// CHECK-DAG: %[[INT7:.*]] = torch.constant.int 7
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INTM1:.*]] = torch.constant.int -1
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK-DAG: %float0.000000e00 = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[INT46:.*]] = torch.constant.int 46
// CHECK-DAG: %[[INT40:.*]] = torch.constant.int 40
// CHECK: %[[VAR0:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT7]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR1:.*]] = torch.aten.view %arg1, %[[VAR0]] : !torch.vtensor<[7],f32>, !torch.list<int> -> !torch.vtensor<[1,7],f32>
// CHECK: %[[VAR2:.*]] = torch.prim.ListConstruct %[[INT4]], %[[INT4]], %[[INT40]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR3:.*]] = torch.aten.empty.memory_format %[[VAR2]], %[[INT9]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4,4,40],complex<f32>>
// CHECK: %[[VAR4:.*]] = torch.prim.Loop %[[INT40]], %[[TRUE]], init(%[[VAR3]]) {
// CHECK: ^bb0(%arg2: !torch.int, %arg3: !torch.vtensor<[4,4,40],complex<f32>>):
// CHECK: %[[VAR5:.*]] = torch.aten.add.int %arg2, %[[INT7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR6:.*]] = torch.prim.min.int %[[VAR5]], %[[INT46]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR7:.*]] = torch.aten.sub.int %[[VAR6]], %arg2 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR8:.*]] = torch.aten.sub.int %[[INT7]], %[[VAR7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR9:.*]] = torch.aten.add.int %arg2, %[[VAR7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR10:.*]] = torch.aten.slice.Tensor %arg0, %[[INTM1]], %arg2, %[[VAR9]], %[[INT1]] : !torch.vtensor<[4,46],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?],f32>
// CHECK: %[[VAR11:.*]] = torch.prim.ListConstruct %[[INT0]], %[[VAR8]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR12:.*]] = torch.aten.constant_pad_nd %[[VAR10]], %[[VAR11]], %float0.000000e00 : !torch.vtensor<[4,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[4,?],f32>
// CHECK: %[[VAR13:.*]] = torch.tensor_static_info_cast %[[VAR12]] : !torch.vtensor<[4,?],f32> to !torch.vtensor<[4,7],f32>
// CHECK: %[[VAR14:.*]] = torch.aten.mul.Tensor %[[VAR13]], %[[VAR1]] : !torch.vtensor<[4,7],f32>, !torch.vtensor<[1,7],f32> -> !torch.vtensor<[4,7],f32>
// CHECK: %[[VAR15:.*]] = torch.aten.fft_fft %[[VAR14]], %[[NONE]], %[[INTM1]], %[[NONE]] : !torch.vtensor<[4,7],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[4,4],complex<f32>>
// CHECK: %[[VAR16:.*]] = torch.aten.unsqueeze %[[VAR15]], %[[INTM1]] : !torch.vtensor<[4,4],complex<f32>>, !torch.int -> !torch.vtensor<[4,4,1],complex<f32>>
// CHECK: %[[VAR17:.*]] = torch.aten.slice_scatter %arg3, %[[VAR16]], %[[INTM1]], %arg2, %[[NONE]], %[[INT1]] : !torch.vtensor<[4,4,40],complex<f32>>, !torch.vtensor<[4,4,1],complex<f32>>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[4,4,40],complex<f32>>
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[VAR17]] : !torch.vtensor<[4,4,40],complex<f32>>)
// CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[4,4,40],complex<f32>>) -> !torch.vtensor<[4,4,40],complex<f32>>
// CHECK: return %[[VAR4]] : !torch.vtensor<[4,4,40],complex<f32>>
func.func @torch.aten.stft.center_2D(%arg0: !torch.vtensor<[4,46],f32>, %arg1: !torch.vtensor<[7],f32>) -> !torch.vtensor<[4,4,40],complex<f32>> {
%padmode = torch.constant.str "reflect"
%nfft = torch.constant.int 7
%hoplen = torch.constant.int 1
%winlen = torch.constant.int 7
%cstfalse = torch.constant.bool false
%csttrue = torch.constant.bool true
%0 = torch.aten.stft.center %arg0, %nfft, %hoplen, %winlen, %arg1, %cstfalse, %padmode, %cstfalse, %cstfalse, %csttrue, %cstfalse : !torch.vtensor<[4,46],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[7],f32>, !torch.bool, !torch.str, !torch.bool, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[4,4,40],complex<f32>>
return %0 : !torch.vtensor<[4,4,40],complex<f32>>
}
// -----
// CHECK-LABEL: func.func @torch.aten.stft.center_2D_win_unk_size(
// CHECK-SAME: %arg0: !torch.vtensor<[3,38],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[3,4,32],complex<f32>> {
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[INT9:.*]] = torch.constant.int 9
// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3
// CHECK-DAG: %[[INT4:.*]] = torch.constant.int 4
// CHECK-DAG: %[[INT32:.*]] = torch.constant.int 32
// CHECK-DAG: %[[INT38:.*]] = torch.constant.int 38
// CHECK-DAG: %float0.000000e00 = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INTM1:.*]] = torch.constant.int -1
// CHECK-DAG: %[[INT6:.*]] = torch.constant.int 6
// CHECK-DAG: %[[INT7:.*]] = torch.constant.int 7
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[VAR0:.*]] = torch.aten.size.int %arg1, %[[INT0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
// CHECK: %[[VAR1:.*]] = torch.aten.eq.int %[[VAR0]], %[[INT6]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAR1]], "window size should be equal to win_length"
// CHECK: %[[VAR2:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR3:.*]] = torch.aten.constant_pad_nd %arg1, %[[VAR2]], %float0.000000e00 : !torch.vtensor<[?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[7],f32>
// CHECK: %[[VAR4:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT7]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR5:.*]] = torch.aten.view %[[VAR3]], %[[VAR4]] : !torch.vtensor<[7],f32>, !torch.list<int> -> !torch.vtensor<[1,7],f32>
// CHECK: %[[VAR6:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT4]], %[[INT32]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR7:.*]] = torch.aten.empty.memory_format %[[VAR6]], %[[INT9]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,32],complex<f32>>
// CHECK: %[[VAR8:.*]] = torch.prim.Loop %[[INT32]], %[[TRUE]], init(%[[VAR7]]) {
// CHECK: ^bb0(%arg2: !torch.int, %arg3: !torch.vtensor<[3,4,32],complex<f32>>):
// CHECK: %[[VAR9:.*]] = torch.aten.add.int %arg2, %[[INT7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR10:.*]] = torch.prim.min.int %[[VAR9]], %[[INT38]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR11:.*]] = torch.aten.sub.int %[[VAR10]], %arg2 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR12:.*]] = torch.aten.sub.int %[[INT7]], %[[VAR11]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR13:.*]] = torch.aten.add.int %arg2, %[[VAR11]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR14:.*]] = torch.aten.slice.Tensor %arg0, %[[INTM1]], %arg2, %[[VAR13]], %[[INT1]] : !torch.vtensor<[3,38],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?],f32>
// CHECK: %[[VAR15:.*]] = torch.prim.ListConstruct %[[INT0]], %[[VAR12]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR16:.*]] = torch.aten.constant_pad_nd %[[VAR14]], %[[VAR15]], %float0.000000e00 : !torch.vtensor<[3,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[3,?],f32>
// CHECK: %[[VAR17:.*]] = torch.tensor_static_info_cast %[[VAR16]] : !torch.vtensor<[3,?],f32> to !torch.vtensor<[3,7],f32>
// CHECK: %[[VAR18:.*]] = torch.aten.mul.Tensor %[[VAR17]], %[[VAR5]] : !torch.vtensor<[3,7],f32>, !torch.vtensor<[1,7],f32> -> !torch.vtensor<[3,7],f32>
// CHECK: %[[VAR19:.*]] = torch.aten.fft_fft %[[VAR18]], %[[NONE]], %[[INTM1]], %[[NONE]] : !torch.vtensor<[3,7],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[3,4],complex<f32>>
// CHECK: %[[VAR20:.*]] = torch.aten.unsqueeze %[[VAR19]], %[[INTM1]] : !torch.vtensor<[3,4],complex<f32>>, !torch.int -> !torch.vtensor<[3,4,1],complex<f32>>
// CHECK: %[[VAR21:.*]] = torch.aten.slice_scatter %arg3, %[[VAR20]], %[[INTM1]], %arg2, %[[NONE]], %[[INT1]] : !torch.vtensor<[3,4,32],complex<f32>>, !torch.vtensor<[3,4,1],complex<f32>>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[3,4,32],complex<f32>>
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[VAR21]] : !torch.vtensor<[3,4,32],complex<f32>>)
// CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[3,4,32],complex<f32>>) -> !torch.vtensor<[3,4,32],complex<f32>>
// CHECK: return %[[VAR8]] : !torch.vtensor<[3,4,32],complex<f32>>
func.func @torch.aten.stft.center_2D_win_unk_size(%arg0: !torch.vtensor<[3,38],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[3,4,32],complex<f32>> {
%padmode = torch.constant.str "reflect"
%nfft = torch.constant.int 7
%hoplen = torch.constant.int 1
%winlen = torch.constant.int 6
%cstfalse = torch.constant.bool false
%csttrue = torch.constant.bool true
%0 = torch.aten.stft.center %arg0, %nfft, %hoplen, %winlen, %arg1, %cstfalse, %padmode, %cstfalse, %cstfalse, %csttrue, %cstfalse : !torch.vtensor<[3,38],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[?],f32>, !torch.bool, !torch.str, !torch.bool, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[3,4,32],complex<f32>>
return %0 : !torch.vtensor<[3,4,32],complex<f32>>
}
// -----
// CHECK-LABEL: func.func @torch.aten.stft.center_2D_no_window(
// CHECK-SAME: %arg0: !torch.vtensor<[2,32],f32>) -> !torch.vtensor<[2,5,25],complex<f32>> {
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[INT9:.*]] = torch.constant.int 9
// CHECK-DAG: %[[INT8:.*]] = torch.constant.int 8
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INTM1:.*]] = torch.constant.int -1
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK-DAG: %float0.000000e00 = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[INT32:.*]] = torch.constant.int 32
// CHECK-DAG: %[[INT25:.*]] = torch.constant.int 25
// CHECK-DAG: %[[INT5:.*]] = torch.constant.int 5
// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[VAR0:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT5]], %[[INT25]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR1:.*]] = torch.aten.empty.memory_format %[[VAR0]], %[[INT9]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,5,25],complex<f32>>
// CHECK: %[[VAR2:.*]] = torch.prim.Loop %[[INT25]], %[[TRUE]], init(%[[VAR1]]) {
// CHECK: ^bb0(%arg1: !torch.int, %arg2: !torch.vtensor<[2,5,25],complex<f32>>):
// CHECK: %[[VAR3:.*]] = torch.aten.add.int %arg1, %[[INT8]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR4:.*]] = torch.prim.min.int %[[VAR3]], %[[INT32]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR5:.*]] = torch.aten.sub.int %[[VAR4]], %arg1 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR6:.*]] = torch.aten.sub.int %[[INT8]], %[[VAR5]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR7:.*]] = torch.aten.add.int %arg1, %[[VAR5]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR8:.*]] = torch.aten.slice.Tensor %arg0, %[[INTM1]], %arg1, %[[VAR7]], %[[INT1]] : !torch.vtensor<[2,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,?],f32>
// CHECK: %[[VAR9:.*]] = torch.prim.ListConstruct %[[INT0]], %[[VAR6]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR10:.*]] = torch.aten.constant_pad_nd %[[VAR8]], %[[VAR9]], %float0.000000e00 : !torch.vtensor<[2,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[2,?],f32>
// CHECK: %[[VAR11:.*]] = torch.tensor_static_info_cast %[[VAR10]] : !torch.vtensor<[2,?],f32> to !torch.vtensor<[2,8],f32>
// CHECK: %[[VAR12:.*]] = torch.aten.fft_fft %[[VAR11]], %[[NONE]], %[[INTM1]], %[[NONE]] : !torch.vtensor<[2,8],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[2,5],complex<f32>>
// CHECK: %[[VAR13:.*]] = torch.aten.unsqueeze %[[VAR12]], %[[INTM1]] : !torch.vtensor<[2,5],complex<f32>>, !torch.int -> !torch.vtensor<[2,5,1],complex<f32>>
// CHECK: %[[VAR14:.*]] = torch.aten.slice_scatter %arg2, %[[VAR13]], %[[INTM1]], %arg1, %[[NONE]], %[[INT1]] : !torch.vtensor<[2,5,25],complex<f32>>, !torch.vtensor<[2,5,1],complex<f32>>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[2,5,25],complex<f32>>
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[VAR14]] : !torch.vtensor<[2,5,25],complex<f32>>)
// CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[2,5,25],complex<f32>>) -> !torch.vtensor<[2,5,25],complex<f32>>
// CHECK: return %[[VAR2]] : !torch.vtensor<[2,5,25],complex<f32>>
func.func @torch.aten.stft.center_2D_no_window(%arg0: !torch.vtensor<[2,32],f32>) -> !torch.vtensor<[2,5,25],complex<f32>> {
%padmode = torch.constant.str "reflect"
%nfft = torch.constant.int 8
%hoplen = torch.constant.int 1
%cstfalse = torch.constant.bool false
%csttrue = torch.constant.bool true
%cstnone = torch.constant.none
%0 = torch.aten.stft.center %arg0, %nfft, %hoplen, %cstnone, %cstnone, %cstfalse, %padmode, %cstfalse, %cstfalse, %csttrue, %cstfalse : !torch.vtensor<[2,32],f32>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.bool, !torch.str, !torch.bool, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[2,5,25],complex<f32>>
return %0 : !torch.vtensor<[2,5,25],complex<f32>>
}
// -----
// CHECK-LABEL: func.func @torch.aten.stft.center_2D_hop_length_2(
// CHECK-SAME: %arg0: !torch.vtensor<[2,61],f32>, %arg1: !torch.vtensor<[8],f32>) -> !torch.vtensor<[2,5,27],complex<f32>> {
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[INT9:.*]] = torch.constant.int 9
// CHECK-DAG: %[[INT5:.*]] = torch.constant.int 5
// CHECK-DAG: %[[INT8:.*]] = torch.constant.int 8
// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INTM1:.*]] = torch.constant.int -1
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK-DAG: %float0.000000e00 = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[INT61:.*]] = torch.constant.int 61
// CHECK-DAG: %[[INT27:.*]] = torch.constant.int 27
// CHECK: %[[VAR0:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT8]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR1:.*]] = torch.aten.view %arg1, %[[VAR0]] : !torch.vtensor<[8],f32>, !torch.list<int> -> !torch.vtensor<[1,8],f32>
// CHECK: %[[VAR2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT5]], %[[INT27]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR3:.*]] = torch.aten.empty.memory_format %[[VAR2]], %[[INT9]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,5,27],complex<f32>>
// CHECK: %[[VAR4:.*]] = torch.prim.Loop %[[INT27]], %[[TRUE]], init(%[[VAR3]]) {
// CHECK: ^bb0(%arg2: !torch.int, %arg3: !torch.vtensor<[2,5,27],complex<f32>>):
// CHECK: %[[VAR5:.*]] = torch.aten.mul.int %arg2, %[[INT2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR6:.*]] = torch.aten.add.int %[[VAR5]], %[[INT8]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR7:.*]] = torch.prim.min.int %[[VAR6]], %[[INT61]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR8:.*]] = torch.aten.sub.int %[[VAR7]], %[[VAR5]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR9:.*]] = torch.aten.sub.int %[[INT8]], %[[VAR8]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR10:.*]] = torch.aten.add.int %[[VAR5]], %[[VAR8]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR11:.*]] = torch.aten.slice.Tensor %arg0, %[[INTM1]], %[[VAR5]], %[[VAR10]], %[[INT1]] : !torch.vtensor<[2,61],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,?],f32>
// CHECK: %[[VAR12:.*]] = torch.prim.ListConstruct %[[INT0]], %[[VAR9]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR13:.*]] = torch.aten.constant_pad_nd %[[VAR11]], %[[VAR12]], %float0.000000e00 : !torch.vtensor<[2,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[2,?],f32>
// CHECK: %[[VAR14:.*]] = torch.tensor_static_info_cast %[[VAR13]] : !torch.vtensor<[2,?],f32> to !torch.vtensor<[2,8],f32>
// CHECK: %[[VAR15:.*]] = torch.aten.mul.Tensor %[[VAR14]], %[[VAR1]] : !torch.vtensor<[2,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[2,8],f32>
// CHECK: %[[VAR16:.*]] = torch.aten.fft_fft %[[VAR15]], %[[NONE]], %[[INTM1]], %[[NONE]] : !torch.vtensor<[2,8],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[2,5],complex<f32>>
// CHECK: %[[VAR17:.*]] = torch.aten.unsqueeze %[[VAR16]], %[[INTM1]] : !torch.vtensor<[2,5],complex<f32>>, !torch.int -> !torch.vtensor<[2,5,1],complex<f32>>
// CHECK: %[[VAR18:.*]] = torch.aten.slice_scatter %arg3, %[[VAR17]], %[[INTM1]], %arg2, %[[NONE]], %[[INT1]] : !torch.vtensor<[2,5,27],complex<f32>>, !torch.vtensor<[2,5,1],complex<f32>>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[2,5,27],complex<f32>>
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[VAR18]] : !torch.vtensor<[2,5,27],complex<f32>>)
// CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[2,5,27],complex<f32>>) -> !torch.vtensor<[2,5,27],complex<f32>>
// CHECK: return %[[VAR4]] : !torch.vtensor<[2,5,27],complex<f32>>
func.func @torch.aten.stft.center_2D_hop_length_2(%arg0: !torch.vtensor<[2,61],f32>, %arg1: !torch.vtensor<[8],f32>) -> !torch.vtensor<[2,5,27],complex<f32>> {
%padmode = torch.constant.str "reflect"
%nfft = torch.constant.int 8
%hoplen = torch.constant.int 2
%winlen = torch.constant.int 8
%cstfalse = torch.constant.bool false
%csttrue = torch.constant.bool true
%0 = torch.aten.stft.center %arg0, %nfft, %hoplen, %winlen, %arg1, %cstfalse, %padmode, %cstfalse, %cstfalse, %csttrue, %cstfalse : !torch.vtensor<[2,61],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[8],f32>, !torch.bool, !torch.str, !torch.bool, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[2,5,27],complex<f32>>
return %0 : !torch.vtensor<[2,5,27],complex<f32>>
}
// -----
// CHECK-LABEL: func.func @torch.aten.stft.center_2D_window_pad_left(
// CHECK-SAME: %arg0: !torch.vtensor<[2,68],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[2,4,31],complex<f32>> {
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[INT9:.*]] = torch.constant.int 9
// CHECK-DAG: %[[INT4:.*]] = torch.constant.int 4
// CHECK-DAG: %[[INT7:.*]] = torch.constant.int 7
// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INTM1:.*]] = torch.constant.int -1
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK-DAG: %float0.000000e00 = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[INT68:.*]] = torch.constant.int 68
// CHECK-DAG: %[[INT31:.*]] = torch.constant.int 31
// CHECK: %[[VAR0:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR1:.*]] = torch.aten.constant_pad_nd %arg1, %[[VAR0]], %float0.000000e00 : !torch.vtensor<[6],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[7],f32>
// CHECK: %[[VAR2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT7]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR3:.*]] = torch.aten.view %[[VAR1]], %[[VAR2]] : !torch.vtensor<[7],f32>, !torch.list<int> -> !torch.vtensor<[1,7],f32>
// CHECK: %[[VAR4:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT4]], %[[INT31]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR5:.*]] = torch.aten.empty.memory_format %[[VAR4]], %[[INT9]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,4,31],complex<f32>>
// CHECK: %[[VAR6:.*]] = torch.prim.Loop %[[INT31]], %[[TRUE]], init(%[[VAR5]]) {
// CHECK: ^bb0(%arg2: !torch.int, %arg3: !torch.vtensor<[2,4,31],complex<f32>>):
// CHECK: %[[VAR7:.*]] = torch.aten.mul.int %arg2, %[[INT2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR8:.*]] = torch.aten.add.int %[[VAR7]], %[[INT7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR9:.*]] = torch.prim.min.int %[[VAR8]], %[[INT68]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR10:.*]] = torch.aten.sub.int %[[VAR9]], %[[VAR7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR11:.*]] = torch.aten.sub.int %[[INT7]], %[[VAR10]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR12:.*]] = torch.aten.add.int %[[VAR7]], %[[VAR10]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR13:.*]] = torch.aten.slice.Tensor %arg0, %[[INTM1]], %[[VAR7]], %[[VAR12]], %[[INT1]] : !torch.vtensor<[2,68],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,?],f32>
// CHECK: %[[VAR14:.*]] = torch.prim.ListConstruct %[[INT0]], %[[VAR11]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR15:.*]] = torch.aten.constant_pad_nd %[[VAR13]], %[[VAR14]], %float0.000000e00 : !torch.vtensor<[2,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[2,?],f32>
// CHECK: %[[VAR16:.*]] = torch.tensor_static_info_cast %[[VAR15]] : !torch.vtensor<[2,?],f32> to !torch.vtensor<[2,7],f32>
// CHECK: %[[VAR17:.*]] = torch.aten.mul.Tensor %[[VAR16]], %[[VAR3]] : !torch.vtensor<[2,7],f32>, !torch.vtensor<[1,7],f32> -> !torch.vtensor<[2,7],f32>
// CHECK: %[[VAR18:.*]] = torch.aten.fft_fft %[[VAR17]], %[[NONE]], %[[INTM1]], %[[NONE]] : !torch.vtensor<[2,7],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[2,4],complex<f32>>
// CHECK: %[[VAR19:.*]] = torch.aten.unsqueeze %[[VAR18]], %[[INTM1]] : !torch.vtensor<[2,4],complex<f32>>, !torch.int -> !torch.vtensor<[2,4,1],complex<f32>>
// CHECK: %[[VAR20:.*]] = torch.aten.slice_scatter %arg3, %[[VAR19]], %[[INTM1]], %arg2, %[[NONE]], %[[INT1]] : !torch.vtensor<[2,4,31],complex<f32>>, !torch.vtensor<[2,4,1],complex<f32>>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[2,4,31],complex<f32>>
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[VAR20]] : !torch.vtensor<[2,4,31],complex<f32>>)
// CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[2,4,31],complex<f32>>) -> !torch.vtensor<[2,4,31],complex<f32>>
// CHECK: return %[[VAR6]] : !torch.vtensor<[2,4,31],complex<f32>>
func.func @torch.aten.stft.center_2D_window_pad_left(%arg0: !torch.vtensor<[2,68],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[2,4,31],complex<f32>> {
%padmode = torch.constant.str "reflect"
%nfft = torch.constant.int 7
%hoplen = torch.constant.int 2
%winlen = torch.constant.int 6
%cstfalse = torch.constant.bool false
%csttrue = torch.constant.bool true
%0 = torch.aten.stft.center %arg0, %nfft, %hoplen, %winlen, %arg1, %cstfalse, %padmode, %cstfalse, %cstfalse, %csttrue, %cstfalse : !torch.vtensor<[2,68],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[6],f32>, !torch.bool, !torch.str, !torch.bool, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,31],complex<f32>>
return %0 : !torch.vtensor<[2,4,31],complex<f32>>
}
// -----
// CHECK-LABEL: func.func @torch.aten.stft.center_2D_hop_length_3_window_pad_both(
// CHECK-SAME: %arg0: !torch.vtensor<[3,90],f32>, %arg1: !torch.vtensor<[8],f32>) -> !torch.vtensor<[3,6,27],complex<f32>> {
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[INT9:.*]] = torch.constant.int 9
// CHECK-DAG: %[[INT6:.*]] = torch.constant.int 6
// CHECK-DAG: %[[INT10:.*]] = torch.constant.int 10
// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INTM1:.*]] = torch.constant.int -1
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK-DAG: %float0.000000e00 = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[INT90:.*]] = torch.constant.int 90
// CHECK-DAG: %[[INT27:.*]] = torch.constant.int 27
// CHECK: %[[VAR0:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR1:.*]] = torch.aten.constant_pad_nd %arg1, %[[VAR0]], %float0.000000e00 : !torch.vtensor<[8],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[VAR2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT10]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR3:.*]] = torch.aten.view %[[VAR1]], %[[VAR2]] : !torch.vtensor<[10],f32>, !torch.list<int> -> !torch.vtensor<[1,10],f32>
// CHECK: %[[VAR4:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT6]], %[[INT27]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR5:.*]] = torch.aten.empty.memory_format %[[VAR4]], %[[INT9]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,6,27],complex<f32>>
// CHECK: %[[VAR6:.*]] = torch.prim.Loop %[[INT27]], %[[TRUE]], init(%[[VAR5]]) {
// CHECK: ^bb0(%arg2: !torch.int, %arg3: !torch.vtensor<[3,6,27],complex<f32>>):
// CHECK: %[[VAR7:.*]] = torch.aten.mul.int %arg2, %[[INT3]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR8:.*]] = torch.aten.add.int %[[VAR7]], %[[INT10]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR9:.*]] = torch.prim.min.int %[[VAR8]], %[[INT90]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR10:.*]] = torch.aten.sub.int %[[VAR9]], %[[VAR7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR11:.*]] = torch.aten.sub.int %[[INT10]], %[[VAR10]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR12:.*]] = torch.aten.add.int %[[VAR7]], %[[VAR10]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR13:.*]] = torch.aten.slice.Tensor %arg0, %[[INTM1]], %[[VAR7]], %[[VAR12]], %[[INT1]] : !torch.vtensor<[3,90],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?],f32>
// CHECK: %[[VAR14:.*]] = torch.prim.ListConstruct %[[INT0]], %[[VAR11]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR15:.*]] = torch.aten.constant_pad_nd %[[VAR13]], %[[VAR14]], %float0.000000e00 : !torch.vtensor<[3,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[3,?],f32>
// CHECK: %[[VAR16:.*]] = torch.tensor_static_info_cast %[[VAR15]] : !torch.vtensor<[3,?],f32> to !torch.vtensor<[3,10],f32>
// CHECK: %[[VAR17:.*]] = torch.aten.mul.Tensor %[[VAR16]], %[[VAR3]] : !torch.vtensor<[3,10],f32>, !torch.vtensor<[1,10],f32> -> !torch.vtensor<[3,10],f32>
// CHECK: %[[VAR18:.*]] = torch.aten.fft_fft %[[VAR17]], %[[NONE]], %[[INTM1]], %[[NONE]] : !torch.vtensor<[3,10],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[3,6],complex<f32>>
// CHECK: %[[VAR19:.*]] = torch.aten.unsqueeze %[[VAR18]], %[[INTM1]] : !torch.vtensor<[3,6],complex<f32>>, !torch.int -> !torch.vtensor<[3,6,1],complex<f32>>
// CHECK: %[[VAR20:.*]] = torch.aten.slice_scatter %arg3, %[[VAR19]], %[[INTM1]], %arg2, %[[NONE]], %[[INT1]] : !torch.vtensor<[3,6,27],complex<f32>>, !torch.vtensor<[3,6,1],complex<f32>>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[3,6,27],complex<f32>>
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[VAR20]] : !torch.vtensor<[3,6,27],complex<f32>>)
// CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[3,6,27],complex<f32>>) -> !torch.vtensor<[3,6,27],complex<f32>>
// CHECK: return %[[VAR6]] : !torch.vtensor<[3,6,27],complex<f32>>
func.func @torch.aten.stft.center_2D_hop_length_3_window_pad_both(%arg0: !torch.vtensor<[3,90],f32>, %arg1: !torch.vtensor<[8],f32>) -> !torch.vtensor<[3,6,27],complex<f32>> {
%padmode = torch.constant.str "reflect"
%nfft = torch.constant.int 10
%hoplen = torch.constant.int 3
%winlen = torch.constant.int 8
%cstfalse = torch.constant.bool false
%csttrue = torch.constant.bool true
%0 = torch.aten.stft.center %arg0, %nfft, %hoplen, %winlen, %arg1, %cstfalse, %padmode, %cstfalse, %cstfalse, %csttrue, %cstfalse : !torch.vtensor<[3,90],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[8],f32>, !torch.bool, !torch.str, !torch.bool, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[3,6,27],complex<f32>>
return %0 : !torch.vtensor<[3,6,27],complex<f32>>
}
// -----
// CHECK-LABEL: func.func @native_layer_norm(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,56,56,96],f32>, %[[ARG1:.*]]: !torch.list<int>, %[[ARG2:.*]]: !torch.vtensor<[96],f32>, %[[ARG3:.*]]: !torch.vtensor<[96],f32>, %[[ARG4:.*]]: !torch.float) -> (!torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>) {
// CHECK-DAG: %[[INT96:.*]] = torch.constant.int 96
// CHECK-DAG: %[[INT56:.*]] = torch.constant.int 56
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[VAR0:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[VAR1:.*]] = torch.aten.sum.dim_IntList %[[ARG0]], %[[VAR0]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[1,56,56,96],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],f32>
// CHECK: %[[VAR2:.*]] = torch.aten.numel %[[ARG0]] : !torch.vtensor<[1,56,56,96],f32> -> !torch.int
// CHECK: %[[VAR3:.*]] = torch.aten.div.Scalar %[[VAR1]], %[[VAR2]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
// CHECK: %[[VAR4:.*]] = torch.prim.ListConstruct %int1, %int56, %int56, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR5:.*]] = torch.aten.broadcast_to %[[VAR3]], %[[VAR4]] : !torch.vtensor<[1,56,56,1],f32>, !torch.list<int> -> !torch.vtensor<[1,56,56,96],f32>
// CHECK: %[[VAR6:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAR5]], %int1 : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,96],f32>, !torch.int -> !torch.vtensor<[1,56,56,96],f32>
// CHECK: %[[VAR7:.*]] = torch.aten.mul.Tensor %[[VAR6]], %[[VAR6]] : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,96],f32> -> !torch.vtensor<[1,56,56,96],f32>
// CHECK: %[[VAR8:.*]] = torch.aten.sum.dim_IntList %[[VAR7]], %0, %true, %none : !torch.vtensor<[1,56,56,96],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],f32>
// CHECK: %[[VAR9:.*]] = torch.aten.numel %7 : !torch.vtensor<[1,56,56,96],f32> -> !torch.int
// CHECK: %[[VAR10:.*]] = torch.aten.div.Scalar %[[VAR8]], %[[VAR9]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
// CHECK: %[[VAR11:.*]] = torch.aten.add.Scalar %[[VAR10]], %[[ARG4]], %int1 : !torch.vtensor<[1,56,56,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
// CHECK: %[[VAR12:.*]] = torch.aten.rsqrt %[[VAR11]] : !torch.vtensor<[1,56,56,1],f32> -> !torch.vtensor<[1,56,56,1],f32>
// CHECK: %[[VAR13:.*]] = torch.prim.ListConstruct %int1, %int56, %int56, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR14:.*]] = torch.aten.broadcast_to %[[VAR12]], %[[VAR13]] : !torch.vtensor<[1,56,56,1],f32>, !torch.list<int> -> !torch.vtensor<[1,56,56,96],f32>
// CHECK: %[[VAR15:.*]] = torch.aten.mul.Tensor %[[VAR6]], %[[VAR14]] : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,96],f32> -> !torch.vtensor<[1,56,56,96],f32>
// CHECK: %[[VAR16:.*]] = torch.aten.mul.Tensor %[[VAR15]], %[[ARG2]] : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[96],f32> -> !torch.vtensor<[1,56,56,96],f32>
// CHECK: %[[VAR17:.*]] = torch.aten.add.Tensor %[[VAR16]], %[[ARG3]], %int1 : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[96],f32>, !torch.int -> !torch.vtensor<[1,56,56,96],f32>
// CHECK: return %[[VAR17]], %[[VAR3]], %[[VAR12]] : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
func.func @native_layer_norm(%input: !torch.vtensor<[1,56,56,96],f32>, %normalized_shape: !torch.list<int>, %weight: !torch.vtensor<[96],f32>, %bias: !torch.vtensor<[96],f32>, %eps: !torch.float) -> (!torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>) {
%result, %mean, %rstd = torch.aten.native_layer_norm %input, %normalized_shape, %weight, %bias, %eps : !torch.vtensor<[1,56,56,96],f32>, !torch.list<int>, !torch.vtensor<[96],f32>, !torch.vtensor<[96],f32>, !torch.float -> !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
return %result, %mean, %rstd : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
}
// -----
// CHECK-LABEL: func.func @native_layer_norm_mixed_dtypes(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,56,56,96],bf16>, %[[ARG1:.*]]: !torch.list<int>, %[[ARG2:.*]]: !torch.vtensor<[96],bf16>, %[[ARG3:.*]]: !torch.vtensor<[96],bf16>, %[[ARG4:.*]]: !torch.float) -> (!torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>) {
// CHECK-DAG: %[[INT96:.*]] = torch.constant.int 96
// CHECK-DAG: %[[INT56:.*]] = torch.constant.int 56
// CHECK-DAG: %[[INT15:.*]] = torch.constant.int 15
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[VAR0:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[VAR1:.*]] = torch.aten.sum.dim_IntList %[[ARG0]], %[[VAR0]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],f32>
// CHECK: %[[VAR2:.*]] = torch.aten.numel %[[ARG0]] : !torch.vtensor<[1,56,56,96],bf16> -> !torch.int
// CHECK: %[[VAR3:.*]] = torch.aten.div.Scalar %[[VAR1]], %[[VAR2]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
// CHECK: %[[VAR4:.*]] = torch.aten.to.dtype %[[VAR3]], %[[INT15]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],bf16>
// CHECK: %[[VAR5:.*]] = torch.prim.ListConstruct %int1, %int56, %int56, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR6:.*]] = torch.aten.broadcast_to %[[VAR4]], %[[VAR5]] : !torch.vtensor<[1,56,56,1],bf16>, !torch.list<int> -> !torch.vtensor<[1,56,56,96],bf16>
// CHECK: %[[VAR7:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAR6]], %int1 : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,96],bf16>, !torch.int -> !torch.vtensor<[1,56,56,96],bf16>
// CHECK: %[[VAR8:.*]] = torch.aten.mul.Tensor %[[VAR7]], %[[VAR7]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,96],bf16> -> !torch.vtensor<[1,56,56,96],bf16>
// CHECK: %[[VAR9:.*]] = torch.aten.sum.dim_IntList %[[VAR8]], %0, %true, %none : !torch.vtensor<[1,56,56,96],bf16>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],f32>
// CHECK: %[[VAR10:.*]] = torch.aten.numel %8 : !torch.vtensor<[1,56,56,96],bf16> -> !torch.int
// CHECK: %[[VAR11:.*]] = torch.aten.div.Scalar %[[VAR9]], %[[VAR10]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
// CHECK: %[[VAR12:.*]] = torch.aten.add.Scalar %[[VAR11]], %[[ARG4]], %int1 : !torch.vtensor<[1,56,56,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
// CHECK: %[[VAR13:.*]] = torch.aten.rsqrt %[[VAR12]] : !torch.vtensor<[1,56,56,1],f32> -> !torch.vtensor<[1,56,56,1],f32>
// CHECK: %[[VAR14:.*]] = torch.aten.to.dtype %[[VAR13]], %[[INT15]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],bf16>
// CHECK: %[[VAR15:.*]] = torch.prim.ListConstruct %int1, %int56, %int56, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR16:.*]] = torch.aten.broadcast_to %[[VAR14]], %[[VAR15]] : !torch.vtensor<[1,56,56,1],bf16>, !torch.list<int> -> !torch.vtensor<[1,56,56,96],bf16>
// CHECK: %[[VAR17:.*]] = torch.aten.mul.Tensor %[[VAR7]], %[[VAR16]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,96],bf16> -> !torch.vtensor<[1,56,56,96],bf16>
// CHECK: %[[VAR18:.*]] = torch.aten.mul.Tensor %[[VAR17]], %[[ARG2]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[96],bf16> -> !torch.vtensor<[1,56,56,96],bf16>
// CHECK: %[[VAR19:.*]] = torch.aten.add.Tensor %[[VAR18]], %[[ARG3]], %int1 : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[96],bf16>, !torch.int -> !torch.vtensor<[1,56,56,96],bf16>
// CHECK: return %[[VAR19]], %[[VAR3]], %[[VAR13]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf16>, %normalized_shape: !torch.list<int>, %weight: !torch.vtensor<[96],bf16>, %bias: !torch.vtensor<[96],bf16>, %eps: !torch.float) -> (!torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>) {
%result, %mean, %rstd = torch.aten.native_layer_norm %input, %normalized_shape, %weight, %bias, %eps : !torch.vtensor<[1,56,56,96],bf16>, !torch.list<int>, !torch.vtensor<[96],bf16>, !torch.vtensor<[96],bf16>, !torch.float -> !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
return %result, %mean, %rstd : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
}
// -----
// CHECK-LABEL: func @pixel_unshuffle_static
// CHECK-DAG: %[[C2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[C0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[C1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[C3:.*]] = torch.constant.int 3
// CHECK-DAG: %[[C4:.*]] = torch.constant.int 4
// CHECK-DAG: %[[C5:.*]] = torch.constant.int 5
// CHECK: %[[PERMLIST:.*]] = torch.prim.ListConstruct %[[C0]], %[[C1]], %[[C3]], %[[C5]], %[[C2]], %[[C4]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[EXPAND1:.*]] = torch.prims.split_dim %[[ARG0]], %[[C2]], %[[C2]] : !torch.vtensor<[1,8,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,4],f32>
// CHECK: %[[EXPAND2:.*]] = torch.prims.split_dim %[[EXPAND1]], %[[C4]], %[[C2]] : !torch.vtensor<[1,8,2,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,2,2],f32>
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[EXPAND2]], %[[PERMLIST]] : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.list<int> -> !torch.vtensor<[1,8,2,2,2,2],f32>
// CHECK: %[[COLLAPSE:.*]] = torch.prims.collapse %[[PERMUTE]], %[[C1]], %[[C3]] : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,32,2,2],f32>
// CHECK: return %[[COLLAPSE]] : !torch.vtensor<[1,32,2,2],f32>
func.func @pixel_unshuffle_static(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,32,2,2],f32> attributes {torch.assume_strict_symbolic_shapes} {
%int2 = torch.constant.int 2
%0 = torch.aten.pixel_unshuffle %arg0, %int2 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,32,2,2],f32>
return %0 : !torch.vtensor<[1,32,2,2],f32>
}
// -----
// CHECK-LABEL: func @pixel_unshuffle_fulldynamic
// CHECK-DAG: %[[C2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[C0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[C1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[C3:.*]] = torch.constant.int 3
// CHECK-DAG: %[[C4:.*]] = torch.constant.int 4
// CHECK-DAG: %[[C5:.*]] = torch.constant.int 5
// CHECK: %[[INC:.*]] = torch.aten.size.int %[[ARG0]], %[[C1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[INH:.*]] = torch.aten.size.int %[[ARG0]], %[[C2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[INW:.*]] = torch.aten.size.int %[[ARG0]], %[[C3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[OUTC:.*]] = torch.aten.mul.int %[[INC]], %[[C4]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[OUTH:.*]] = torch.aten.floordiv.int %[[INH]], %[[C2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[OUTW:.*]] = torch.aten.floordiv.int %[[INW]], %[[C2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SIZE0:.*]] = torch.aten.size.int %[[ARG0]], %[[C0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[PERMLIST:.*]] = torch.prim.ListConstruct %[[C0]], %[[C1]], %[[C3]], %[[C5]], %[[C2]], %[[C4]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[EXPAND1:.*]] = torch.prims.split_dim %[[ARG0]], %[[C2]], %[[OUTH]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,2,?],f32>
// CHECK: %[[EXPAND2:.*]] = torch.prims.split_dim %[[EXPAND1]], %[[C4]], %[[OUTW]] : !torch.vtensor<[?,?,?,2,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,2,?,2],f32>
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[EXPAND2]], %[[PERMLIST]] : !torch.vtensor<[?,?,?,2,?,2],f32>, !torch.list<int> -> !torch.vtensor<[?,?,2,2,?,?],f32>
// CHECK: %[[COLLAPSE:.*]] = torch.prims.collapse %[[PERMUTE]], %[[C1]], %[[C3]] : !torch.vtensor<[?,?,2,2,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[COLLAPSE]] : !torch.vtensor<[?,?,?,?],f32>
func.func @pixel_unshuffle_fulldynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.assume_strict_symbolic_shapes} {
%int2 = torch.constant.int 2
%0 = torch.aten.pixel_unshuffle %arg0, %int2 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.broadcast_tensors
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,3],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,1],f32>
// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: torch.runtime.assert %[[TRUE]], "tensors are not broadcast compatible"
// CHECK: torch.runtime.assert %[[TRUE]], "tensors are not broadcast compatible"
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[B0:.*]] = torch.aten.broadcast_to %[[ARG0]], %[[SHAPE]] : !torch.vtensor<[1,3],f32>, !torch.list<int> -> !torch.vtensor<[2,3],f32>
// CHECK: %[[B1:.*]] = torch.aten.broadcast_to %[[ARG1]], %[[SHAPE]] : !torch.vtensor<[2,1],f32>, !torch.list<int> -> !torch.vtensor<[2,3],f32>
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[B0]], %[[B1]] : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor<[2,3],f32>>
// CHECK: return %[[LIST]] : !torch.list<vtensor<[2,3],f32>>
func.func @torch.aten.broadcast_tensors(%arg0: !torch.vtensor<[1,3],f32>, %arg1: !torch.vtensor<[2,1],f32>) -> !torch.list<vtensor<[2,3], f32>> {
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[1,3],f32>, !torch.vtensor<[2,1],f32>) -> !torch.list<vtensor>
%1 = torch.aten.broadcast_tensors %0 : !torch.list<vtensor> -> !torch.list<vtensor<[2,3],f32>>
return %1 : !torch.list<vtensor<[2,3],f32>>
}
// -----
// CHECK-LABEL: func @channel_shuffle