Skip to content

Commit 17e282f

Browse files
tholopcopybara-github
authored andcommitted
Implement C++ client.
PiperOrigin-RevId: 842367438
1 parent 7af55b5 commit 17e282f

File tree

8 files changed

+571
-8
lines changed

8 files changed

+571
-8
lines changed

willow/src/api/BUILD

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ cc_library(
5959
"@abseil-cpp//absl/status",
6060
"@abseil-cpp//absl/status:statusor",
6161
"@abseil-cpp//absl/strings",
62-
"@cxx.rs//:cxx",
6362
"@cxx.rs//:core",
6463
"//willow/proto/willow:aggregation_config_cc_proto",
6564
"//willow/proto/willow:server_accumulator_cc_proto",
@@ -112,3 +111,74 @@ rust_library(
112111
"//willow/src/willow_v1:willow_v1_verifier",
113112
],
114113
)
114+
115+
rust_library(
116+
name = "client",
117+
srcs = ["client.rs"],
118+
deps = [
119+
":aggregation_config",
120+
"@protobuf//rust:protobuf",
121+
"@cxx.rs//:cxx",
122+
"//shell_wrapper:status",
123+
"//willow/proto/shell:shell_ciphertexts_rust_proto",
124+
"//willow/proto/willow:aggregation_config_rust_proto",
125+
"//willow/proto/willow:messages_rust_proto",
126+
"//willow/src/shell:ahe_shell",
127+
"//willow/src/shell:kahe_shell",
128+
"//willow/src/shell:parameters_shell",
129+
"//willow/src/shell:single_thread_hkdf",
130+
"//willow/src/shell:vahe_shell",
131+
"//willow/src/traits:ahe_traits",
132+
"//willow/src/traits:client_traits",
133+
"//willow/src/traits:kahe_traits",
134+
"//willow/src/traits:messages",
135+
"//willow/src/traits:prng_traits",
136+
"//willow/src/traits:proto_serialization_traits",
137+
"//willow/src/traits:vahe_traits",
138+
"//willow/src/willow_v1:willow_v1_client",
139+
"//willow/src/willow_v1:willow_v1_server",
140+
],
141+
)
142+
143+
rust_cxx_bridge(
144+
name = "client_cxx",
145+
src = "client.rs",
146+
deps = [
147+
":client",
148+
],
149+
)
150+
151+
cc_library(
152+
name = "client_cc",
153+
srcs = ["client.cc"],
154+
hdrs = ["client.h"],
155+
deps = [
156+
":client_cxx",
157+
"@abseil-cpp//absl/status",
158+
"@abseil-cpp//absl/status:statusor",
159+
"@cxx.rs//:core",
160+
"//shell_wrapper:shell_types_cc",
161+
"//willow/proto/shell:shell_ciphertexts_cc_proto",
162+
"//willow/proto/willow:aggregation_config_cc_proto",
163+
"//willow/proto/willow:messages_cc_proto",
164+
"//willow/proto/willow:server_accumulator_cc_proto",
165+
"//willow/src/input_encoding:codec",
166+
],
167+
)
168+
169+
cc_test(
170+
name = "client_test_cc",
171+
srcs = ["client_test.cc"],
172+
deps = [
173+
":client_cc",
174+
":client_cxx",
175+
"@googletest//:gtest_main",
176+
"@abseil-cpp//absl/status",
177+
"@cxx.rs//:core",
178+
"//shell_wrapper:status_matchers",
179+
"//willow/proto/willow:aggregation_config_cc_proto",
180+
"//willow/proto/willow:input_spec_cc_proto",
181+
"//willow/src/input_encoding:codec",
182+
"//willow/src/testing_utils:shell_testing_decryptor_cc",
183+
],
184+
)

