Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ struct ConvertFromTensor : public OpConversionPattern<FromTensorOp> {
}
auto typeInfo = res.value();

if (typeInfo.polynomialType.getForm() != Form::COEFF) {
return rewriter.notifyMatchFailure(
op, "FromTensor lowering requires COEFF form");
}

auto resultShape = typeInfo.tensorType.getShape()[0];
auto resultEltTy = typeInfo.tensorType.getElementType();
auto inputTensorTy = op.getInput().getType();
Expand Down Expand Up @@ -276,6 +281,11 @@ struct ConvertConstant : public OpConversionPattern<ConstantOp> {

auto typeInfo = res.value();

if (typeInfo.polynomialType.getForm() != Form::COEFF) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return rewriter.notifyMatchFailure(
op, "Constant lowering requires COEFF form");
}

auto attr = dyn_cast<TypedIntPolynomialAttr>(op.getValue());
if (!attr)
return rewriter.notifyMatchFailure(op,
Expand Down Expand Up @@ -362,6 +372,11 @@ struct ConvertMonomial : public OpConversionPattern<MonomialOp> {
op, "failed to construct common conversion info");
auto typeInfo = res.value();

if (typeInfo.polynomialType.getForm() != Form::COEFF) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto here; see https://github.com/google/heir/blob/main/lib/Dialect/Polynomial/Transforms/PolyMulToNTT.cpp#L59. As a CONST op, this output a polynomial in either form.

return rewriter.notifyMatchFailure(
op, "Monomial lowering requires COEFF form");
}

SmallVector<int64_t> storageShape(typeInfo.tensorType.getShape().begin(),
typeInfo.tensorType.getShape().end());
if (auto rnsType = dyn_cast<RNSType>(typeInfo.coefficientType)) {
Expand Down Expand Up @@ -514,6 +529,11 @@ struct ConvertMonicMonomialMul
op, "failed to construct common conversion info");
auto typeInfo = res.value();

if (typeInfo.polynomialType.getForm() != Form::COEFF) {
return rewriter.notifyMatchFailure(
op, "MonicMonomialMul lowering requires COEFF form");
}

ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// In general, a rotation would correspond to multiplication by x^n,
// which requires a modular reduction step. But because the verifier
Expand Down Expand Up @@ -600,6 +620,11 @@ struct ConvertLeadingTerm : public OpConversionPattern<LeadingTermOp> {
op, "failed to construct common conversion info");
auto typeInfo = res.value();

if (typeInfo.polynomialType.getForm() != Form::COEFF) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LeadingTerm op doesn't output a poly, does it?

return rewriter.notifyMatchFailure(
op, "LeadingTerm lowering requires COEFF form");
}

auto c0 = arith::ConstantOp::create(
b, b.getIntegerAttr(typeInfo.coefficientStorageType, 0));
auto c1 = arith::ConstantOp::create(b, b.getIndexAttr(1));
Expand Down Expand Up @@ -655,6 +680,11 @@ struct ConvertApplyCoefficientwise
op, "failed to construct common conversion info");
auto typeInfo = res.value();

if (typeInfo.polynomialType.getForm() != Form::COEFF) {
return rewriter.notifyMatchFailure(
op, "ApplyCoefficientwise lowering requires COEFF form");
}

ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value inputTensor = adaptor.getInput();
// Implicitly we're relying on the fact that the operand has already been
Expand Down Expand Up @@ -1466,6 +1496,11 @@ struct ConvertNTT : public OpConversionPattern<NTTOp> {
op, "missing convert-elementwise-to-affine");
}

if (polyTy.getForm() != Form::COEFF) {
return rewriter.notifyMatchFailure(op,
"NTT lowering requires COEFF form");
}

if (!op.getRoot()) {
return rewriter.notifyMatchFailure(op, "missing root attribute");
}
Expand Down Expand Up @@ -1517,6 +1552,11 @@ struct ConvertINTT : public OpConversionPattern<INTTOp> {
op, "failed to construct common conversion info");
auto typeInfo = res.value();

if (typeInfo.polynomialType.getForm() != Form::EVAL) {
return rewriter.notifyMatchFailure(op,
"INTT lowering requires EVAL form");
}

if (!op.getRoot()) {
op.emitError("missing root attribute");
return rewriter.notifyMatchFailure(op, "missing root attribute");
Expand Down
Loading