1818#include < memory>
1919#include < string>
2020#include < utility>
21+ #include < vector>
2122
2223#include " absl/memory/memory.h"
2324#include " absl/status/status.h"
2627#include " willow/proto/willow/aggregation_config.pb.h"
2728#include " willow/proto/willow/server_accumulator.pb.h"
2829#include " willow/src/api/server_accumulator.rs.h"
30+ #include " willow/src/input_encoding/codec.h"
2931
3032namespace secure_aggregation {
33+ namespace willow {
3134
32- absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>>
33- WillowShellServerAccumulator::Create (
34- const willow::AggregationConfigProto& aggregation_config) {
35+ absl::StatusOr<std::unique_ptr<ServerAccumulator>> ServerAccumulator::Create (
36+ const AggregationConfigProto& aggregation_config) {
3537 secure_aggregation::ServerAccumulator* out;
3638 std::unique_ptr<std::string> status_message;
3739 int status_code =
@@ -41,12 +43,11 @@ WillowShellServerAccumulator::Create(
4143 if (status_code != 0 ) {
4244 return absl::Status (absl::StatusCode (status_code), *status_message);
4345 }
44- return absl::WrapUnique (new WillowShellServerAccumulator (IntoBox (out)));
46+ return absl::WrapUnique (new ServerAccumulator (IntoBox (out)));
4547}
4648
47- absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>>
48- WillowShellServerAccumulator::CreateFromSerializedState (
49- std::string serialized_state) {
49+ absl::StatusOr<std::unique_ptr<ServerAccumulator>>
50+ ServerAccumulator::CreateFromSerializedState (std::string serialized_state) {
5051 secure_aggregation::ServerAccumulator* out;
5152 std::unique_ptr<std::string> status_message;
5253 int status_code = secure_aggregation::NewServerAccumulatorFromSerializedState (
@@ -55,17 +56,17 @@ WillowShellServerAccumulator::CreateFromSerializedState(
5556 if (status_code != 0 ) {
5657 return absl::Status (absl::StatusCode (status_code), *status_message);
5758 }
58- return absl::WrapUnique (new WillowShellServerAccumulator (IntoBox (out)));
59+ return absl::WrapUnique (new ServerAccumulator (IntoBox (out)));
5960}
6061
61- absl::Status WillowShellServerAccumulator ::ProcessClientMessages (
62- willow:: ClientMessageRange client_messages) {
62+ absl::Status ServerAccumulator ::ProcessClientMessages (
63+ ClientMessageRange client_messages) {
6364 auto serialized_client_messages = client_messages.SerializeAsString ();
6465 client_messages.Clear ();
6566 return ProcessClientMessages (std::move (serialized_client_messages));
6667}
6768
68- absl::Status WillowShellServerAccumulator ::ProcessClientMessages (
69+ absl::Status ServerAccumulator ::ProcessClientMessages (
6970 std::string serialized_client_messages) {
7071 std::unique_ptr<std::string> status_message;
7172 int status_code = accumulator_->ProcessClientMessages (
@@ -77,8 +78,8 @@ absl::Status WillowShellServerAccumulator::ProcessClientMessages(
7778 return absl::OkStatus ();
7879}
7980
80- absl::Status WillowShellServerAccumulator ::Merge (
81- std::unique_ptr<WillowShellServerAccumulator > other) {
81+ absl::Status ServerAccumulator ::Merge (
82+ std::unique_ptr<ServerAccumulator > other) {
8283 std::unique_ptr<std::string> status_message;
8384 int status_code =
8485 accumulator_->Merge (std::move (other->accumulator_ ), &status_message);
@@ -88,7 +89,7 @@ absl::Status WillowShellServerAccumulator::Merge(
8889 return absl::OkStatus ();
8990}
9091
91- absl::StatusOr<std::string> WillowShellServerAccumulator ::ToSerializedState () {
92+ absl::StatusOr<std::string> ServerAccumulator ::ToSerializedState () {
9293 rust::Vec<uint8_t > serialized_state;
9394 std::unique_ptr<std::string> status_message;
9495 int status_code =
@@ -100,4 +101,70 @@ absl::StatusOr<std::string> WillowShellServerAccumulator::ToSerializedState() {
100101 serialized_state.size ());
101102}
102103
104+ absl::StatusOr<FinalizedAccumulatorResult> ServerAccumulator::Finalize () && {
105+ // Finalize accumulator in Rust and store the serialized results.
106+ rust::Vec<uint8_t > decryption_request;
107+ rust::Vec<uint8_t > final_result_decryptor_state;
108+ std::unique_ptr<std::string> status_message;
109+ int status_code = secure_aggregation::FinalizeServerAccumulator (
110+ std::move (accumulator_), &decryption_request,
111+ &final_result_decryptor_state, &status_message);
112+ if (status_code != 0 ) {
113+ return absl::Status (absl::StatusCode (status_code), *status_message);
114+ }
115+
116+ // Pack the two serialized results into a single proto.
117+ FinalizedAccumulatorResult result_proto;
118+ result_proto.set_decryption_request (
119+ std::string (reinterpret_cast <const char *>(decryption_request.data ()),
120+ decryption_request.size ()));
121+ result_proto.set_final_result_decryptor_state (std::string (
122+ reinterpret_cast <const char *>(final_result_decryptor_state.data ()),
123+ final_result_decryptor_state.size ()));
124+
125+ return result_proto;
126+ }
127+
128+ absl::StatusOr<std::unique_ptr<FinalResultDecryptor>>
129+ FinalResultDecryptor::CreateFromSerialized (
130+ std::string final_result_decryptor_state) {
131+ secure_aggregation::FinalResultDecryptor* out;
132+ std::unique_ptr<std::string> status_message;
133+ int status_code =
134+ secure_aggregation::CreateFinalResultDecryptorFromSerialized (
135+ std::make_unique<std::string>(
136+ std::move (final_result_decryptor_state)),
137+ &out, &status_message);
138+ if (status_code != 0 ) {
139+ return absl::Status (absl::StatusCode (status_code), *status_message);
140+ }
141+ return absl::WrapUnique (new FinalResultDecryptor (
142+ secure_aggregation::FinalResultDecryptorIntoBox (out)));
143+ }
144+
145+ absl::StatusOr<EncodedData> FinalResultDecryptor::Decrypt (
146+ std::string serialized_partial_decryption_response) {
147+ rust::Vec<EncodedDataEntry> out;
148+ std::unique_ptr<std::string> status_message;
149+ int status_code = aggregated_ciphertexts_->Decrypt (
150+ std::make_unique<std::string>(
151+ std::move (serialized_partial_decryption_response)),
152+ &out, &status_message);
153+ if (status_code != 0 ) {
154+ return absl::Status (absl::StatusCode (status_code), *status_message);
155+ }
156+ EncodedData encoded_data;
157+ for (const auto & rust_entry : out) {
158+ std::string key (rust_entry.key );
159+ std::vector<int64_t > val;
160+ val.reserve (rust_entry.values .size ());
161+ for (auto v : rust_entry.values ) {
162+ val.push_back (static_cast <int64_t >(v));
163+ }
164+ encoded_data[std::move (key)] = std::move (val);
165+ }
166+ return encoded_data;
167+ }
168+
169+ } // namespace willow
103170} // namespace secure_aggregation
0 commit comments