Implement basis extension#2893
Conversation
|
(sorry for the delay, I've been at a summit in Zurich; I should have time to review this before Monday) |
j2kun
left a comment
There was a problem hiding this comment.
I think the architectural approach is good, though you still seem unsure about relying on ApplyCoefficientwise for the purpose of lowering polynomial to a cornami dialect.
My main question is: do you think we should try to get some intermediate layer of correctness checking for the algorithm before trying to lower all the way to LLVM and running it? We have some facilities for doing this (cf. ArithmeticDag and for a good example see its use in the Horner's method [1]) but it may need some extra work to make it suitable to use with RNS-typed values. In short, doing this would allow us to have the same code path be used to run the computation in a unit test, as well as to generate the MLIR during the lowering.
If you think it's worth it, then I'd say we can submit this PR without an end-to-end test, and follow up to replace the core computational pieces with an ArithmeticDag.
[1]: See https://github.com/google/heir/blob/main/lib/Utils/Polynomial/Horner.h for the implementation, https://github.com/google/heir/blob/main/lib/Utils/Polynomial/HornerTest.cpp for the unit tests, and
heir/lib/Transforms/LowerPolynomialEval/Patterns.cpp
Lines 110 to 116 in 7e114c3
| return mrcs; | ||
| } | ||
|
|
||
| FailureOr<ArrayAttr> buildQInvProds(mlir::MLIRContext* ctx, |
There was a problem hiding this comment.
nit: since you end up materializing these qInv values as constants, I suspect you can avoid constructing the ArrayAttr and just return a SmallVector<APint> of the qInv values. I think that should simplify the computeMixedRadixCoeffs a little bit. You can even take it a step further and extract the core computation to operate entirely on APInts (and vectors thereof) and unit test the computation in isolation.
You can probably also add a custom mod_arith::ConstantOp::create that accepts the type and an APInt as input, and handles the boilerplate of constructing the ModArithAttr and underlying IntegerAttr.
Creating lots of unused attributes also has a performance impact (at heir-opt runtime) because constructing MLIR attributes involves extra hashing and uniquifying steps.
There was a problem hiding this comment.
Agree that the ArrayAttr is unnecessary; I've fixed that.
For using ModArithAttr vs APInt, I erred on the side of type safety: in computeMixedRadixCoeffs, I check whether the type of the ModArithAttr is what I expect. It's unfortunate that this could have a performance impact, but I would like to keep the type safety if possible (that's one of MLIR's strong points).
Creating lots of unused attributes
I was thinking about this when I wrote the code and generated the constants once per keyswitch (so outside the ApplyCoefficientwise loop) rather than inside. We could go one level higher and make these part of the CKKSAttributes, which would further reduce the number of times they are computed. I think this would make a nice follow-up PR.
On a partially-related note: do we have some kind of benchmarking on individual passes to help identify when a particular pass is slow?
There was a problem hiding this comment.
(On that follow-up PR: computing roots of unity for NTTs is another place we should compute the roots once for a parameter set instead of at every NTT)
There was a problem hiding this comment.
do we have some kind of benchmarking on individual passes to help identify when a particular pass is slow
So far just --mlir-timing which prints a table of per-pass runtimes after finishing. Useful, but we don't have anything like performance regression testing to ensure changes don't make passes/pipelines slower.
There was a problem hiding this comment.
I think constructing ArrayAttrs is significantly more expensive than constructing the underlying scalar attributes, in particular since the ArrayAttr hashes as a combination of all its member hashes (unless we're using a dense form of the array attr which isn't implemented for mod_arith member attrs yet). So I'm happy to just omit the ArrayAttr for now.
| << xi.getType(); | ||
| return failure(); | ||
| } | ||
| // Using Horner's method, we compute (c_{i-1}*q_{i-2} + c_{i-2})*q_{i-3} + |
There was a problem hiding this comment.
Offhand, I wonder to what extent folks have looked at improving upon this using something like Estrin's scheme https://en.wikipedia.org/wiki/Estrin%27s_scheme
At least, if basis extension is a computational bottleneck, that should provide a little bit more parallelism opportunities.
Not particularly relevant for this lowering which is meant for functional testing, but I thought I'd mention it, since the use of MacOp reminded me.
There was a problem hiding this comment.
There is already an enormous amount of parallelism available in the context of [hybrid] keyswitching: we split the input into multiple "states" (each of which can be basis-extended in parallel), and within each state, we can basis-extend each coefficient in parallel. Concretely, for a ring dimension of 2^16 with 40 moduli in the input basis and 2 key-switch primes, there are (2^16)*20 parallel basis conversions happening. Since Horner "is optimal in the sense that it minimizes the number of multiplications and additions required to evaluate an arbitrary polynomial", I suspect this is more important than adding additional parallelism.
That said, this is the kind of interesting tradeoff that (especially hardware) implementations might want to explore. In my head, I have this idea for a benchmarking suite built into HEIR that explore a high-dimensional space of algorithm combinations to look for the best combination on a particular HW backend.
|
Addressed comments. Remaining open items:
|
Sounds good to me.
I think it's less infrastructure than it seems, so putting it in a new location is fine. Most of the hard work is done in https://github.com/google/heir/blob/main/tests/Dialect/Polynomial/Conversions/heir_polynomial_to_llvm/runner/lower_add_test.cc is a good, simple reference here. |
|
I'll work on adding a test tomorrow. In the meantime, could you help me figure out why CI is failing, but the tests run fine on my local machine? See, e.g,. https://github.com/google/heir/actions/runs/25077913451/job/73475389405?pr=2893 |
Apparently the issue is that CI uses the -c flag, and I introduced a bug that is only triggered with -c. Working on a fix. |
Let me know if there's something I can do to help debug. |
71a90c0 to
8654d31
Compare
This PR adds an implementation of basis extension. I made several choices which may not be right in the end, so input on architecture is welcome.
Miscellaneous changes:
The biggest piece missing from this PR is that there are no LLVM tests for basis conversion. This should be added, but I figured I should get an architecture review first, and also get input on where to put the tests. My feeling is that we should have some rns-to-llvm tests for basis conversion.
PR written with AI assistance.