Skip to content

Implement basis extension#2893

Open
crockeea wants to merge 8 commits intogoogle:mainfrom
crockeea:basisext
Open

Implement basis extension#2893
crockeea wants to merge 8 commits intogoogle:mainfrom
crockeea:basisext

Conversation

@crockeea
Copy link
Copy Markdown
Collaborator

@crockeea crockeea commented Apr 21, 2026

This PR adds an implementation of basis extension. I made several choices which may not be right in the end, so input on architecture is welcome.

  • I implemented basis conversion inside the LWEToPolynomial conversion. An alternate construction would be to lower LWE.convert_basis to polynomial.convert_basis and then replace polynomial.convert_basis with a polynomial.apply_coefficientwise, but that seemed unnecessary to me.
  • This means there is no longer a need for Polynomial's ConvertBasisOp, so it has been removed. All of the NTT tests that used ConvertBasis have been converted to use ApplyCoefficientwise instead, since there were no tests for this op previously, and because it is coeff-only like ConvertBasis was.
  • I added a helper function in PolynomialOps to properly apply polynomial.apply_coefficientwise, which involves a lot of boilerplate.
  • I chose not to add an rns.convert_basis op. My motivation for this and implementing the lowering in LWEToPolynomial is that my goal is to lower the implementation to the polynomial layer without going to ModArith. I'm still not 100% sold on using polynomial.apply_coefficientwise for that reason, since it involves the use of ModArith, but I'm trying it for now. [The alternative would be to add some new ops to the Polynomial dialect, which is also not appealing.] Instead of an rns.convert_basis op, I implemented basis conversion as top-level C++ functions in RNSOps.cpp, which are used by the LWEToPolynomial lowering for LWE.ConvertBasis.
  • There are new RNS ops for extracting a single slice (rather than an RNS subset of slices) and for packing individual limbs together into an RNS.

Miscellaneous changes:

  1. When programming with GPT in VSCode, I get a .codex file, which should be ignored.
  2. I renamed the TypeInterface build targets in RNS and ModArith to match the style of other build targets

The biggest piece missing from this PR is that there are no LLVM tests for basis conversion. This should be added, but I figured I should get an architecture review first, and also get input on where to put the tests. My feeling is that we should have some rns-to-llvm tests for basis conversion.

PR written with AI assistance.

@crockeea crockeea requested a review from j2kun April 21, 2026 22:03
@j2kun
Copy link
Copy Markdown
Collaborator

j2kun commented Apr 23, 2026

