Skip to content

Conversation

Welement-zyf
Copy link

Pull Request

What problem does this PR solve?

A new HE SIMD batch MatMat multiplication protocol, which can accelerate online HE MatMat computation. This protocol is focus on MatMul between ciphertext input and plaintext weight.

Possible side effects?

Increase the computation of encoding plaintext weight. But this part can be done offline when there are many inputs.

cd src
bazelisk test //libspu/mpc/cheetah/arith:simd_batchmm_test --test_output=all

Copy link

github-actions bot commented Oct 16, 2025

All contributors have signed the CLA ✍️ ✅
Posted by the CLA Assistant Lite bot.

Copy link
Contributor

Summary of Changes

Hello @Welement-zyf, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a novel protocol for performing batch matrix-matrix multiplication within a Homomorphic Encryption (HE) context. The primary goal is to significantly speed up online computations where one input is a ciphertext and the other is a plaintext weight. The protocol intelligently adapts its strategy based on the input matrix dimensions relative to the SIMD lane size, employing techniques like Baby-Step Giant-Step for efficient rotations when applicable, and providing robust methods for data preparation, encryption, and result extraction.

Highlights

  • New Protocol Introduction: A new Homomorphic Encryption (HE) SIMD batch Matrix-Matrix multiplication protocol has been introduced, designed to accelerate online HE computations.
  • Targeted Acceleration: This protocol specifically focuses on optimizing matrix multiplication where the input is a ciphertext and the weight is a plaintext, a common scenario in HE applications.
  • Adaptive Strategy: The SIMDBatchMMProt class implements an adaptive multiplication strategy, utilizing Baby-Step Giant-Step (BSGS) rotations for efficient computation when the input shape is small relative to the SIMD lane, and a simpler approach otherwise.
  • Comprehensive Data Handling: The protocol includes robust methods for preparing input and weight vectors, batch encoding/decoding, symmetric encryption, and parsing the final results back into a matrix format.
  • Testing and Build System Integration: New build rules and a dedicated test suite have been added to integrate the protocol into the existing system and ensure its correctness and performance.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new HE SIMD batch matrix-matrix multiplication protocol, which is a significant feature. The implementation is well-structured and handles different input sizes with specific strategies, including a BSGS-style algorithm for efficiency. My review focuses on improving code quality by removing leftover debugging artifacts like print statements, cleaning up commented-out code, and eliminating unused variables and methods. These changes will enhance the readability and maintainability of the new protocol and its tests.

Comment on lines +86 to +95
// yacl::parallel_for(0, polys.size(), [&](int64_t bgn, int64_t end) {
// for (int64_t i = bgn; i < end; ++i) {
// seal::util::encrypt_zero_symmetric(secret_key, context,
// context.first_parms_id(), false,
// save_seed, out[i]);
// seal::util::multiply_add_plain_with_scaling_variant(
// polys[i], *context.first_context_data(),
// seal::util::RNSIter{out[i].data(), out[i].poly_modulus_degree()});
// }
// });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function contains a large block of commented-out code. If this code is no longer needed, it should be removed to improve readability and maintainability.

uint64_t baby_step = absl::bit_ceil(
static_cast<uint64_t>(std::sqrt(block_size * meta.dims[2] / (double)meta.dims[1])));
baby_step = std::min(baby_step, block_size);
std::cout << "baby_step: " << baby_step << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This std::cout statement seems to be for debugging purposes and should be removed from the final code.

absl::Span<const uint64_t> ans_poly,
absl::Span<uint64_t> res_mat) const;

// Shape2D GetInShape() const { return in_shape_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out function declaration should be removed to keep the header file clean.



private:
void NoiseFloodInplace(RLWECt &ct, const seal::SEALContext &context);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The private method NoiseFloodInplace is declared here but is not used within the class. This appears to be dead code and should be removed along with its implementation.

std::shared_ptr<SIMDBatchMMProt> simd_batchmm_prot_;

void SetUp() override {
std::cout << "setup" << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This std::cout statement appears to be for debugging and should be removed from the test's SetUp method.

Comment on lines +54 to +57
// modulus_bits = {60, 30, 52, plain_bits};
} else {
// modulus_bits = {60, 45, 45, 58, plain_bits};
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of commented-out code in the SetUp method should be removed to improve test readability.

return fmt::format("{}", p.param ? "NoiseFlood" : "Approx");
});