willow/src/api/client.cc

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "willow/src/api/client.h"
16+
17+
#include <cstdint>
18+
#include <memory>
19+
#include <string>
20+
#include <utility>
21+
#include <vector>
22+
23+
#include "absl/status/status.h"
24+
#include "absl/status/statusor.h"
25+
#include "include/cxx.h"
26+
#include "shell_wrapper/shell_types.h"
27+
#include "willow/proto/shell/ciphertexts.pb.h"
28+
#include "willow/proto/willow/aggregation_config.pb.h"
29+
#include "willow/proto/willow/server_accumulator.pb.h"
30+
#include "willow/src/api/client.rs.h"
31+
#include "willow/src/input_encoding/codec.h"
32+
33+
namespace secure_aggregation {
34+
35+
absl::StatusOr<willow::ClientMessage> GenerateClientContribution(
36+
const willow::AggregationConfigProto& aggregation_config,
37+
const willow::EncodedData& encoded_data,
38+
const willow::ShellAhePublicKey& public_key, const std::string& nonce) {
39+
// Initialize client.
40+
std::string config_str = aggregation_config.SerializeAsString();
41+
auto config_ptr = std::make_unique<std::string>(std::move(config_str));
42+
secure_aggregation::WillowShellClient* client_ptr = nullptr;
43+
std::unique_ptr<std::string> status_message;
44+
int status_code =
45+
initialize_client(std::move(config_ptr), &client_ptr, &status_message);
46+
if (status_code != 0) {
47+
return absl::Status(absl::StatusCode(status_code), *status_message);
48+
}
49+
// Use `into_box` to avoid linker issues arising from rust::Box::from_raw.
50+
auto client = client_into_box(client_ptr);
51+
52+
// Prepare arguments.
53+
std::vector<DataEntryView> entries;
54+
entries.reserve(encoded_data.size());
55+
for (const auto& [key, values] : encoded_data) {
56+
rust::Slice<const uint8_t> key_slice = ToRustSlice(key);
57+
// values.data() is currently a pointer to an int64_t array and not
58+
// uint64_t, so this performs an implicit cast (wrapping around if
59+
// necessary). Not using a ToRustSlice variant because this is a temporary
60+
// solution until the codec is updated to use uint64_t.
61+
rust::Slice<const uint64_t> values_slice(
62+
reinterpret_cast<const uint64_t*>(values.data()), values.size());
63+
entries.push_back(DataEntryView{key_slice, values_slice});
64+
}
65+
rust::Slice<const DataEntryView> entries_slice(entries.data(),
66+
entries.size());
67+
68+
std::string key_str = public_key.SerializeAsString();
69+
auto key_ptr = std::make_unique<std::string>(std::move(key_str));
70+
rust::Slice<const uint8_t> nonce_slice = ToRustSlice(nonce);
71+
rust::Vec<uint8_t> result_bytes;
72+
std::unique_ptr<std::string> status_message_gen;
73+
74+
// Encrypt data.
75+
int status_code_gen =
76+
generate_contribution(client, entries_slice, std::move(key_ptr),
77+
nonce_slice, &result_bytes, &status_message_gen);
78+
if (status_code_gen != 0) {
79+
return absl::Status(absl::StatusCode(status_code_gen), *status_message_gen);
80+
}
81+
82+
// Parse string to ClientMessage.
83+
willow::ClientMessage client_message;
84+
std::string result_str(reinterpret_cast<const char*>(result_bytes.data()),
85+
result_bytes.size());
86+
if (!client_message.ParseFromString(result_str)) {
87+
return absl::InternalError(
88+
"Failed to parse ClientMessage from Rust output.");
89+
}
90+
91+
return client_message;
92+
}
93+
94+
} // namespace secure_aggregation