(sorry for the delay, I've been at a summit in Zurich; I should have time to review this before Monday)

Copy link
Copy Markdown
Collaborator

@j2kun j2kun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the architectural approach is good, though you still seem unsure about relying on ApplyCoefficientwise for the purpose of lowering polynomial to a cornami dialect.

My main question is: do you think we should try to get some intermediate layer of correctness checking for the algorithm before trying to lower all the way to LLVM and running it? We have some facilities for doing this (cf. ArithmeticDag and for a good example see its use in the Horner's method [1]) but it may need some extra work to make it suitable to use with RNS-typed values. In short, doing this would allow us to have the same code path be used to run the computation in a unit test, as well as to generate the MLIR during the lowering.

If you think it's worth it, then I'd say we can submit this PR without an end-to-end test, and follow up to replace the core computational pieces with an ArithmeticDag.

[1]: See https://github.com/google/heir/blob/main/lib/Utils/Polynomial/Horner.h for the implementation, https://github.com/google/heir/blob/main/lib/Utils/Polynomial/HornerTest.cpp for the unit tests, and

auto resultNode = polynomial::hornerMonomialPolynomialEvaluation(
xNode, coefficients, dagType);
// Use IRMaterializingVisitor to convert to MLIR
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
kernel::IRMaterializingVisitor visitor(op.getValue().getType());
Value finalOutput = visitor.process(resultNode, b)[0];
for its use in the MLIR lowering.

Comment thread lib/Dialect/RNS/IR/RNSOps.h Outdated
Comment thread lib/Dialect/RNS/IR/RNSOps.td Outdated
Comment thread lib/Dialect/RNS/IR/RNSOps.td Outdated
Comment thread lib/Dialect/Polynomial/IR/PolynomialOps.cpp Outdated
Comment thread lib/Dialect/Polynomial/IR/PolynomialOps.h
Comment thread lib/Dialect/RNS/IR/RNSOps.cpp Outdated
return mrcs;
}

FailureOr<ArrayAttr> buildQInvProds(mlir::MLIRContext* ctx,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since you end up materializing these qInv values as constants, I suspect you can avoid constructing the ArrayAttr and just return a SmallVector<APint> of the qInv values. I think that should simplify the computeMixedRadixCoeffs a little bit. You can even take it a step further and extract the core computation to operate entirely on APInts (and vectors thereof) and unit test the computation in isolation.

You can probably also add a custom mod_arith::ConstantOp::create that accepts the type and an APInt as input, and handles the boilerplate of constructing the ModArithAttr and underlying IntegerAttr.

Creating lots of unused attributes also has a performance impact (at heir-opt runtime) because constructing MLIR attributes involves extra hashing and uniquifying steps.

Copy link
Copy Markdown
Collaborator Author

@crockeea crockeea Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that the ArrayAttr is unnecessary; I've fixed that.

For using ModArithAttr vs APInt, I erred on the side of type safety: in computeMixedRadixCoeffs, I check whether the type of the ModArithAttr is what I expect. It's unfortunate that this could have a performance impact, but I would like to keep the type safety if possible (that's one of MLIR's strong points).

Creating lots of unused attributes

I was thinking about this when I wrote the code and generated the constants once per keyswitch (so outside the ApplyCoefficientwise loop) rather than inside. We could go one level higher and make these part of the CKKSAttributes, which would further reduce the number of times they are computed. I think this would make a nice follow-up PR.

On a partially-related note: do we have some kind of benchmarking on individual passes to help identify when a particular pass is slow?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(On that follow-up PR: computing roots of unity for NTTs is another place we should compute the roots once for a parameter set instead of at every NTT)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have some kind of benchmarking on individual passes to help identify when a particular pass is slow

So far just --mlir-timing which prints a table of per-pass runtimes after finishing. Useful, but we don't have anything like performance regression testing to ensure changes don't make passes/pipelines slower.

Copy link
Copy Markdown
Collaborator

@j2kun j2kun Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think constructing ArrayAttrs is significantly more expensive than constructing the underlying scalar attributes, in particular since the ArrayAttr hashes as a combination of all its member hashes (unless we're using a dense form of the array attr which isn't implemented for mod_arith member attrs yet). So I'm happy to just omit the ArrayAttr for now.

<< xi.getType();
return failure();
}
// Using Horner's method, we compute (c_{i-1}*q_{i-2} + c_{i-2})*q_{i-3} +
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offhand, I wonder to what extent folks have looked at improving upon this using something like Estrin's scheme https://en.wikipedia.org/wiki/Estrin%27s_scheme

At least, if basis extension is a computational bottleneck, that should provide a little bit more parallelism opportunities.

Not particularly relevant for this lowering which is meant for functional testing, but I thought I'd mention it, since the use of MacOp reminded me.

Copy link
Copy Markdown
Collaborator Author

@crockeea crockeea Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already an enormous amount of parallelism available in the context of [hybrid] keyswitching: we split the input into multiple "states" (each of which can be basis-extended in parallel), and within each state, we can basis-extend each coefficient in parallel. Concretely, for a ring dimension of 2^16 with 40 moduli in the input basis and 2 key-switch primes, there are (2^16)*20 parallel basis conversions happening. Since Horner "is optimal in the sense that it minimizes the number of multiplications and additions required to evaluate an arbitrary polynomial", I suspect this is more important than adding additional parallelism.

That said, this is the kind of interesting tradeoff that (especially hardware) implementations might want to explore. In my head, I have this idea for a benchmarking suite built into HEIR that explore a high-dimensional space of algorithm combinations to look for the best combination on a particular HW backend.

Comment thread lib/Dialect/LWE/Conversions/LWEToPolynomial/LWEToPolynomial.cpp Outdated
Comment thread lib/Dialect/LWE/Conversions/LWEToPolynomial/LWEToPolynomial.cpp Outdated
Comment thread lib/Dialect/RNS/IR/RNSOps.cpp Outdated
@crockeea
Copy link
Copy Markdown
Collaborator Author

crockeea commented Apr 28, 2026

Addressed comments. Remaining open items:

  • qInvProds thread

  • though you still seem unsure about relying on ApplyCoefficientwise for the purpose of lowering polynomial to a cornami dialect.

    I am, but I think it's a fine direction to start with. Given that we haven't done any work on moving from HEIR to HW yet, I don't know exactly what the challenges will or won't be.

  • do you think we should try to get some intermediate layer of correctness checking for the algorithm before trying to lower all the way to LLVM and running it?

    I don't see much advantage to doing that; I was happy with the structure of the LLM tests for NTT.

  • In short, doing this would allow us to have the same code path be used to run the computation in a unit test, as well as to generate the MLIR during the lowering.

    I think your point here is that in the LLVM lowering, we're using the basis conversion code path plus the ModArith->LLVM code path, and there could be bugs in the latter that affect correctness of the former. That's true, but I think using LLVM is the "easy button" for writing simple tests. If there's a bug in LLVM now, that could be annoying to find. But if someone introduces a bug later, this test would help catch it.

  • If you think it's worth it, then I'd say we can submit this PR without an end-to-end test, and follow up to replace the core computational pieces with an ArithmeticDag.

    My preferred path would be to add an LLVM test to this PR. Given how subtle this algorithm is (having implemented it multiple times and having made mistakes each time), having a test is important :) The open item is where to put the LLVM tests: cram it into the existing Polynomial/Conversions/heir_polynomial_to_llvm tests, or copy that infrastructure to RNS tests?