TEST_P(SIMDBatchMMTest, ) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test case has an empty name. Please provide a descriptive name that indicates what is being tested, for example, Correctness.

Suggested change
TEST_P(SIMDBatchMMTest, ) {
TEST_P(SIMDBatchMMTest, Correctness) {

int64_t d0 = 24;
int64_t d1 = 2048;
int64_t d2 = 1408;
std::cout << "Testing Batch MatMat " << batch << "x" << d0 << "x" << d1 << " * " << d1 << "x" << d2 << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test contains several std::cout statements for debugging. Tests should be silent on success. Please remove this and other print statements throughout the test.

size_t num_row_blocks = CeilDiv(static_cast<uint64_t>(d1), block_size);
size_t num_col_blocks = CeilDiv(static_cast<uint64_t>(d2), block_size);
size_t simd_lane = simd_batchmm_prot_->SIMDLane();
// size_t row_size = simd_batchmm_prot_->SIMDLane() / 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out line should be removed to improve test code clarity.


namespace spu::mpc::cheetah {

class SIMDBatchMMProt {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should inherit EnableCPRNG here, rather than in the unittest.

};


static constexpr int kNoiseFloodBits = 40;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the logic of noise flooding has not been utilized; it may need to be added.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the cheetah_dot.h defines BatchDotOLE. You might consider testing its performance against your implementation.

return fmt::format("{}", p.param ? "NoiseFlood" : "Approx");
});

TEST_P(SIMDBatchMMTest, ) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Regarding how to elegantly generate test data and perform correctness testing, you can refer to cheetah_dot_test.cc.
  • For measuring elapsed time, you can use yacl::ElapsedTimer, and the usage can be referenced in src/libspu/mpc/utils/lowmc_test.cc.
  • If you need to track communication volume and round counts, you can directly retrieve link statistics from yacl; the usage can be referenced in src/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc.

int64_t d0 = 24;
int64_t d1 = 2048;
int64_t d2 = 1408;
std::cout << "Testing Batch MatMat " << batch << "x" << d0 << "x" << d1 << " * " << d1 << "x" << d2 << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer SPDLOG_INFO if you really need to print something, but you should make them as less as possible.

});

TEST_P(SIMDBatchMMTest, ) {
size_t batch = 64;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the parameterized testing framework of gtest to elegantly test the correctness and performance of the implementation with various configurations.

You can define something like:

class CheetahDotTest
    : public ::testing::TestWithParam<std::tuple<FieldType, Shape3D>> {};

INSTANTIATE_TEST_SUITE_P(
    Cheetah, CheetahDotTest,
    testing::Combine(testing::Values(FieldType::FM64, FieldType::FM128),
                     testing::Values(Shape3D{8, 7, 5}, Shape3D{57, 30, 1},
                                     Shape3D{30, 57, 1}, Shape3D{18, 8, 41},
                                     Shape3D{500, 13, 25},
                                     Shape3D{1, 2048, 768},
                                     Shape3D{18, 768, 78})),
    [](const testing::TestParamInfo<CheetahDotTest::ParamType>& p) {
      return fmt::format("{}x{}x{}x{}", std::get<0>(std::get<1>(p.param)),
                         std::get<1>(std::get<1>(p.param)),
                         std::get<2>(std::get<1>(p.param)),
                         std::get<0>(p.param));
    });


SIMDBatchMMProt(SIMDBatchMMProt&&) = delete;

// Same as SIMDMulProt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Unless there is a necessity for external calls, the APIs should be made private as much as possible.
  • You should refer to the API design in cheetah_dot.h, where the input provided to the upper-level APIs needs to be of the NdArrayRef type.

@deadlywing
Copy link
Contributor

Please sign the CLA, btw.

@deadlywing
Copy link
Contributor

I feel that your batch matmul also needs to define a State, which you can refer to in src/libspu/mpc/cheetah/state.h, where CheetahDotState is defined. In src/libspu/mpc/cheetah/arithmetic.cc, you can see that the matmul series of interfaces are basically implemented using CheetahDotState to achieve specific functionalities.

@Welement-zyf
Copy link
Author

I have read the CLA Document and I hereby sign the CLA

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants