Skip to content

Stop cloning additional_info into each ServerHandshaker #2574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

10 changes: 5 additions & 5 deletions grpc_streaming_attestation/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use anyhow::Context;
use futures::{Stream, StreamExt};
use oak_remote_attestation::handshaker::{AttestationBehavior, Encryptor, ServerHandshaker};
use oak_utils::LogError;
use std::pin::Pin;
use std::{pin::Pin, sync::Arc};
use tonic::{Request, Response, Status, Streaming};

/// Handler for subsequent encrypted requests from the stream after the handshake is completed.
Expand Down Expand Up @@ -93,7 +93,7 @@ pub struct AttestationServer<F, L: LogError> {
/// Processes data from client requests and creates responses.
request_handler: F,
/// Configuration information to provide to the client for the attestation step.
additional_info: Vec<u8>,
additional_info: Arc<Vec<u8>>,
/// Error logging function that is required for logging attestation protocol errors.
/// Errors are only logged on server side and are not sent to clients.
error_logger: L,
Expand All @@ -114,7 +114,7 @@ where
Ok(Self {
tee_certificate,
request_handler,
additional_info,
additional_info: Arc::new(additional_info),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we prefer a let declaration over instantiating the Arc in the struct instantiation?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine to me.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we prefer a let declaration over instantiating the Arc in the struct instantiation?

error_logger,
})
}
Expand All @@ -136,8 +136,8 @@ where
) -> Result<Response<Self::StreamStream>, Status> {
let tee_certificate = self.tee_certificate.clone();
let request_handler = self.request_handler.clone();
let additional_info = self.additional_info.clone();
let error_logger = self.error_logger.clone();
let additional_info = self.additional_info.clone();

let response_stream = async_stream::try_stream! {
let mut request_stream = request_stream.into_inner();
Expand All @@ -148,7 +148,7 @@ where
error_logger.log_error(&format!("Couldn't create self attestation behavior: {:?}", error));
Status::internal("")
})?,
additional_info,
additional_info
);
while !handshaker.is_completed() {
let incoming_message = request_stream.next()
Expand Down
9 changes: 6 additions & 3 deletions grpc_unary_attestation/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ use crate::{
use lru::LruCache;
use oak_remote_attestation::handshaker::{AttestationBehavior, Encryptor, ServerHandshaker};
use oak_utils::LogError;
use std::{convert::TryInto, sync::Mutex};
use std::{
convert::TryInto,
sync::{Arc, Mutex},
};
use tonic;

enum SessionState {
Expand All @@ -38,7 +41,7 @@ struct SessionTracker {
/// PEM encoded X.509 certificate that signs TEE firmware key.
tee_certificate: Vec<u8>,
/// Configuration information to provide to the client for the attestation step.
additional_info: Vec<u8>,
additional_info: Arc<Vec<u8>>,
known_sessions: LruCache<SessionId, SessionState>,
}

Expand All @@ -50,7 +53,7 @@ impl SessionTracker {
let known_sessions = LruCache::new(SESSIONS_CACHE_SIZE);
Self {
tee_certificate,
additional_info,
additional_info: Arc::new(additional_info),
known_sessions,
}
}
Expand Down
3 changes: 2 additions & 1 deletion oak_functions/client/rust/src/attestation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ pub fn into_server_identity_verifier(
config_verifier: ConfigurationVerifier,
) -> ServerIdentityVerifier {
let server_verifier = move |server_identity: ServerIdentity| -> anyhow::Result<()> {
let config = ConfigurationInfo::decode(server_identity.additional_info.as_ref())?;
let config =
ConfigurationInfo::decode(server_identity.additional_info.as_ref().as_slice())?;
// TODO(#2347): Check that ConfigurationInfo does not have additional/unknown fields.
config_verifier(config)?;
// TODO(#2316): Verify proof of inclusion in Rekor.
Expand Down
13 changes: 6 additions & 7 deletions remote_attestation/rust/src/handshaker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::{
},
proto::{AttestationInfo, AttestationReport},
};
use alloc::{boxed::Box, vec, vec::Vec};
use alloc::{boxed::Box, sync::Arc, vec, vec::Vec};
use anyhow::{anyhow, Context};
use prost::Message;

Expand Down Expand Up @@ -331,13 +331,13 @@ pub struct ServerHandshaker {
transcript: Transcript,
/// Additional info about the server, including configuration information and proof of
/// inclusion in a verifiable log.
additional_info: Vec<u8>,
additional_info: Arc<Vec<u8>>,
}

impl ServerHandshaker {
/// Creates [`ServerHandshaker`] with `ServerHandshakerState::ExpectingClientIdentity`
/// state.
pub fn new(behavior: AttestationBehavior, additional_info: Vec<u8>) -> Self {
pub fn new(behavior: AttestationBehavior, additional_info: Arc<Vec<u8>>) -> Self {
Self {
behavior,
state: ServerHandshakerState::ExpectingClientHello,
Expand Down Expand Up @@ -448,9 +448,8 @@ impl ServerHandshaker {
.as_ref()
.context("Couldn't get TEE certificate")?;

let additional_info = self.additional_info.clone();
let attestation_info =
create_attestation_info(signer, additional_info.as_ref(), tee_certificate)
create_attestation_info(signer, self.additional_info.as_ref(), tee_certificate)
.context("Couldn't get attestation info")?;

let mut server_identity = ServerIdentity::new(
Expand All @@ -460,7 +459,7 @@ impl ServerHandshaker {
.public_key()
.context("Couldn't get singing public key")?,
attestation_info,
additional_info,
self.additional_info.clone(),
);

// Update current transcript.
Expand All @@ -487,7 +486,7 @@ impl ServerHandshaker {
// Attestation info.
vec![],
// Additional info.
vec![],
Arc::new(vec![]),
)
};

Expand Down
8 changes: 4 additions & 4 deletions remote_attestation/rust/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::crypto::{
KEY_AGREEMENT_ALGORITHM_KEY_LENGTH, NONCE_LENGTH, SIGNATURE_LENGTH,
SIGNING_ALGORITHM_KEY_LENGTH,
};
use alloc::vec::Vec;
use alloc::{sync::Arc, vec::Vec};
use anyhow::{anyhow, bail, Context};
use bytes::{Buf, BufMut};

Expand Down Expand Up @@ -124,7 +124,7 @@ pub struct ServerIdentity {
///
/// The server and the client must be able to agree on a canonical representation of the
/// content to be able to deterministically compute the hash of this field.
pub additional_info: Vec<u8>,
pub additional_info: Arc<Vec<u8>>,
}

/// Client identity message containing remote attestation information and a public key for
Expand Down Expand Up @@ -222,7 +222,7 @@ impl ServerIdentity {
random: [u8; REPLAY_PROTECTION_ARRAY_LENGTH],
signing_public_key: [u8; SIGNING_ALGORITHM_KEY_LENGTH],
attestation_info: Vec<u8>,
additional_info: Vec<u8>,
additional_info: Arc<Vec<u8>>,
) -> Self {
Self {
version: PROTOCOL_VERSION,
Expand Down Expand Up @@ -302,7 +302,7 @@ impl Deserializable for ServerIdentity {
let mut signing_public_key = [0u8; SIGNING_ALGORITHM_KEY_LENGTH];
input.copy_to_slice(&mut signing_public_key);
let attestation_info = get_vec(&mut input)?;
let additional_info = get_vec(&mut input)?;
let additional_info = Arc::new(get_vec(&mut input)?);

if input.has_remaining() {
bail!(
Expand Down
5 changes: 3 additions & 2 deletions remote_attestation/rust/src/tests/handshaker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{
},
tests::message::INVALID_MESSAGE_HEADER,
};
use alloc::{boxed::Box, vec};
use alloc::{boxed::Box, sync::Arc, vec};
use assert_matches::assert_matches;

const TEE_MEASUREMENT: &str = "Test TEE measurement";
Expand All @@ -48,7 +48,8 @@ fn create_handshakers() -> (ClientHandshaker, ServerHandshaker) {
.unwrap();

let additional_info = br"Additional Info".to_vec();
let server_handshaker = ServerHandshaker::new(bidirectional_attestation, additional_info);
let server_handshaker =
ServerHandshaker::new(bidirectional_attestation, Arc::new(additional_info));

(client_handshaker, server_handshaker)
}
Expand Down
10 changes: 6 additions & 4 deletions remote_attestation/rust/src/tests/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{
MAXIMUM_MESSAGE_SIZE, REPLAY_PROTECTION_ARRAY_LENGTH, SERVER_IDENTITY_HEADER,
},
};
use alloc::{vec, vec::Vec};
use alloc::{sync::Arc, vec, vec::Vec};
use anyhow::{anyhow, Context};
use assert_matches::assert_matches;
use quickcheck::{quickcheck, TestResult};
Expand Down Expand Up @@ -96,7 +96,7 @@ fn test_serialize_server_identity() {
transcript_signature: Vec<u8>,
signing_public_key: Vec<u8>,
attestation_info: Vec<u8>,
additional_info: Vec<u8>,
additional_info: Arc<Vec<u8>>,
) -> TestResult {
if ephemeral_public_key.len() > KEY_AGREEMENT_ALGORITHM_KEY_LENGTH
|| random.len() > REPLAY_PROTECTION_ARRAY_LENGTH
Expand Down Expand Up @@ -131,7 +131,9 @@ fn test_serialize_server_identity() {
assert!(result.is_ok());
TestResult::from_bool(result.unwrap())
}
quickcheck(property as fn(Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>) -> TestResult);
quickcheck(
property as fn(Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, Arc<Vec<u8>>) -> TestResult,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this test do?

);
}

#[test]
Expand Down Expand Up @@ -202,7 +204,7 @@ fn test_deserialize_message() {
default_array(),
default_array(),
vec![],
vec![],
Arc::new(vec![]),
);
let deserialized_server_identity = deserialize_message(&server_identity.serialize().unwrap());
assert_matches!(deserialized_server_identity, Ok(_));
Expand Down