@j2kun
Copy link
Copy Markdown
Collaborator

j2kun commented Apr 28, 2026

My preferred path would be to add an LLVM test to this PR.

Sounds good to me.

where to put the LLVM tests

I think it's less infrastructure than it seems, so putting it in a new location is fine. Most of the hard work is done in tests/llvm_runner:llvm_runner.bzl, and so you can import that macro anywhere and use it as

llvm_runner_test(
    name = "target_name",
    heir_opt_flags = [
        "--emit-c-interface",
        // any other passes you need
        "--heir-polynomial-to-llvm",
    ],
    main_c_src = "test_harness.cc",
    mlir_src = "mlir_input.mlir",
)

https://github.com/google/heir/blob/main/tests/Dialect/Polynomial/Conversions/heir_polynomial_to_llvm/runner/lower_add_test.cc is a good, simple reference here.

@crockeea
Copy link
Copy Markdown
Collaborator Author

I'll work on adding a test tomorrow. In the meantime, could you help me figure out why CI is failing, but the tests run fine on my local machine? See, e.g,. https://github.com/google/heir/actions/runs/25077913451/job/73475389405?pr=2893

@crockeea
Copy link
Copy Markdown
Collaborator Author

crockeea commented Apr 29, 2026

In the meantime, could you help me figure out why CI is failing, but the tests run fine on my local machine?

Apparently the issue is that CI uses the -c flag, and I introduced a bug that is only triggered with -c. Working on a fix.

@j2kun
Copy link
Copy Markdown
Collaborator

j2kun commented Apr 29, 2026

In the meantime, could you help me figure out why CI is failing, but the tests run fine on my local machine?

Apparently the issue is that CI uses the -c flag, and I introduced a bug that is only triggered with -c. Working on a fix.

Let me know if there's something I can do to help debug.

@crockeea crockeea force-pushed the basisext branch 2 times, most recently from 71a90c0 to 8654d31 Compare April 30, 2026 23:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants