Skip to content

Commit 98d9af8

Browse files
tholopcopybara-github
authored andcommitted
Add C++ wrapper around the testing decryptor.
PiperOrigin-RevId: 850079342
1 parent 6b8a895 commit 98d9af8

File tree

8 files changed

+435
-5
lines changed

8 files changed

+435
-5
lines changed

willow/src/api/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ rust_library(
3333
"@protobuf//rust:protobuf",
3434
"//shell_wrapper:status",
3535
"//willow/proto/willow:aggregation_config_rust_proto",
36+
"//willow/src/shell:single_thread_hkdf",
3637
"//willow/src/traits:proto_serialization_traits",
3738
],
3839
)

willow/src/api/aggregation_config.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,19 @@ impl ToProto for AggregationConfig {
8282
}
8383
}
8484

85+
impl AggregationConfig {
86+
/// Computes context bytes by hashing the session ID in the config.
87+
pub fn compute_context_bytes(&self) -> Result<Vec<u8>, StatusError> {
88+
let context_seed = single_thread_hkdf::compute_hkdf(
89+
self.session_id.as_bytes(),
90+
b"",
91+
b"AggregationConfig.context_string",
92+
single_thread_hkdf::seed_length(),
93+
)?;
94+
Ok(context_seed.as_bytes().to_vec())
95+
}
96+
}
97+
8598
#[cfg(test)]
8699
mod tests {
87100
use crate::AggregationConfig;

willow/src/input_encoding/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ load("@rules_cc//cc:cc_test.bzl", "cc_test")
1818
package(
1919
default_applicable_licenses = [
2020
],
21+
default_visibility = ["//visibility:public"],
2122
)
2223

2324
cc_library(

willow/src/testing_utils/BUILD

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
load("@cxx.rs//tools/bazel:rust_cxx_bridge.bzl", "rust_cxx_bridge")
16+
load("@rules_cc//cc:cc_library.bzl", "cc_library")
17+
load("@rules_cc//cc:cc_test.bzl", "cc_test")
1518
load("@rules_rust//rust:defs.bzl", "rust_library", "rust_test")
1619

1720
package(
1821
default_applicable_licenses = [
1922
],
20-
default_visibility = ["//visibility:public"],
23+
default_visibility = ["//:internal"],
2124
)
2225

2326
# PRNG
@@ -71,7 +74,6 @@ rust_library(
7174
"//shell_wrapper:status",
7275
"//willow/src/api:aggregation_config",
7376
"//willow/src/shell:kahe_shell",
74-
"//willow/src/shell:single_thread_hkdf",
7577
"//willow/src/shell:vahe_shell",
7678
"//willow/src/traits:ahe_traits",
7779
"//willow/src/traits:kahe_traits",
@@ -90,16 +92,28 @@ rust_test(
9092
],
9193
)
9294

95+
rust_cxx_bridge(
96+
name = "shell_testing_decryptor_cxx",
97+
src = "shell_testing_decryptor.rs",
98+
deps = [
99+
":shell_testing_decryptor",
100+
],
101+
)
102+
93103
rust_library(
94104
name = "shell_testing_decryptor",
95-
testonly = 1,
96105
srcs = [
97106
"shell_testing_decryptor.rs",
98107
],
99108
deps = [
100-
":shell_testing_parameters",
109+
"@protobuf//rust:protobuf",
110+
"@cxx.rs//:cxx",
111+
"//shell_wrapper:shell_types_cc",
101112
"//shell_wrapper:status",
113+
"//willow/proto/willow:aggregation_config_rust_proto",
114+
"//willow/proto/willow:messages_rust_proto",
102115
"//willow/src/api:aggregation_config",
116+
"//willow/src/shell:ahe_shell",
103117
"//willow/src/shell:kahe_shell",
104118
"//willow/src/shell:parameters_shell",
105119
"//willow/src/shell:single_thread_hkdf",
@@ -108,6 +122,35 @@ rust_library(
108122
"//willow/src/traits:kahe_traits",
109123
"//willow/src/traits:messages",
110124
"//willow/src/traits:prng_traits",
125+
"//willow/src/traits:proto_serialization_traits",
111126
"//willow/src/traits:vahe_traits",
112127
],
113128
)
129+
130+
cc_library(
131+
name = "shell_testing_decryptor_cc",
132+
srcs = ["shell_testing_decryptor.cc"],
133+
hdrs = ["shell_testing_decryptor.h"],
134+
deps = [
135+
":shell_testing_decryptor_cxx",
136+
"@abseil-cpp//absl/memory",
137+
"@abseil-cpp//absl/status",
138+
"@abseil-cpp//absl/status:statusor",
139+
"//shell_wrapper:shell_types_cc",
140+
"//willow/proto/shell:shell_ciphertexts_cc_proto",
141+
"//willow/proto/willow:aggregation_config_cc_proto",
142+
"//willow/proto/willow:messages_cc_proto",
143+
"//willow/src/input_encoding:codec",
144+
],
145+
)
146+
147+
cc_test(
148+
name = "shell_testing_decryptor_test",
149+
srcs = ["shell_testing_decryptor_test.cc"],
150+
deps = [
151+
":shell_testing_decryptor_cc",
152+
"@googletest//:gtest_main",
153+
"//shell_wrapper:status_matchers",
154+
"//willow/proto/willow:aggregation_config_cc_proto",
155+
],
156+
)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "willow/src/testing_utils/shell_testing_decryptor.h"
18+
19+
#include <cstdint>
20+
#include <memory>
21+
#include <string>
22+
#include <utility>
23+
#include <vector>
24+
25+
#include "absl/memory/memory.h"
26+
#include "absl/status/status.h"
27+
#include "absl/status/statusor.h"
28+
#include "shell_wrapper/shell_types.h"
29+
#include "willow/src/input_encoding/codec.h"
30+
#include "willow/src/testing_utils/shell_testing_decryptor.rs.h"
31+
32+
namespace secure_aggregation {
33+
34+
ShellTestingDecryptor::ShellTestingDecryptor(
35+
rust::Box<ShellTestingDecryptorRust> decryptor)
36+
: decryptor_(std::move(decryptor)) {}
37+
38+
absl::StatusOr<std::unique_ptr<ShellTestingDecryptor>>
39+
ShellTestingDecryptor::Create(
40+
const willow::AggregationConfigProto& aggregation_config) {
41+
std::string aggregation_config_proto = aggregation_config.SerializeAsString();
42+
rust::Slice<const uint8_t> slice = ToRustSlice(aggregation_config_proto);
43+
44+
secure_aggregation::ShellTestingDecryptorRust* out;
45+
std::unique_ptr<std::string> status_message;
46+
int status_code =
47+
create_shell_testing_decryptor(slice, &out, &status_message);
48+
49+
if (status_code != 0) {
50+
return absl::Status(absl::StatusCode(status_code), *status_message);
51+
}
52+
// Use `into_box` to avoid linker issues arising from rust::Box::from_raw.
53+
return absl::WrapUnique(new ShellTestingDecryptor(decryptor_into_box(out)));
54+
}
55+
56+
absl::StatusOr<willow::ShellAhePublicKey>
57+
ShellTestingDecryptor::GeneratePublicKey() {
58+
rust::Vec<uint8_t> out;
59+
std::unique_ptr<std::string> status_message;
60+
int status_code = decryptor_->generate_public_key(&out, &status_message);
61+
62+
if (status_code != 0) {
63+
return absl::Status(absl::StatusCode(status_code), *status_message);
64+
}
65+
66+
willow::ShellAhePublicKey public_key;
67+
if (!public_key.ParseFromArray(out.data(), out.size())) {
68+
return absl::InternalError("Failed to parse ShellAhePublicKey");
69+
}
70+
return public_key;
71+
}
72+
73+
absl::StatusOr<willow::EncodedData> ShellTestingDecryptor::Decrypt(
74+
const willow::ClientMessage& message) {
75+
std::string contribution_proto = message.SerializeAsString();
76+
rust::Slice<const uint8_t> slice(
77+
reinterpret_cast<const uint8_t*>(contribution_proto.data()),
78+
contribution_proto.size());
79+
80+
rust::Vec<secure_aggregation::EncodedDataEntry> rust_flat_data;
81+
std::unique_ptr<std::string> status_message;
82+
int status_code =
83+
decryptor_->decrypt(slice, &rust_flat_data, &status_message);
84+
85+
if (status_code != 0) {
86+
return absl::Status(absl::StatusCode(status_code), *status_message);
87+
}
88+
89+
willow::EncodedData encoded_data;
90+
for (const auto& rust_entry : rust_flat_data) {
91+
std::string key(rust_entry.key);
92+
std::vector<int64_t> val;
93+
val.reserve(rust_entry.values.size());
94+
for (auto v : rust_entry.values) {
95+
val.push_back(static_cast<int64_t>(v));
96+
}
97+
encoded_data[std::move(key)] = std::move(val);
98+
}
99+
100+
return encoded_data;
101+
}
102+
103+
} // namespace secure_aggregation
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_SHELL_TESTING_DECRYPTOR_H_
18+
#define SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_SHELL_TESTING_DECRYPTOR_H_
19+
20+
#include <memory>
21+
22+
#include "absl/status/statusor.h"
23+
#include "willow/proto/shell/ciphertexts.pb.h"
24+
#include "willow/proto/willow/aggregation_config.pb.h"
25+
#include "willow/proto/willow/messages.pb.h"
26+
#include "willow/src/input_encoding/codec.h"
27+
#include "willow/src/testing_utils/shell_testing_decryptor.rs.h"
28+
29+
namespace secure_aggregation {
30+
31+
// Basic implementation of a single decryptor that uses Shell operations
32+
// directly. Useful for testing Shell clients, by checking that encrypted
33+
// messages can be decrypted properly.
34+
class ShellTestingDecryptor {
35+
public:
36+
// Creates a new ShellTestingDecryptor from the given config, hashing the
37+
// session ID from the config to seed KAHE and AHE public parameters.
38+
static absl::StatusOr<std::unique_ptr<ShellTestingDecryptor>> Create(
39+
const willow::AggregationConfigProto& aggregation_config);
40+
41+
// Generates a new AHE public key, and stores the corresponding secret key.
42+
absl::StatusOr<willow::ShellAhePublicKey> GeneratePublicKey();
43+
44+
// Decrypts a client message using the stored AHE secret key, by recovering
45+
// the KAHE key from the AHE ciphertext and then decrypting the KAHE
46+
// ciphertext. Does not verify the client proof contained in the message.
47+
absl::StatusOr<willow::EncodedData> Decrypt(
48+
const willow::ClientMessage& message);
49+
50+
private:
51+
explicit ShellTestingDecryptor(
52+
rust::Box<ShellTestingDecryptorRust> decryptor);
53+
54+
rust::Box<ShellTestingDecryptorRust> decryptor_;
55+
};
56+
57+
} // namespace secure_aggregation
58+
59+
#endif // SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_SHELL_TESTING_DECRYPTOR_H_

0 commit comments

Comments
 (0)