diff --git a/lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.cpp b/lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.cpp index 25857c52b1..f934c03be7 100644 --- a/lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.cpp +++ b/lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.cpp @@ -219,6 +219,11 @@ struct ConvertFromTensor : public OpConversionPattern { } auto typeInfo = res.value(); + if (typeInfo.polynomialType.getForm() != Form::COEFF) { + return rewriter.notifyMatchFailure( + op, "FromTensor lowering requires COEFF form"); + } + auto resultShape = typeInfo.tensorType.getShape()[0]; auto resultEltTy = typeInfo.tensorType.getElementType(); auto inputTensorTy = op.getInput().getType(); @@ -276,6 +281,11 @@ struct ConvertConstant : public OpConversionPattern { auto typeInfo = res.value(); + if (typeInfo.polynomialType.getForm() != Form::COEFF) { + return rewriter.notifyMatchFailure( + op, "Constant lowering requires COEFF form"); + } + auto attr = dyn_cast(op.getValue()); if (!attr) return rewriter.notifyMatchFailure(op, @@ -362,6 +372,11 @@ struct ConvertMonomial : public OpConversionPattern { op, "failed to construct common conversion info"); auto typeInfo = res.value(); + if (typeInfo.polynomialType.getForm() != Form::COEFF) { + return rewriter.notifyMatchFailure( + op, "Monomial lowering requires COEFF form"); + } + SmallVector storageShape(typeInfo.tensorType.getShape().begin(), typeInfo.tensorType.getShape().end()); if (auto rnsType = dyn_cast(typeInfo.coefficientType)) { @@ -514,6 +529,11 @@ struct ConvertMonicMonomialMul op, "failed to construct common conversion info"); auto typeInfo = res.value(); + if (typeInfo.polynomialType.getForm() != Form::COEFF) { + return rewriter.notifyMatchFailure( + op, "MonicMonomialMul lowering requires COEFF form"); + } + ImplicitLocOpBuilder b(op.getLoc(), rewriter); // In general, a rotation would correspond to multiplication by x^n, // which requires a modular reduction step. But because the verifier @@ -600,6 +620,11 @@ struct ConvertLeadingTerm : public OpConversionPattern { op, "failed to construct common conversion info"); auto typeInfo = res.value(); + if (typeInfo.polynomialType.getForm() != Form::COEFF) { + return rewriter.notifyMatchFailure( + op, "LeadingTerm lowering requires COEFF form"); + } + auto c0 = arith::ConstantOp::create( b, b.getIntegerAttr(typeInfo.coefficientStorageType, 0)); auto c1 = arith::ConstantOp::create(b, b.getIndexAttr(1)); @@ -655,6 +680,11 @@ struct ConvertApplyCoefficientwise op, "failed to construct common conversion info"); auto typeInfo = res.value(); + if (typeInfo.polynomialType.getForm() != Form::COEFF) { + return rewriter.notifyMatchFailure( + op, "ApplyCoefficientwise lowering requires COEFF form"); + } + ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value inputTensor = adaptor.getInput(); // Implicitly we're relying on the fact that the operand has already been @@ -1466,6 +1496,11 @@ struct ConvertNTT : public OpConversionPattern { op, "missing convert-elementwise-to-affine"); } + if (polyTy.getForm() != Form::COEFF) { + return rewriter.notifyMatchFailure(op, + "NTT lowering requires COEFF form"); + } + if (!op.getRoot()) { return rewriter.notifyMatchFailure(op, "missing root attribute"); } @@ -1517,6 +1552,11 @@ struct ConvertINTT : public OpConversionPattern { op, "failed to construct common conversion info"); auto typeInfo = res.value(); + if (typeInfo.polynomialType.getForm() != Form::EVAL) { + return rewriter.notifyMatchFailure(op, + "INTT lowering requires EVAL form"); + } + if (!op.getRoot()) { op.emitError("missing root attribute"); return rewriter.notifyMatchFailure(op, "missing root attribute");