Skip to content

Commit d83bdcb

Browse files
tholopcopybara-github
authored andcommitted
Add Finalize and Decrypt functions to ServerAccumulator, expand ShellTestingDecryptor to handle partial decryption requests.
PiperOrigin-RevId: 858564313
1 parent 48e28b1 commit d83bdcb

13 files changed

+557
-81
lines changed

willow/proto/willow/messages.proto

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,15 @@ message VerifierStateProto {
6565
bytes nonce_lower_bound = 2;
6666
bytes nonce_upper_bound = 3;
6767
}
68+
69+
// Result of finalizing an accumulator.
70+
message FinalizedAccumulatorResult {
71+
// Serialized decryption request to include in a DecryptRequest to be sent to
72+
// the decryptor service.
73+
bytes decryption_request = 1;
74+
75+
// Serialized state for creating a final result decryptor, which will handle
76+
// the response from the decryptor service and produce a plaintext aggregation
77+
// result.
78+
bytes final_result_decryptor_state = 2;
79+
}

willow/proto/willow/server_accumulator.proto

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ message ServerAccumulatorState {
2626
ServerStateProto server_state = 1;
2727
AggregationConfigProto aggregation_config = 3;
2828
// We have one verifier state per range of nonces processed by this
29-
// accumulator. States gat merged when adjacent ranges are processed.
29+
// accumulator. States get merged when adjacent ranges are processed.
3030
repeated VerifierStateProto verifier_states = 2;
3131
// The ranges of nonces processed by this accumulator. In the same order as
3232
// the corresponding verifier states.
@@ -43,3 +43,9 @@ message NonceRange {
4343
bytes start = 1; // Inclusive.
4444
bytes end = 2; // Exclusive.
4545
}
46+
47+
// State for creating a FinalResultDecryptor.
48+
message FinalResultDecryptorState {
49+
ServerStateProto server_state = 1;
50+
AggregationConfigProto aggregation_config = 2;
51+
}

willow/src/api/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ cc_library(
6161
"@abseil-cpp//absl/strings",
6262
"@cxx.rs//:core",
6363
"//willow/proto/willow:aggregation_config_cc_proto",
64+
"//willow/proto/willow:messages_cc_proto",
6465
"//willow/proto/willow:server_accumulator_cc_proto",
66+
"//willow/src/input_encoding:codec",
6567
],
6668
)
6769

