|
| 1 | +From 139c779d447d6163c51dbe9d8735b2062025f032 Mon Sep 17 00:00:00 2001 |
| 2 | +From: Christopher Bate < [email protected]> |
| 3 | +Date: Fri, 21 Mar 2025 03:28:26 +0000 |
| 4 | +Subject: [PATCH 5/7] Fix crash on ComplexType in PointwiseToLinalgMapConverter |
| 5 | + |
| 6 | +--- |
| 7 | + .../conversions/linalg/tests/pointwise.mlir | 23 ++++++++++++++ |
| 8 | + .../transforms/StablehloToLinalgPointwise.cpp | 30 +++++++++++++++---- |
| 9 | + 2 files changed, 48 insertions(+), 5 deletions(-) |
| 10 | + |
| 11 | +diff --git a/stablehlo/conversions/linalg/tests/pointwise.mlir b/stablehlo/conversions/linalg/tests/pointwise.mlir |
| 12 | +index 6dc76f24..7a9f71aa 100644 |
| 13 | +--- a/stablehlo/conversions/linalg/tests/pointwise.mlir |
| 14 | ++++ b/stablehlo/conversions/linalg/tests/pointwise.mlir |
| 15 | +@@ -23,6 +23,29 @@ func.func @float_add(%lhs: tensor<2x2xf32>, |
| 16 | + |
| 17 | + // ----- |
| 18 | + |
| 19 | ++// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> |
| 20 | ++// CHECK-LABEL: func @complex_add_const |
| 21 | ++// CHECK-PRIMITIVE-LABEL: func @complex_add_const |
| 22 | ++func.func @complex_add_const(%lhs: tensor<2x2xcomplex<f32>>, |
| 23 | ++ %rhs: tensor<2x2xcomplex<f32>>) |
| 24 | ++ -> tensor<2x2xcomplex<f32>> { |
| 25 | ++ |
| 26 | ++ // CHECK: %[[CST:.+]] = complex.constant [1.000000e-01 : f32, 2.000000e-01 : f32] : complex<f32> |
| 27 | ++ // CHECK: linalg.generic |
| 28 | ++ // CHECK: ^bb0(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>) |
| 29 | ++ // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = complex.add %[[IN]], %[[CST]] |
| 30 | ++ // CHECK: linalg.yield %[[RESULT]] |
| 31 | ++ |
| 32 | ++ // CHECK-PRIMITIVE: linalg.map |
| 33 | ++ // CHECK-PRIMITIVE: complex.add |
| 34 | ++ %cst = stablehlo.constant dense<(0.1, 0.2)> : tensor<2x2xcomplex<f32>> |
| 35 | ++ %0 = "stablehlo.add"(%lhs, %cst) {someattr} |
| 36 | ++ : (tensor<2x2xcomplex<f32>>, tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> |
| 37 | ++ func.return %0 : tensor<2x2xcomplex<f32>> |
| 38 | ++} |
| 39 | ++ |
| 40 | ++// ----- |
| 41 | ++ |
| 42 | + // CHECK-LABEL: func @float_add_dynamic_encoding |
| 43 | + // CHECK-PRIMITIVE-LABEL: func @float_add_dynamic_encoding |
| 44 | + func.func @float_add_dynamic_encoding( |
| 45 | +diff --git a/stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp b/stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp |
| 46 | +index 707db6a7..301dfdc2 100644 |
| 47 | +--- a/stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp |
| 48 | ++++ b/stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp |
| 49 | +@@ -114,6 +114,28 @@ FailureOr<PointwiseConversionInfo> checkOperandsAndResults( |
| 50 | + return PointwiseConversionInfo{maxRank, resultTy}; |
| 51 | + } |
| 52 | + |
| 53 | ++/// If `input` is a splat constant value, materialize the scalar splat |
| 54 | ++/// value. Otherwise, return nullopt. |
| 55 | ++std::optional<Value> materializeSplatScalarConstant(RewriterBase &rewriter, |
| 56 | ++ Location loc, Value input) { |
| 57 | ++ SplatElementsAttr attr; |
| 58 | ++ Type elementType = mlir::getElementTypeOrSelf(input.getType()); |
| 59 | ++ if (!matchPattern(input, m_Constant(&attr))) return {}; |
| 60 | ++ if (isa<IntegerType, FloatType, IndexType>(elementType)) { |
| 61 | ++ return rewriter |
| 62 | ++ .create<arith::ConstantOp>(loc, elementType, |
| 63 | ++ attr.getSplatValue<TypedAttr>()) |
| 64 | ++ .getResult(); |
| 65 | ++ } |
| 66 | ++ if (isa<ComplexType>(elementType)) { |
| 67 | ++ return rewriter |
| 68 | ++ .create<complex::ConstantOp>(loc, elementType, |
| 69 | ++ attr.getSplatValue<ArrayAttr>()) |
| 70 | ++ .getResult(); |
| 71 | ++ } |
| 72 | ++ return {}; |
| 73 | ++} |
| 74 | ++ |
| 75 | + /// Converts a HLO operation to a linalg.map op that contains the corresponding |
| 76 | + /// scalar operations. |
| 77 | + template <typename OpTy> |
| 78 | +@@ -160,11 +182,9 @@ struct PointwiseToLinalgMapConverter : OpConversionPattern<OpTy> { |
| 79 | + SmallVector<Value> mappedInputs; |
| 80 | + SmallVector<Value> scalarInputs; |
| 81 | + for (Value input : adaptor.getOperands()) { |
| 82 | +- DenseElementsAttr attr; |
| 83 | +- if (matchPattern(input, m_Constant(&attr)) && attr.isSplat()) { |
| 84 | +- scalarInputs.push_back(rewriter.create<arith::ConstantOp>( |
| 85 | +- loc, cast<ShapedType>(input.getType()).getElementType(), |
| 86 | +- attr.getSplatValue<TypedAttr>())); |
| 87 | ++ if (std::optional<Value> splatVal = |
| 88 | ++ materializeSplatScalarConstant(rewriter, loc, input)) { |
| 89 | ++ scalarInputs.push_back(*splatVal); |
| 90 | + } else if (getRank(input) == maxRank) { |
| 91 | + mappedInputs.push_back(coerceTensorShape( |
| 92 | + rewriter, loc, cast<TypedValue<ShapedType>>(input), |
| 93 | +-- |
| 94 | +2.46.0 |
| 95 | + |
0 commit comments