willow/src/api/client.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef SECURE_AGGREGATION_WILLOW_SRC_API_CLIENT_H_
18+
#define SECURE_AGGREGATION_WILLOW_SRC_API_CLIENT_H_
19+
20+
#include <string>
21+
22+
#include "absl/status/statusor.h"
23+
#include "willow/proto/shell/ciphertexts.pb.h"
24+
#include "willow/proto/willow/aggregation_config.pb.h"
25+
#include "willow/proto/willow/messages.pb.h"
26+
#include "willow/proto/willow/server_accumulator.pb.h"
27+
#include "willow/src/input_encoding/codec.h"
28+
29+
namespace secure_aggregation {
30+
31+
// Generates a client contribution by encrypting the encoded data with the
32+
// provided AHE public key.
33+
absl::StatusOr<willow::ClientMessage> GenerateClientContribution(
34+
const willow::AggregationConfigProto& aggregation_config,
35+
const willow::EncodedData& encoded_data,
36+
const willow::ShellAhePublicKey& public_key, const std::string& nonce);
37+
38+
} // namespace secure_aggregation
39+
40+
#endif // SECURE_AGGREGATION_WILLOW_SRC_API_CLIENT_H_

willow/src/api/client.rs

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
use aggregation_config::AggregationConfig;
18+
use aggregation_config_rust_proto::AggregationConfigProto;
19+
use ahe_shell::PublicKey;
20+
use ahe_traits::AheBase;
21+
use client_traits::SecureAggregationClient;
22+
use kahe_shell::ShellKahe;
23+
use kahe_traits::KaheBase;
24+
use parameters_shell::create_shell_configs;
25+
use prng_traits::SecurePrng;
26+
use proto_serialization_traits::{FromProto, ToProto};
27+
use protobuf::prelude::*;
28+
use shell_ciphertexts_rust_proto::ShellAhePublicKey;
29+
use single_thread_hkdf::SingleThreadHkdfPrng;
30+
use status::ffi::FfiStatus;
31+
use status::StatusError;
32+
use std::collections::HashMap;
33+
use vahe_shell::ShellVahe;
34+
use willow_v1_client::WillowV1Client;
35+
36+
/// CXX bridge to call Rust client code from C++.
37+
///
38+
/// SAFETY: all functions in this module are only called from the wrapping C++ library,
39+
/// ensuring that output pointers are correctly wrapped by a rust::Box, and that pointer arguments
40+
/// are not null.
41+
#[cxx::bridge(namespace = "secure_aggregation")]
42+
pub mod ffi {
43+
/// One entry in the plaintext map. Obtained by taking slices of the metric names and values
44+
/// owned by the C++ EncodedData object (hence the lifetime).
45+
struct DataEntryView<'a> {
46+
key: &'a [u8],
47+
values: &'a [u64],
48+
}
49+
50+
extern "Rust" {
51+
// cxx: types used as extern Rust types are required to be defined by the same crate that
52+
// contains the bridge using them
53+
type WillowShellClient;
54+
55+
pub unsafe fn initialize_client(
56+
config: UniquePtr<CxxString>,
57+
out: *mut *mut WillowShellClient,
58+
out_status_message: *mut UniquePtr<CxxString>,
59+
) -> i32;
60+
61+
unsafe fn client_into_box(ptr: *mut WillowShellClient) -> Box<WillowShellClient>;
62+
63+
unsafe fn generate_contribution(
64+
client: &mut Box<WillowShellClient>,
65+
data: &[DataEntryView],
66+
key: UniquePtr<CxxString>,
67+
nonce: &[u8],
68+
out: *mut Vec<u8>,
69+
out_status_message: *mut UniquePtr<CxxString>,
70+
) -> i32;
71+
}
72+
}
73+
74+
pub struct WillowShellClient(WillowV1Client<ShellKahe, ShellVahe>);
75+
76+
impl WillowShellClient {
77+
fn new_from_serialized_config(
78+
config: cxx::UniquePtr<cxx::CxxString>,
79+
) -> Result<Self, StatusError> {
80+
let aggregation_config_proto =
81+
AggregationConfigProto::parse(config.as_bytes()).map_err(|e| {
82+
status::internal(format!("Failed to parse AggregationConfigProto: {}", e))
83+
})?;
84+
let aggregation_config = AggregationConfig::from_proto(aggregation_config_proto, ())?;
85+
let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?;
86+
let context_bytes = aggregation_config.compute_context_bytes()?;
87+
let kahe = ShellKahe::new(kahe_config, &context_bytes)?;
88+
let vahe = ShellVahe::new(ahe_config, &context_bytes)?;
89+
let client_seed = SingleThreadHkdfPrng::generate_seed()?;
90+
let prng = SingleThreadHkdfPrng::create(&client_seed)?;
91+
let client = WillowV1Client { kahe, vahe, prng };
92+
Ok(WillowShellClient(client))
93+
}
94+
95+
fn generate_contribution(
96+
&mut self,
97+
data: &[ffi::DataEntryView],
98+
public_key: cxx::UniquePtr<cxx::CxxString>,
99+
nonce: &[u8],
100+
) -> Result<Vec<u8>, StatusError> {
101+
let mut plaintext_slice: HashMap<&str, &[u64]> = HashMap::new();
102+
for entry in data {
103+
let key = str::from_utf8(entry.key)
104+
.map_err(|e| status::internal(format!("Failed to parse key as UTF-8: {}", e)))?;
105+
plaintext_slice.insert(key, entry.values);
106+
}
107+
let public_key_proto = ShellAhePublicKey::parse(public_key.as_bytes())
108+
.map_err(|e| status::internal(format!("Failed to parse ShellAhePublicKey: {}", e)))?;
109+
let public_key_rust = PublicKey::from_proto(public_key_proto, &self.0.vahe)?;
110+
let message = self.0.create_client_message(&plaintext_slice, &public_key_rust, nonce)?;
111+
Ok(message
112+
.to_proto(&self.0)?
113+
.serialize()
114+
.map_err(|e| status::internal(format!("Failed to serialize ClientMessage: {}", e)))?)
115+
}
116+
}
117+
118+
/// SAFETY: `out` and `out_status_message` must not be null.
119+
unsafe fn initialize_client(
120+
config: cxx::UniquePtr<cxx::CxxString>,
121+
out: *mut *mut WillowShellClient,
122+
out_status_message: *mut cxx::UniquePtr<cxx::CxxString>,
123+
) -> i32 {
124+
match WillowShellClient::new_from_serialized_config(config) {
125+
Ok(client) => {
126+
*out = Box::into_raw(Box::new(client));
127+
0
128+
}
129+
Err(status_error) => {
130+
let ffi_status: FfiStatus = status_error.into();
131+
*out_status_message = ffi_status.message;
132+
ffi_status.code
133+
}
134+
}
135+
}
136+
137+
/// Converts a raw pointer to a Box. Ideally we would use `rust::Box::from_raw`
138+
/// (https://cxx.rs/binding/box.html) directly from C++, but that causes linker errors.
139+
///
140+
/// SAFETY: `ptr` must have been created by `Box::into_raw`, as in `initialize_client`.
141+
unsafe fn client_into_box(ptr: *mut WillowShellClient) -> Box<WillowShellClient> {
142+
Box::from_raw(ptr)
143+
}
144+
145+
/// SAFETY: `out` and `out_status_message` must not be null.
146+
unsafe fn generate_contribution(
147+
client: &mut Box<WillowShellClient>,
148+
data: &[ffi::DataEntryView],
149+
public_key: cxx::UniquePtr<cxx::CxxString>,
150+
nonce: &[u8],
151+
out: *mut Vec<u8>,
152+
out_status_message: *mut cxx::UniquePtr<cxx::CxxString>,
153+
) -> i32 {
154+
match client.generate_contribution(data, public_key, nonce) {
155+
Ok(contribution) => {
156+
*out = contribution;
157+
0
158+
}
159+
Err(status_error) => {
160+
let ffi_status: FfiStatus = status_error.into();
161+
*out_status_message = ffi_status.message;
162+
ffi_status.code
163+
}
164+
}
165+
}

0 commit comments

Comments
 (0)