@@ -77,6 +79,7 @@ cc_test(
7779
"@abseil-cpp//absl/status:statusor",
7880
"//shell_wrapper:status_matchers",
7981
"//willow/proto/willow:aggregation_config_cc_proto",
82+
"//willow/proto/willow:messages_cc_proto",
8083
"//willow/proto/willow:server_accumulator_cc_proto",
8184
"//willow/src/input_encoding:codec",
8285
"//willow/src/testing_utils:shell_testing_decryptor_cc",
@@ -102,6 +105,7 @@ rust_library(
102105
"//shell_wrapper:status",
103106
"//willow/proto/willow:aggregation_config_rust_proto",
104107
"//willow/proto/willow:server_accumulator_rust_proto",
108+
"//willow/proto/willow:messages_rust_proto",
105109
"//willow/src/shell:kahe_shell",
106110
"//willow/src/shell:parameters_shell",
107111
"//willow/src/shell:vahe_shell",

willow/src/api/client_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ namespace willow {
3535
namespace {
3636

3737
using secure_aggregation::secagg_internal::StatusIs;
38+
using secure_aggregation::testing::ShellTestingDecryptor;
3839
using ::testing::ElementsAre;
40+
using ::testing::ElementsAreArray;
3941
using ::testing::Pair;
4042
using ::testing::UnorderedElementsAre;
4143

@@ -108,7 +110,7 @@ TEST(WillowShellClientTest, InitializeAndGenerateContribution) {
108110
for (const auto& [name, values] : encoded_data) {
109111
EXPECT_TRUE(decrypted_encoded_data.contains(name));
110112
const auto& decrypted_values = decrypted_encoded_data[name];
111-
EXPECT_THAT(decrypted_values, testing::ElementsAreArray(values));
113+
EXPECT_THAT(decrypted_values, ElementsAreArray(values));
112114
}
113115

114116
// Decode decrypted data.

willow/src/api/server_accumulator.cc

Lines changed: 81 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <memory>
1919
#include <string>
2020
#include <utility>
21+
#include <vector>
2122

2223
#include "absl/memory/memory.h"
2324
#include "absl/status/status.h"
@@ -26,12 +27,13 @@
2627
#include "willow/proto/willow/aggregation_config.pb.h"
2728
#include "willow/proto/willow/server_accumulator.pb.h"
2829
#include "willow/src/api/server_accumulator.rs.h"
30+
#include "willow/src/input_encoding/codec.h"
2931

3032
namespace secure_aggregation {
33+
namespace willow {
3134

32-
absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>>
33-
WillowShellServerAccumulator::Create(
34-
const willow::AggregationConfigProto& aggregation_config) {
35+
absl::StatusOr<std::unique_ptr<ServerAccumulator>> ServerAccumulator::Create(
36+
const AggregationConfigProto& aggregation_config) {
3537
secure_aggregation::ServerAccumulator* out;
3638
std::unique_ptr<std::string> status_message;
3739
int status_code =
@@ -41,12 +43,11 @@ WillowShellServerAccumulator::Create(
4143
if (status_code != 0) {
4244
return absl::Status(absl::StatusCode(status_code), *status_message);
4345
}
44-
return absl::WrapUnique(new WillowShellServerAccumulator(IntoBox(out)));
46+
return absl::WrapUnique(new ServerAccumulator(IntoBox(out)));
4547
}
4648

47-
absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>>
48-
WillowShellServerAccumulator::CreateFromSerializedState(
49-
std::string serialized_state) {
49+
absl::StatusOr<std::unique_ptr<ServerAccumulator>>
50+
ServerAccumulator::CreateFromSerializedState(std::string serialized_state) {
5051
secure_aggregation::ServerAccumulator* out;
5152
std::unique_ptr<std::string> status_message;
5253
int status_code = secure_aggregation::NewServerAccumulatorFromSerializedState(
@@ -55,17 +56,17 @@ WillowShellServerAccumulator::CreateFromSerializedState(
5556
if (status_code != 0) {
5657
return absl::Status(absl::StatusCode(status_code), *status_message);
5758
}
58-
return absl::WrapUnique(new WillowShellServerAccumulator(IntoBox(out)));
59+
return absl::WrapUnique(new ServerAccumulator(IntoBox(out)));
5960
}
6061

61-
absl::Status WillowShellServerAccumulator::ProcessClientMessages(
62-
willow::ClientMessageRange client_messages) {
62+
absl::Status ServerAccumulator::ProcessClientMessages(
63+
ClientMessageRange client_messages) {
6364
auto serialized_client_messages = client_messages.SerializeAsString();
6465
client_messages.Clear();
6566
return ProcessClientMessages(std::move(serialized_client_messages));
6667
}
6768

68-
absl::Status WillowShellServerAccumulator::ProcessClientMessages(
69+
absl::Status ServerAccumulator::ProcessClientMessages(
6970
std::string serialized_client_messages) {
7071
std::unique_ptr<std::string> status_message;
7172
int status_code = accumulator_->ProcessClientMessages(
@@ -77,8 +78,8 @@ absl::Status WillowShellServerAccumulator::ProcessClientMessages(
7778
return absl::OkStatus();
7879
}
7980

80-
absl::Status WillowShellServerAccumulator::Merge(
81-
std::unique_ptr<WillowShellServerAccumulator> other) {
81+
absl::Status ServerAccumulator::Merge(
82+
std::unique_ptr<ServerAccumulator> other) {
8283
std::unique_ptr<std::string> status_message;
8384
int status_code =
8485
accumulator_->Merge(std::move(other->accumulator_), &status_message);
@@ -88,7 +89,7 @@ absl::Status WillowShellServerAccumulator::Merge(
8889
return absl::OkStatus();
8990
}
9091

91-
absl::StatusOr<std::string> WillowShellServerAccumulator::ToSerializedState() {
92+
absl::StatusOr<std::string> ServerAccumulator::ToSerializedState() {
9293
rust::Vec<uint8_t> serialized_state;
9394
std::unique_ptr<std::string> status_message;
9495
int status_code =
@@ -100,4 +101,70 @@ absl::StatusOr<std::string> WillowShellServerAccumulator::ToSerializedState() {
100101
serialized_state.size());
101102
}
102103

104+
absl::StatusOr<FinalizedAccumulatorResult> ServerAccumulator::Finalize() && {
105+
// Finalize accumulator in Rust and store the serialized results.
106+
rust::Vec<uint8_t> decryption_request;
107+
rust::Vec<uint8_t> final_result_decryptor_state;
108+
std::unique_ptr<std::string> status_message;
109+
int status_code = secure_aggregation::FinalizeServerAccumulator(
110+
std::move(accumulator_), &decryption_request,
111+
&final_result_decryptor_state, &status_message);
112+
if (status_code != 0) {
113+
return absl::Status(absl::StatusCode(status_code), *status_message);
114+
}
115+
116+
// Pack the two serialized results into a single proto.
117+
FinalizedAccumulatorResult result_proto;
118+
result_proto.set_decryption_request(
119+
std::string(reinterpret_cast<const char*>(decryption_request.data()),
120+
decryption_request.size()));
121+
result_proto.set_final_result_decryptor_state(std::string(
122+
reinterpret_cast<const char*>(final_result_decryptor_state.data()),
123+
final_result_decryptor_state.size()));
124+
125+
return result_proto;
126+
}
127+
128+
absl::StatusOr<std::unique_ptr<FinalResultDecryptor>>
129+
FinalResultDecryptor::CreateFromSerialized(
130+
std::string final_result_decryptor_state) {
131+
secure_aggregation::FinalResultDecryptor* out;
132+
std::unique_ptr<std::string> status_message;
133+
int status_code =
134+
secure_aggregation::CreateFinalResultDecryptorFromSerialized(
135+
std::make_unique<std::string>(
136+
std::move(final_result_decryptor_state)),
137+
&out, &status_message);
138+
if (status_code != 0) {
139+
return absl::Status(absl::StatusCode(status_code), *status_message);
140+
}
141+
return absl::WrapUnique(new FinalResultDecryptor(
142+
secure_aggregation::FinalResultDecryptorIntoBox(out)));
143+
}
144+
145+
absl::StatusOr<EncodedData> FinalResultDecryptor::Decrypt(
146+
std::string serialized_partial_decryption_response) {
147+
rust::Vec<EncodedDataEntry> out;
148+
std::unique_ptr<std::string> status_message;
149+
int status_code = aggregated_ciphertexts_->Decrypt(
150+
std::make_unique<std::string>(
151+
std::move(serialized_partial_decryption_response)),
152+
&out, &status_message);
153+
if (status_code != 0) {
154+
return absl::Status(absl::StatusCode(status_code), *status_message);
155+
}
156+
EncodedData encoded_data;
157+
for (const auto& rust_entry : out) {
158+
std::string key(rust_entry.key);
159+
std::vector<int64_t> val;
160+
val.reserve(rust_entry.values.size());
161+
for (auto v : rust_entry.values) {
162+
val.push_back(static_cast<int64_t>(v));
163+
}
164+
encoded_data[std::move(key)] = std::move(val);
165+
}
166+
return encoded_data;
167+
}
168+
169+
} // namespace willow
103170
} // namespace secure_aggregation

willow/src/api/server_accumulator.h

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,53 +23,87 @@
2323

2424
#include "absl/status/status.h"
2525
#include "absl/status/statusor.h"
26-
#include "absl/strings/string_view.h"
2726
#include "include/cxx.h"
2827
#include "willow/proto/willow/aggregation_config.pb.h"
28+
#include "willow/proto/willow/messages.pb.h"
2929
#include "willow/proto/willow/server_accumulator.pb.h"
3030
#include "willow/src/api/server_accumulator.rs.h"
31+
#include "willow/src/input_encoding/codec.h"
3132

3233
namespace secure_aggregation {
34+
namespace willow {
35+
36+
// Holds the relevant state from a finalized accumulation, and can decrypt the
37+
// final result using the response from the decryptor service. Only works with
38+
// a single decryptor, or if the response is already the sum of all partial
39+
// decryptions, since it attempts to decrypt after receiving only a single
40+
// partial decryption response.
41+
class FinalResultDecryptor {
42+
public:
43+
// Creates a new final result decryptor from the given serialized
44+
// state, likely coming from a FinalizedAccumulatorResult.
45+
static absl::StatusOr<std::unique_ptr<FinalResultDecryptor>>
46+
CreateFromSerialized(std::string final_result_decryptor_state);
47+
48+
// Decrypts final result using the given partial decryption
49+
// response.
50+
absl::StatusOr<EncodedData> Decrypt(
51+
std::string serialized_partial_decryption_response);
52+
53+
private:
54+
explicit FinalResultDecryptor(
55+
rust::Box<secure_aggregation::FinalResultDecryptor>
56+
aggregated_ciphertexts)
57+
: aggregated_ciphertexts_(std::move(aggregated_ciphertexts)) {}
58+
59+
rust::Box<secure_aggregation::FinalResultDecryptor> aggregated_ciphertexts_;
60+
};
3361

3462
// Implements an accumulator class intended to be used by a batch processing
3563
// system. Combines both the server and the verifier functionality of willow_v1,
3664
// using SHELL for the underlying cryptography.
37-
class WillowShellServerAccumulator {
65+
class ServerAccumulator {
3866
public:
3967
// Creates a new accumulator with the given aggregation_config and empty
4068
// state.
41-
static absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>> Create(
42-
const willow::AggregationConfigProto& aggregation_config);
69+
static absl::StatusOr<std::unique_ptr<ServerAccumulator>> Create(
70+
const AggregationConfigProto& aggregation_config);
4371

4472
// Creates a new accumulator from the given serialized state, which must
4573
// correspond to a serialized ServerAccumulatorState proto.
46-
static absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>>
74+
static absl::StatusOr<std::unique_ptr<ServerAccumulator>>
4775
CreateFromSerializedState(std::string serialized_state);
4876

4977
// Processes a list of client messages. If an invalid message is encountered,
5078
// an error is logged and processing continues.
51-
absl::Status ProcessClientMessages(
52-
willow::ClientMessageRange client_messages);
79+
absl::Status ProcessClientMessages(ClientMessageRange client_messages);
5380

5481
// Processes a list of client messages, given as a serialized
5582
// ClientMessageList proto.
5683
absl::Status ProcessClientMessages(std::string serialized_client_messages);
5784

5885
// Merges the state of `other` into the current accumulator.
59-
absl::Status Merge(std::unique_ptr<WillowShellServerAccumulator> other);
86+
absl::Status Merge(std::unique_ptr<ServerAccumulator> other);
6087

6188
// Converts the current state of the accumulator to a serialized
6289
// ServerAccumulatorState proto.
6390
absl::StatusOr<std::string> ToSerializedState();
6491

92+
// Finalizes the accumulator and returns a proto that holds the serialized
93+
// decryption request (to be sent to the decryptor service) and the
94+
// serialized decryptor state (to create a FinalResultDecryptor). This
95+
// consumes the accumulator.
96+
absl::StatusOr<FinalizedAccumulatorResult> Finalize() &&;
97+
6598
private:
66-
explicit WillowShellServerAccumulator(
99+
explicit ServerAccumulator(
67100
rust::Box<secure_aggregation::ServerAccumulator> accumulator)
68101
: accumulator_(std::move(accumulator)) {}
69102

70103
rust::Box<secure_aggregation::ServerAccumulator> accumulator_;
71104
};
72105

106+
} // namespace willow
73107
} // namespace secure_aggregation
74108

75109
#endif // SECURE_AGGREGATION_WILLOW_SRC_API_SERVER_ACCUMULATOR_H_

0 commit comments

Comments
 (0)