Skip to content

Commit 61fb954

Browse files
authored
[mlir][Vector] Improve support for vector.extract(broadcast) (#116234)
This patch improves support for vector.extract(broadcast) dynamic dimension folders. This is mostly a matter of moving a conservative condition for dynamic dimensions. The broadcast folder for vector.extract now covers the cases that the vector.extractelement + broadcast folder does. This patch also improves test coverage for vector.extract + broadcast folders/canonicalizers. The folders/canonicalizers now enumerate every supported / unsupported case.
1 parent 5a2bee0 commit 61fb954

File tree

2 files changed

+61
-38
lines changed

2 files changed

+61
-38
lines changed

Diff for: mlir/lib/Dialect/Vector/IR/VectorOps.cpp

+7-9
Original file line numberDiff line numberDiff line change
@@ -1660,10 +1660,6 @@ static bool hasZeroDimVectors(Operation *op) {
16601660

16611661
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
16621662
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1663-
// TODO: Canonicalization for dynamic position not implemented yet.
1664-
if (extractOp.hasDynamicPosition())
1665-
return Value();
1666-
16671663
Operation *defOp = extractOp.getVector().getDefiningOp();
16681664
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
16691665
return Value();
@@ -1700,20 +1696,22 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
17001696
// extract position to `0` when extracting from the source operand.
17011697
llvm::SetVector<int64_t> broadcastedUnitDims =
17021698
broadcastOp.computeBroadcastedUnitDims();
1703-
SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
1699+
SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
1700+
OpBuilder b(extractOp.getContext());
17041701
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
17051702
for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
17061703
if (broadcastedUnitDims.contains(i))
1707-
extractPos[i] = 0;
1704+
extractPos[i] = b.getIndexAttr(0);
17081705
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
17091706
// matching extract position when extracting from the source operand.
17101707
int64_t rankDiff = broadcastSrcRank - extractResultRank;
17111708
extractPos.erase(extractPos.begin(),
17121709
std::next(extractPos.begin(), extractPos.size() - rankDiff));
17131710
// OpBuilder is only used as a helper to build an I64ArrayAttr.
1714-
OpBuilder b(extractOp.getContext());
1715-
extractOp.setOperand(0, source);
1716-
extractOp.setStaticPosition(extractPos);
1711+
auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
1712+
extractOp->setOperands(
1713+
llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
1714+
extractOp.setStaticPosition(staticPos);
17171715
return extractOp.getResult();
17181716
}
17191717

Diff for: mlir/test/Dialect/Vector/canonicalize.mlir

+54-29
Original file line numberDiff line numberDiff line change
@@ -710,24 +710,38 @@ func.func @fold_extract_transpose(
710710

711711
// -----
712712

713-
// CHECK-LABEL: fold_extract_broadcast
713+
// CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
714714
// CHECK-SAME: %[[A:.*]]: f32
715715
// CHECK: return %[[A]] : f32
716-
func.func @fold_extract_broadcast(%a : f32) -> f32 {
716+
func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
717+
%idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
717718
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
718-
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
719+
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
719720
return %r : f32
720721
}
721722

722723
// -----
723724

724-
// CHECK-LABEL: fold_extract_broadcast_0dvec
725+
// CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
726+
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
727+
// CHECK: return %[[A]] : vector<4xf32>
728+
func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
729+
%idx0 : index, %idx1 : index) -> vector<4xf32> {
730+
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
731+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
732+
return %r : vector<4xf32>
733+
}
734+
735+
// -----
736+
737+
// CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
725738
// CHECK-SAME: %[[A:.*]]: vector<f32>
726739
// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
727740
// CHECK: return %[[B]] : f32
728-
func.func @fold_extract_broadcast_0dvec(%a : vector<f32>) -> f32 {
741+
func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
742+
%idx0 : index, %idx1 : index, %idx2: index) -> f32 {
729743
%b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
730-
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
744+
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
731745
return %r : f32
732746
}
733747

@@ -747,57 +761,68 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
747761
// CHECK-LABEL: fold_extract_splat
748762
// CHECK-SAME: %[[A:.*]]: f32
749763
// CHECK: return %[[A]] : f32
750-
func.func @fold_extract_splat(%a : f32) -> f32 {
764+
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
751765
%b = vector.splat %a : vector<1x2x4xf32>
752-
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
766+
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
753767
return %r : f32
754768
}
755769

756770
// -----
757771

758-
// CHECK-LABEL: fold_extract_broadcast_vector
759-
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
760-
// CHECK: return %[[A]] : vector<4xf32>
761-
func.func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> {
762-
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
763-
%r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
764-
return %r : vector<4xf32>
772+
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
773+
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
774+
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
775+
// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
776+
// CHECK: return %[[R]] : f32
777+
func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
778+
%idx : index, %idx1 : index, %idx2 : index) -> f32 {
779+
%b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
780+
%r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
781+
return %r : f32
765782
}
766783

767784
// -----
768785

769-
// CHECK-LABEL: fold_extract_broadcast
770-
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
771-
// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
772-
// CHECK: return %[[R]] : f32
773-
func.func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
774-
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
775-
%r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
776-
return %r : f32
786+
// CHECK-LABEL: fold_extract_broadcast_to_lower_rank
787+
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
788+
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
789+
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
790+
// CHECK: return %[[B]] : vector<4xf32>
791+
// rank(extract_output) < rank(broadcast_input)
792+
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
793+
%idx0 : index, %idx1 : index) -> vector<4xf32> {
794+
%b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
795+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
796+
return %r : vector<4xf32>
777797
}
778798

779799
// -----
780800

781-
// CHECK-LABEL: fold_extract_broadcast
801+
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
782802
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
783803
// CHECK: return %[[B]] : vector<4xf32>
784-
func.func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
804+
// rank(extract_output) > rank(broadcast_input)
805+
func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
806+
-> vector<4xf32> {
785807
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
786-
%r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
808+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
787809
return %r : vector<4xf32>
788810
}
789811

790812
// -----
791813

792-
// CHECK-LABEL: fold_extract_broadcast
814+
// CHECK-LABEL: fold_extract_broadcast_to_equal_rank
793815
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
794816
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
795817
// CHECK: return %[[R]] : vector<8xf32>
796-
func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
818+
// rank(extract_output) == rank(broadcast_input)
819+
func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
820+
-> vector<8xf32> {
797821
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
798-
%r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
822+
%r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
799823
return %r : vector<8xf32>
800824
}
825+
801826
// -----
802827

803828
// CHECK-LABEL: @fold_extract_shuffle

0 commit comments

Comments
 (0)