-
Notifications
You must be signed in to change notification settings - Fork 136
HE SIMD batch MatMat multiplication protocol #1271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
All contributors have signed the CLA ✍️ ✅ |
Summary of ChangesHello @Welement-zyf, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a novel protocol for performing batch matrix-matrix multiplication within a Homomorphic Encryption (HE) context. The primary goal is to significantly speed up online computations where one input is a ciphertext and the other is a plaintext weight. The protocol intelligently adapts its strategy based on the input matrix dimensions relative to the SIMD lane size, employing techniques like Baby-Step Giant-Step for efficient rotations when applicable, and providing robust methods for data preparation, encryption, and result extraction. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new HE SIMD batch matrix-matrix multiplication protocol, which is a significant feature. The implementation is well-structured and handles different input sizes with specific strategies, including a BSGS-style algorithm for efficiency. My review focuses on improving code quality by removing leftover debugging artifacts like print statements, cleaning up commented-out code, and eliminating unused variables and methods. These changes will enhance the readability and maintainability of the new protocol and its tests.
// yacl::parallel_for(0, polys.size(), [&](int64_t bgn, int64_t end) { | ||
// for (int64_t i = bgn; i < end; ++i) { | ||
// seal::util::encrypt_zero_symmetric(secret_key, context, | ||
// context.first_parms_id(), false, | ||
// save_seed, out[i]); | ||
// seal::util::multiply_add_plain_with_scaling_variant( | ||
// polys[i], *context.first_context_data(), | ||
// seal::util::RNSIter{out[i].data(), out[i].poly_modulus_degree()}); | ||
// } | ||
// }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uint64_t baby_step = absl::bit_ceil( | ||
static_cast<uint64_t>(std::sqrt(block_size * meta.dims[2] / (double)meta.dims[1]))); | ||
baby_step = std::min(baby_step, block_size); | ||
std::cout << "baby_step: " << baby_step << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
absl::Span<const uint64_t> ans_poly, | ||
absl::Span<uint64_t> res_mat) const; | ||
|
||
// Shape2D GetInShape() const { return in_shape_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
private: | ||
void NoiseFloodInplace(RLWECt &ct, const seal::SEALContext &context); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::shared_ptr<SIMDBatchMMProt> simd_batchmm_prot_; | ||
|
||
void SetUp() override { | ||
std::cout << "setup" << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// modulus_bits = {60, 30, 52, plain_bits}; | ||
} else { | ||
// modulus_bits = {60, 45, 45, 58, plain_bits}; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return fmt::format("{}", p.param ? "NoiseFlood" : "Approx"); | ||
}); | ||
|
||
TEST_P(SIMDBatchMMTest, ) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int64_t d0 = 24; | ||
int64_t d1 = 2048; | ||
int64_t d2 = 1408; | ||
std::cout << "Testing Batch MatMat " << batch << "x" << d0 << "x" << d1 << " * " << d1 << "x" << d2 << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size_t num_row_blocks = CeilDiv(static_cast<uint64_t>(d1), block_size); | ||
size_t num_col_blocks = CeilDiv(static_cast<uint64_t>(d2), block_size); | ||
size_t simd_lane = simd_batchmm_prot_->SIMDLane(); | ||
// size_t row_size = simd_batchmm_prot_->SIMDLane() / 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
namespace spu::mpc::cheetah { | ||
|
||
class SIMDBatchMMProt { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should inherit EnableCPRNG
here, rather than in the unittest.
}; | ||
|
||
|
||
static constexpr int kNoiseFloodBits = 40; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that the logic of noise flooding has not been utilized; it may need to be added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that the cheetah_dot.h
defines BatchDotOLE
. You might consider testing its performance against your implementation.
return fmt::format("{}", p.param ? "NoiseFlood" : "Approx"); | ||
}); | ||
|
||
TEST_P(SIMDBatchMMTest, ) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Regarding how to elegantly generate test data and perform correctness testing, you can refer to
cheetah_dot_test.cc
. - For measuring elapsed time, you can use yacl::ElapsedTimer, and the usage can be referenced in
src/libspu/mpc/utils/lowmc_test.cc
. - If you need to track communication volume and round counts, you can directly retrieve link statistics from yacl; the usage can be referenced in
src/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc
.
int64_t d0 = 24; | ||
int64_t d1 = 2048; | ||
int64_t d2 = 1408; | ||
std::cout << "Testing Batch MatMat " << batch << "x" << d0 << "x" << d1 << " * " << d1 << "x" << d2 << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We prefer SPDLOG_INFO
if you really need to print something, but you should make them as less as possible.
}); | ||
|
||
TEST_P(SIMDBatchMMTest, ) { | ||
size_t batch = 64; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the parameterized testing framework of gtest
to elegantly test the correctness and performance of the implementation with various configurations.
You can define something like:
class CheetahDotTest
: public ::testing::TestWithParam<std::tuple<FieldType, Shape3D>> {};
INSTANTIATE_TEST_SUITE_P(
Cheetah, CheetahDotTest,
testing::Combine(testing::Values(FieldType::FM64, FieldType::FM128),
testing::Values(Shape3D{8, 7, 5}, Shape3D{57, 30, 1},
Shape3D{30, 57, 1}, Shape3D{18, 8, 41},
Shape3D{500, 13, 25},
Shape3D{1, 2048, 768},
Shape3D{18, 768, 78})),
[](const testing::TestParamInfo<CheetahDotTest::ParamType>& p) {
return fmt::format("{}x{}x{}x{}", std::get<0>(std::get<1>(p.param)),
std::get<1>(std::get<1>(p.param)),
std::get<2>(std::get<1>(p.param)),
std::get<0>(p.param));
});
|
||
SIMDBatchMMProt(SIMDBatchMMProt&&) = delete; | ||
|
||
// Same as SIMDMulProt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Unless there is a necessity for external calls, the APIs should be made
private
as much as possible. - You should refer to the API design in
cheetah_dot.h
, where the input provided to the upper-level APIs needs to be of theNdArrayRef
type.
Please sign the CLA, btw. |
I feel that your batch matmul also needs to define a State, which you can refer to in |
I have read the CLA Document and I hereby sign the CLA |
Pull Request
What problem does this PR solve?
A new HE SIMD batch MatMat multiplication protocol, which can accelerate online HE MatMat computation. This protocol is focus on MatMul between ciphertext input and plaintext weight.
Possible side effects?
Increase the computation of encoding plaintext weight. But this part can be done offline when there are many inputs.