Skip to content
5 changes: 2 additions & 3 deletions grpc_attestation/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ where
.map_err(|error| {
error_logger.log_error(&format!("Couldn't create self attestation behavior: {:?}", error));
Status::internal("")
})?,
additional_info,
})?
);
while !handshaker.is_completed() {
let incoming_message = request_stream.next()
Expand All @@ -167,7 +166,7 @@ where
})?;

let outgoing_message = handshaker
.next_step(&incoming_message.body)
.next_step(&incoming_message.body, additional_info.clone())
.map_err(|error| {
error_logger.log_error(&format!("Couldn't process handshake message: {:?}", error));
Status::aborted("")
Expand Down
31 changes: 18 additions & 13 deletions remote_attestation/rust/src/handshaker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,40 +329,45 @@ pub struct ServerHandshaker {
/// Collection of previously sent and received messages.
/// Signed transcript is sent in messages to prevent replay attacks.
transcript: Transcript,
/// Additional info about the server, including configuration information and proof of
/// inclusion in a verifiable log.
additional_info: Vec<u8>,
}

impl ServerHandshaker {
/// Creates [`ServerHandshaker`] with `ServerHandshakerState::ExpectingClientIdentity`
/// state.
pub fn new(behavior: AttestationBehavior, additional_info: Vec<u8>) -> Self {
pub fn new(behavior: AttestationBehavior) -> Self {
Self {
behavior,
state: ServerHandshakerState::ExpectingClientHello,
transcript: Transcript::new(),
additional_info,
}
}

/// Processes incoming `message` and returns a serialized remote attestation message.
/// If [`None`] is returned, then no messages should be sent out to the client.
pub fn next_step(&mut self, message: &[u8]) -> anyhow::Result<Option<Vec<u8>>> {
self.next_step_util(message).map_err(|error| {
self.state = ServerHandshakerState::Aborted;
error
})
pub fn next_step(
&mut self,
message: &[u8],
additional_info: Vec<u8>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be missing something, but what determines what to pass here, and at exactly which step? AFAICT the additional info is only used at one particular step, so we still need to pass something else for all the other steps? Looking at the tests below it seems that we do keep passing the same value, but that is only actually used in one.

Also the idea of this method is that we just pipe in messages from the underlying transport, and it should move along the various steps in the state machine (and eventually produce the encryptor), so adding this argument here seems to complicate the API.

I'll let @conradgrobler and @ipetr0v weigh in too.

Another thing to consider is that, in practice, in most protocols that have some space for additional info, only a fixed amount of additional data can be passed in, usually 256 bits, usually the result of a hashing function (rather than the actual data). So maybe another approach worth considering in order to reduce the memory usage would be to make this a fixed size array, but keep it in the constructor. Would that work?

Copy link
Author

@jul-sh jul-sh Mar 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, only used some for some steps, passing it always is still not optimal.

what determines what to pass here

additional_info maps to this config. Afaik the config stays unchanged after the runtime is initiated. Hence it'd always be a clone of the same vec.

Which leads me to the suggestion I raised in the chat, and which be the optimal approach for memory: We create a single instance of the additional_info vec in the loader, and only pass down references (with a static lifetime) [Ref http://chat/room/AAAAp6czer4/-EYB2OqgJQs].

Is that safe? Or do we expect the config (which is what additional_info contains atm) to ever be changed/dropped after the runtime is started?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's safe to assume it does not change, but I'm not sure that implies it can be 'static -- maybe it can be coerced to a static value at runtime by "leaking" it? but I'm not sure that's idiomatic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also see my other comment below, maybe the fix you are looking for is just a one line change there?

Copy link
Collaborator

@conradgrobler conradgrobler Mar 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using an Arc<Vec<u8>> rather than a static reference? Cloning it should be cheap, and it would remove the need to pass the additional info into every step.

Copy link
Author

@jul-sh jul-sh Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If passing it to the method is not idiomatic, we can make it a shared reference (std::sync::Arc) indeed. That would probably be easier than declaring the lifetimes.

The only blocker is that as of #2556 the oak_remote_attestation crate is no_std compatible. This means we cannot use std::sync::Arc without breaking that.

@conradgrobler Do we plan to support shared references in no_std in the future?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use alloc::sync::Arc. We will always have an allocator, even in no_std environements.

Copy link
Author

@jul-sh jul-sh Mar 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer, will use that. Also realized that for no_std we'll probs be moving the gRPC communication logic out of the trusted runtime anyways.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we? I thought all this stuff will remain in the trusted binary, since it drives the remote attestation protocol.

Anyways, I still don't think we should change this method to accept the additional info param. Especially in the non-std, blocking model, this method should look like a call or invoke function from bytes to bytes (with some extra wrappers), so that we can compose it more nicely with the rest of the building blocks of the runtime. And at some point it should itself also accept another callback that it invokes with the decrypted data from remote attestation, which itself would follow the same paradigm.

Copy link
Author

@jul-sh jul-sh Mar 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we? I thought all this stuff will remain in the trusted binary, since it drives the remote attestation protocol.

The gRPC protocol implementation only, remote attestation logic will ofc still be performed within the trusted binary. Most of the changes in this PR do concern the "remote attestation", but tonic's (multi-threaded) gRPC server implementation is what drives the original need for cloning here. The async trait service implementation is why the compiler cannot infer lifetimes for the remote attestation code used within.

Anyways, I still don't think we should change this method to accept the additional info param

yes, we're on the same page. As covered in the prev discussion, we'll count references for now. :)

) -> anyhow::Result<Option<Vec<u8>>> {
self.next_step_util(message, additional_info)
.map_err(|error| {
self.state = ServerHandshakerState::Aborted;
error
})
}

fn next_step_util(&mut self, message: &[u8]) -> anyhow::Result<Option<Vec<u8>>> {
fn next_step_util(
&mut self,
message: &[u8],
additional_info: Vec<u8>,
) -> anyhow::Result<Option<Vec<u8>>> {
let deserialized_message =
deserialize_message(message).context("Couldn't deserialize message")?;
match deserialized_message {
MessageWrapper::ClientHello(client_hello) => match &self.state {
ServerHandshakerState::ExpectingClientHello => {
let server_identity = self
.process_client_hello(client_hello)
.process_client_hello(client_hello, additional_info)
.context("Couldn't process client hello message")?;
let serialized_server_identity = server_identity
.serialize()
Expand Down Expand Up @@ -429,6 +434,7 @@ impl ServerHandshaker {
fn process_client_hello(
&mut self,
client_hello: ClientHello,
additional_info: Vec<u8>,
) -> anyhow::Result<ServerIdentity> {
// Create server identity message.
let key_negotiator = KeyNegotiator::create(KeyNegotiatorType::Server)
Expand All @@ -448,7 +454,6 @@ impl ServerHandshaker {
.as_ref()
.context("Couldn't get TEE certificate")?;

let additional_info = self.additional_info.clone();
let attestation_info =
create_attestation_info(signer, additional_info.as_ref(), tee_certificate)
.context("Couldn't get attestation info")?;
Expand Down
34 changes: 20 additions & 14 deletions remote_attestation/rust/src/tests/handshaker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use assert_matches::assert_matches;

const TEE_MEASUREMENT: &str = "Test TEE measurement";
const DATA: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
const ADDITIONAL_INFO: [u8; 15] = *br"Additional Info";

fn create_handshakers() -> (ClientHandshaker, ServerHandshaker) {
let bidirectional_attestation =
Expand All @@ -47,8 +48,7 @@ fn create_handshakers() -> (ClientHandshaker, ServerHandshaker) {
AttestationBehavior::create_bidirectional_attestation(&[], TEE_MEASUREMENT.as_bytes())
.unwrap();

let additional_info = br"Additional Info".to_vec();
let server_handshaker = ServerHandshaker::new(bidirectional_attestation, additional_info);
let server_handshaker = ServerHandshaker::new(bidirectional_attestation);

(client_handshaker, server_handshaker)
}
Expand All @@ -71,7 +71,7 @@ fn test_handshake() {
.expect("Couldn't create client hello message");

let server_identity = server_handshaker
.next_step(&client_hello)
.next_step(&client_hello, ADDITIONAL_INFO.into())
.expect("Couldn't process client hello message")
.expect("Empty server identity message");

Expand All @@ -82,7 +82,7 @@ fn test_handshake() {
assert!(client_handshaker.is_completed());

let result = server_handshaker
.next_step(&client_identity)
.next_step(&client_identity, ADDITIONAL_INFO.into())
.expect("Couldn't process client identity message");
assert_matches!(result, None);
assert!(server_handshaker.is_completed());
Expand Down Expand Up @@ -122,7 +122,7 @@ fn test_invalid_message_after_initialization() {
let result = client_handshaker.create_client_hello();
assert_matches!(result, Err(_));

let result = server_handshaker.next_step(&invalid_message);
let result = server_handshaker.next_step(&invalid_message, ADDITIONAL_INFO.into());
assert_matches!(result, Err(_));
assert!(server_handshaker.is_aborted());
}
Expand All @@ -137,8 +137,11 @@ fn test_invalid_message_after_hello() {
assert_matches!(result, Err(_));
assert!(client_handshaker.is_aborted());

let server_identity = server_handshaker.next_step(&client_hello).unwrap().unwrap();
let result = server_handshaker.next_step(&invalid_message);
let server_identity = server_handshaker
.next_step(&client_hello, ADDITIONAL_INFO.into())
.unwrap()
.unwrap();
let result = server_handshaker.next_step(&invalid_message, ADDITIONAL_INFO.into());
assert_matches!(result, Err(_));
assert!(server_handshaker.is_aborted());

Expand All @@ -152,7 +155,10 @@ fn test_invalid_message_after_identities() {
let invalid_message = vec![INVALID_MESSAGE_HEADER];

let client_hello = client_handshaker.create_client_hello().unwrap();
let server_identity = server_handshaker.next_step(&client_hello).unwrap().unwrap();
let server_identity = server_handshaker
.next_step(&client_hello, ADDITIONAL_INFO.into())
.unwrap()
.unwrap();
let client_identity = client_handshaker
.next_step(&server_identity)
.unwrap()
Expand All @@ -162,11 +168,11 @@ fn test_invalid_message_after_identities() {
assert_matches!(result, Err(_));
assert!(client_handshaker.is_aborted());

let result = server_handshaker.next_step(&invalid_message);
let result = server_handshaker.next_step(&invalid_message, ADDITIONAL_INFO.into());
assert_matches!(result, Err(_));
assert!(server_handshaker.is_aborted());

let result = server_handshaker.next_step(&client_identity);
let result = server_handshaker.next_step(&client_identity, ADDITIONAL_INFO.into());
assert_matches!(result, Err(_));
}

Expand All @@ -177,7 +183,7 @@ fn test_replay_server_identity() {

let first_client_hello = first_client_handshaker.create_client_hello().unwrap();
let first_server_identity = first_server_handshaker
.next_step(&first_client_hello)
.next_step(&first_client_hello, ADDITIONAL_INFO.into())
.unwrap()
.unwrap();

Expand All @@ -194,7 +200,7 @@ fn test_replay_client_identity() {

let first_client_hello = first_client_handshaker.create_client_hello().unwrap();
let first_server_identity = first_server_handshaker
.next_step(&first_client_hello)
.next_step(&first_client_hello, ADDITIONAL_INFO.into())
.unwrap()
.unwrap();
let first_client_identity = first_client_handshaker
Expand All @@ -204,10 +210,10 @@ fn test_replay_client_identity() {

let second_client_hello = second_client_handshaker.create_client_hello().unwrap();
let _ = second_server_handshaker
.next_step(&second_client_hello)
.next_step(&second_client_hello, ADDITIONAL_INFO.into())
.unwrap()
.unwrap();
let result = second_server_handshaker.next_step(&first_client_identity);
let result = second_server_handshaker.next_step(&first_client_identity, ADDITIONAL_INFO.into());
assert_matches!(result, Err(_));
}

Expand Down