@@ -710,24 +710,38 @@ func.func @fold_extract_transpose(
710
710
711
711
// -----
712
712
713
- // CHECK-LABEL: fold_extract_broadcast
713
+ // CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
714
714
// CHECK-SAME: %[[A:.*]]: f32
715
715
// CHECK: return %[[A]] : f32
716
- func.func @fold_extract_broadcast (%a : f32 ) -> f32 {
716
+ func.func @fold_extract_broadcast_same_input_output_scalar (%a : f32 ,
717
+ %idx0 : index , %idx1 : index , %idx2 : index ) -> f32 {
717
718
%b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
718
- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
719
+ %r = vector.extract %b [%idx0 , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
719
720
return %r : f32
720
721
}
721
722
722
723
// -----
723
724
724
- // CHECK-LABEL: fold_extract_broadcast_0dvec
725
+ // CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
726
+ // CHECK-SAME: %[[A:.*]]: vector<4xf32>
727
+ // CHECK: return %[[A]] : vector<4xf32>
728
+ func.func @fold_extract_broadcast_same_input_output_vec (%a : vector <4 xf32 >,
729
+ %idx0 : index , %idx1 : index ) -> vector <4 xf32 > {
730
+ %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
731
+ %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
732
+ return %r : vector <4 xf32 >
733
+ }
734
+
735
+ // -----
736
+
737
+ // CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
725
738
// CHECK-SAME: %[[A:.*]]: vector<f32>
726
739
// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
727
740
// CHECK: return %[[B]] : f32
728
- func.func @fold_extract_broadcast_0dvec (%a : vector <f32 >) -> f32 {
741
+ func.func @fold_extract_broadcast_0dvec_input_scalar_output (%a : vector <f32 >,
742
+ %idx0 : index , %idx1 : index , %idx2: index ) -> f32 {
729
743
%b = vector.broadcast %a : vector <f32 > to vector <1 x2 x4 xf32 >
730
- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
744
+ %r = vector.extract %b [%idx0 , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
731
745
return %r : f32
732
746
}
733
747
@@ -747,57 +761,68 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
747
761
// CHECK-LABEL: fold_extract_splat
748
762
// CHECK-SAME: %[[A:.*]]: f32
749
763
// CHECK: return %[[A]] : f32
750
- func.func @fold_extract_splat (%a : f32 ) -> f32 {
764
+ func.func @fold_extract_splat (%a : f32 , %idx0 : index , %idx1 : index , %idx2 : index ) -> f32 {
751
765
%b = vector.splat %a : vector <1 x2 x4 xf32 >
752
- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
766
+ %r = vector.extract %b [%idx0 , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
753
767
return %r : f32
754
768
}
755
769
756
770
// -----
757
771
758
- // CHECK-LABEL: fold_extract_broadcast_vector
759
- // CHECK-SAME: %[[A:.*]]: vector<4xf32>
760
- // CHECK: return %[[A]] : vector<4xf32>
761
- func.func @fold_extract_broadcast_vector (%a : vector <4 xf32 >) -> vector <4 xf32 > {
762
- %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
763
- %r = vector.extract %b [0 , 1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
764
- return %r : vector <4 xf32 >
772
+ // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
773
+ // CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
774
+ // CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
775
+ // CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
776
+ // CHECK: return %[[R]] : f32
777
+ func.func @fold_extract_broadcast_dim1_broadcasting (%a : vector <2 x1 xf32 >,
778
+ %idx : index , %idx1 : index , %idx2 : index ) -> f32 {
779
+ %b = vector.broadcast %a : vector <2 x1 xf32 > to vector <1 x2 x4 xf32 >
780
+ %r = vector.extract %b [%idx , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
781
+ return %r : f32
765
782
}
766
783
767
784
// -----
768
785
769
- // CHECK-LABEL: fold_extract_broadcast
770
- // CHECK-SAME: %[[A:.*]]: vector<4xf32>
771
- // CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
772
- // CHECK: return %[[R]] : f32
773
- func.func @fold_extract_broadcast (%a : vector <4 xf32 >) -> f32 {
774
- %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
775
- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
776
- return %r : f32
786
+ // CHECK-LABEL: fold_extract_broadcast_to_lower_rank
787
+ // CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
788
+ // CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
789
+ // CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
790
+ // CHECK: return %[[B]] : vector<4xf32>
791
+ // rank(extract_output) < rank(broadcast_input)
792
+ func.func @fold_extract_broadcast_to_lower_rank (%a : vector <2 x4 xf32 >,
793
+ %idx0 : index , %idx1 : index ) -> vector <4 xf32 > {
794
+ %b = vector.broadcast %a : vector <2 x4 xf32 > to vector <1 x2 x4 xf32 >
795
+ %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
796
+ return %r : vector <4 xf32 >
777
797
}
778
798
779
799
// -----
780
800
781
- // CHECK-LABEL: fold_extract_broadcast
801
+ // CHECK-LABEL: fold_extract_broadcast_to_higher_rank
782
802
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
783
803
// CHECK: return %[[B]] : vector<4xf32>
784
- func.func @fold_extract_broadcast (%a : f32 ) -> vector <4 xf32 > {
804
+ // rank(extract_output) > rank(broadcast_input)
805
+ func.func @fold_extract_broadcast_to_higher_rank (%a : f32 , %idx0 : index , %idx1 : index )
806
+ -> vector <4 xf32 > {
785
807
%b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
786
- %r = vector.extract %b [0 , 1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
808
+ %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
787
809
return %r : vector <4 xf32 >
788
810
}
789
811
790
812
// -----
791
813
792
- // CHECK-LABEL: fold_extract_broadcast
814
+ // CHECK-LABEL: fold_extract_broadcast_to_equal_rank
793
815
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
794
816
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
795
817
// CHECK: return %[[R]] : vector<8xf32>
796
- func.func @fold_extract_broadcast (%a : vector <1 xf32 >) -> vector <8 xf32 > {
818
+ // rank(extract_output) == rank(broadcast_input)
819
+ func.func @fold_extract_broadcast_to_equal_rank (%a : vector <1 xf32 >, %idx0 : index )
820
+ -> vector <8 xf32 > {
797
821
%b = vector.broadcast %a : vector <1 xf32 > to vector <1 x8 xf32 >
798
- %r = vector.extract %b [0 ] : vector <8 xf32 > from vector <1 x8 xf32 >
822
+ %r = vector.extract %b [%idx0 ] : vector <8 xf32 > from vector <1 x8 xf32 >
799
823
return %r : vector <8 xf32 >
800
824
}
825
+
801
826
// -----
802
827
803
828
// CHECK-LABEL: @fold_extract_shuffle
0 commit comments