Skip to content

Commit 36c698c

Browse files
author
jul-sh
authored
Stop cloning additional_info into each ServerHandshaker (#2574)
* Do not clone additional info into each ServerHandshaker. This is important, as with unary request the runtime will keep a cache of ServerHandshaker instances * Pass a reference to handshaker steps, only cloning in the steps that req it * obey cargo clippy: use slices over Vecs * Revert previous approach of passing reference per method call * Use reference counting for cheaper clones * Update unary attestation * Update server_verifier * explicitly get reference for clarity * Store additional_info vec in SessionTracker to remove excess cloning
1 parent cc7a6b3 commit 36c698c

File tree

7 files changed

+32
-26
lines changed

7 files changed

+32
-26
lines changed

grpc_streaming_attestation/src/server.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use anyhow::Context;
2424
use futures::{Stream, StreamExt};
2525
use oak_remote_attestation::handshaker::{AttestationBehavior, Encryptor, ServerHandshaker};
2626
use oak_utils::LogError;
27-
use std::pin::Pin;
27+
use std::{pin::Pin, sync::Arc};
2828
use tonic::{Request, Response, Status, Streaming};
2929

3030
/// Handler for subsequent encrypted requests from the stream after the handshake is completed.
@@ -93,7 +93,7 @@ pub struct AttestationServer<F, L: LogError> {
9393
/// Processes data from client requests and creates responses.
9494
request_handler: F,
9595
/// Configuration information to provide to the client for the attestation step.
96-
additional_info: Vec<u8>,
96+
additional_info: Arc<Vec<u8>>,
9797
/// Error logging function that is required for logging attestation protocol errors.
9898
/// Errors are only logged on server side and are not sent to clients.
9999
error_logger: L,
@@ -114,7 +114,7 @@ where
114114
Ok(Self {
115115
tee_certificate,
116116
request_handler,
117-
additional_info,
117+
additional_info: Arc::new(additional_info),
118118
error_logger,
119119
})
120120
}
@@ -136,8 +136,8 @@ where
136136
) -> Result<Response<Self::StreamStream>, Status> {
137137
let tee_certificate = self.tee_certificate.clone();
138138
let request_handler = self.request_handler.clone();
139-
let additional_info = self.additional_info.clone();
140139
let error_logger = self.error_logger.clone();
140+
let additional_info = self.additional_info.clone();
141141

142142
let response_stream = async_stream::try_stream! {
143143
let mut request_stream = request_stream.into_inner();
@@ -148,7 +148,7 @@ where
148148
error_logger.log_error(&format!("Couldn't create self attestation behavior: {:?}", error));
149149
Status::internal("")
150150
})?,
151-
additional_info,
151+
additional_info
152152
);
153153
while !handshaker.is_completed() {
154154
let incoming_message = request_stream.next()

grpc_unary_attestation/src/server.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ use crate::{
2424
use lru::LruCache;
2525
use oak_remote_attestation::handshaker::{AttestationBehavior, Encryptor, ServerHandshaker};
2626
use oak_utils::LogError;
27-
use std::{convert::TryInto, sync::Mutex};
27+
use std::{
28+
convert::TryInto,
29+
sync::{Arc, Mutex},
30+
};
2831
use tonic;
2932

3033
enum SessionState {
@@ -38,7 +41,7 @@ struct SessionTracker {
3841
/// PEM encoded X.509 certificate that signs TEE firmware key.
3942
tee_certificate: Vec<u8>,
4043
/// Configuration information to provide to the client for the attestation step.
41-
additional_info: Vec<u8>,
44+
additional_info: Arc<Vec<u8>>,
4245
known_sessions: LruCache<SessionId, SessionState>,
4346
}
4447

@@ -50,7 +53,7 @@ impl SessionTracker {
5053
let known_sessions = LruCache::new(SESSIONS_CACHE_SIZE);
5154
Self {
5255
tee_certificate,
53-
additional_info,
56+
additional_info: Arc::new(additional_info),
5457
known_sessions,
5558
}
5659
}

oak_functions/client/rust/src/attestation.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ pub fn into_server_identity_verifier(
3030
config_verifier: ConfigurationVerifier,
3131
) -> ServerIdentityVerifier {
3232
let server_verifier = move |server_identity: ServerIdentity| -> anyhow::Result<()> {
33-
let config = ConfigurationInfo::decode(server_identity.additional_info.as_ref())?;
33+
let config =
34+
ConfigurationInfo::decode(server_identity.additional_info.as_ref().as_slice())?;
3435
// TODO(#2347): Check that ConfigurationInfo does not have additional/unknown fields.
3536
config_verifier(config)?;
3637
// TODO(#2316): Verify proof of inclusion in Rekor.

remote_attestation/rust/src/handshaker.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::{
3434
},
3535
proto::{AttestationInfo, AttestationReport},
3636
};
37-
use alloc::{boxed::Box, vec, vec::Vec};
37+
use alloc::{boxed::Box, sync::Arc, vec, vec::Vec};
3838
use anyhow::{anyhow, Context};
3939
use prost::Message;
4040

@@ -331,13 +331,13 @@ pub struct ServerHandshaker {
331331
transcript: Transcript,
332332
/// Additional info about the server, including configuration information and proof of
333333
/// inclusion in a verifiable log.
334-
additional_info: Vec<u8>,
334+
additional_info: Arc<Vec<u8>>,
335335
}
336336

337337
impl ServerHandshaker {
338338
/// Creates [`ServerHandshaker`] with `ServerHandshakerState::ExpectingClientIdentity`
339339
/// state.
340-
pub fn new(behavior: AttestationBehavior, additional_info: Vec<u8>) -> Self {
340+
pub fn new(behavior: AttestationBehavior, additional_info: Arc<Vec<u8>>) -> Self {
341341
Self {
342342
behavior,
343343
state: ServerHandshakerState::ExpectingClientHello,
@@ -448,9 +448,8 @@ impl ServerHandshaker {
448448
.as_ref()
449449
.context("Couldn't get TEE certificate")?;
450450

451-
let additional_info = self.additional_info.clone();
452451
let attestation_info =
453-
create_attestation_info(signer, additional_info.as_ref(), tee_certificate)
452+
create_attestation_info(signer, self.additional_info.as_ref(), tee_certificate)
454453
.context("Couldn't get attestation info")?;
455454

456455
let mut server_identity = ServerIdentity::new(
@@ -460,7 +459,7 @@ impl ServerHandshaker {
460459
.public_key()
461460
.context("Couldn't get singing public key")?,
462461
attestation_info,
463-
additional_info,
462+
self.additional_info.clone(),
464463
);
465464

466465
// Update current transcript.
@@ -487,7 +486,7 @@ impl ServerHandshaker {
487486
// Attestation info.
488487
vec![],
489488
// Additional info.
490-
vec![],
489+
Arc::new(vec![]),
491490
)
492491
};
493492

remote_attestation/rust/src/message.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::crypto::{
2525
KEY_AGREEMENT_ALGORITHM_KEY_LENGTH, NONCE_LENGTH, SIGNATURE_LENGTH,
2626
SIGNING_ALGORITHM_KEY_LENGTH,
2727
};
28-
use alloc::vec::Vec;
28+
use alloc::{sync::Arc, vec::Vec};
2929
use anyhow::{anyhow, bail, Context};
3030
use bytes::{Buf, BufMut};
3131

@@ -124,7 +124,7 @@ pub struct ServerIdentity {
124124
///
125125
/// The server and the client must be able to agree on a canonical representation of the
126126
/// content to be able to deterministically compute the hash of this field.
127-
pub additional_info: Vec<u8>,
127+
pub additional_info: Arc<Vec<u8>>,
128128
}
129129

130130
/// Client identity message containing remote attestation information and a public key for
@@ -222,7 +222,7 @@ impl ServerIdentity {
222222
random: [u8; REPLAY_PROTECTION_ARRAY_LENGTH],
223223
signing_public_key: [u8; SIGNING_ALGORITHM_KEY_LENGTH],
224224
attestation_info: Vec<u8>,
225-
additional_info: Vec<u8>,
225+
additional_info: Arc<Vec<u8>>,
226226
) -> Self {
227227
Self {
228228
version: PROTOCOL_VERSION,
@@ -302,7 +302,7 @@ impl Deserializable for ServerIdentity {
302302
let mut signing_public_key = [0u8; SIGNING_ALGORITHM_KEY_LENGTH];
303303
input.copy_to_slice(&mut signing_public_key);
304304
let attestation_info = get_vec(&mut input)?;
305-
let additional_info = get_vec(&mut input)?;
305+
let additional_info = Arc::new(get_vec(&mut input)?);
306306

307307
if input.has_remaining() {
308308
bail!(

remote_attestation/rust/src/tests/handshaker.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::{
2222
},
2323
tests::message::INVALID_MESSAGE_HEADER,
2424
};
25-
use alloc::{boxed::Box, vec};
25+
use alloc::{boxed::Box, sync::Arc, vec};
2626
use assert_matches::assert_matches;
2727

2828
const TEE_MEASUREMENT: &str = "Test TEE measurement";
@@ -48,7 +48,8 @@ fn create_handshakers() -> (ClientHandshaker, ServerHandshaker) {
4848
.unwrap();
4949

5050
let additional_info = br"Additional Info".to_vec();
51-
let server_handshaker = ServerHandshaker::new(bidirectional_attestation, additional_info);
51+
let server_handshaker =
52+
ServerHandshaker::new(bidirectional_attestation, Arc::new(additional_info));
5253

5354
(client_handshaker, server_handshaker)
5455
}

remote_attestation/rust/src/tests/message.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::{
2525
MAXIMUM_MESSAGE_SIZE, REPLAY_PROTECTION_ARRAY_LENGTH, SERVER_IDENTITY_HEADER,
2626
},
2727
};
28-
use alloc::{vec, vec::Vec};
28+
use alloc::{sync::Arc, vec, vec::Vec};
2929
use anyhow::{anyhow, Context};
3030
use assert_matches::assert_matches;
3131
use quickcheck::{quickcheck, TestResult};
@@ -96,7 +96,7 @@ fn test_serialize_server_identity() {
9696
transcript_signature: Vec<u8>,
9797
signing_public_key: Vec<u8>,
9898
attestation_info: Vec<u8>,
99-
additional_info: Vec<u8>,
99+
additional_info: Arc<Vec<u8>>,
100100
) -> TestResult {
101101
if ephemeral_public_key.len() > KEY_AGREEMENT_ALGORITHM_KEY_LENGTH
102102
|| random.len() > REPLAY_PROTECTION_ARRAY_LENGTH
@@ -131,7 +131,9 @@ fn test_serialize_server_identity() {
131131
assert!(result.is_ok());
132132
TestResult::from_bool(result.unwrap())
133133
}
134-
quickcheck(property as fn(Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>) -> TestResult);
134+
quickcheck(
135+
property as fn(Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, Arc<Vec<u8>>) -> TestResult,
136+
);
135137
}
136138

137139
#[test]
@@ -202,7 +204,7 @@ fn test_deserialize_message() {
202204
default_array(),
203205
default_array(),
204206
vec![],
205-
vec![],
207+
Arc::new(vec![]),
206208
);
207209
let deserialized_server_identity = deserialize_message(&server_identity.serialize().unwrap());
208210
assert_matches!(deserialized_server_identity, Ok(_));

0 commit comments

Comments
 (0)