Skip to content
5 changes: 2 additions & 3 deletions grpc_attestation/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ where
.map_err(|error| {
error_logger.log_error(&format!("Couldn't create self attestation behavior: {:?}", error));
Status::internal("")
})?,
additional_info,
})?
);
while !handshaker.is_completed() {
let incoming_message = request_stream.next()
Expand All @@ -167,7 +166,7 @@ where
})?;

let outgoing_message = handshaker
.next_step(&incoming_message.body)
.next_step(&incoming_message.body, &additional_info)
.map_err(|error| {
error_logger.log_error(&format!("Couldn't process handshake message: {:?}", error));
Status::aborted("")
Expand Down
35 changes: 20 additions & 15 deletions remote_attestation/rust/src/handshaker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,40 +329,45 @@ pub struct ServerHandshaker {
/// Collection of previously sent and received messages.
/// Signed transcript is sent in messages to prevent replay attacks.
transcript: Transcript,
/// Additional info about the server, including configuration information and proof of
/// inclusion in a verifiable log.
additional_info: Vec<u8>,
}

impl ServerHandshaker {
/// Creates [`ServerHandshaker`] with `ServerHandshakerState::ExpectingClientIdentity`
/// state.
pub fn new(behavior: AttestationBehavior, additional_info: Vec<u8>) -> Self {
pub fn new(behavior: AttestationBehavior) -> Self {
Self {
behavior,
state: ServerHandshakerState::ExpectingClientHello,
transcript: Transcript::new(),
additional_info,
}
}

/// Processes incoming `message` and returns a serialized remote attestation message.
/// If [`None`] is returned, then no messages should be sent out to the client.
pub fn next_step(&mut self, message: &[u8]) -> anyhow::Result<Option<Vec<u8>>> {
self.next_step_util(message).map_err(|error| {
self.state = ServerHandshakerState::Aborted;
error
})
pub fn next_step(
&mut self,
message: &[u8],
additional_info: &[u8],
) -> anyhow::Result<Option<Vec<u8>>> {
self.next_step_util(message, additional_info)
.map_err(|error| {
self.state = ServerHandshakerState::Aborted;
error
})
}

fn next_step_util(&mut self, message: &[u8]) -> anyhow::Result<Option<Vec<u8>>> {
fn next_step_util(
&mut self,
message: &[u8],
additional_info: &[u8],
) -> anyhow::Result<Option<Vec<u8>>> {
let deserialized_message =
deserialize_message(message).context("Couldn't deserialize message")?;
match deserialized_message {
MessageWrapper::ClientHello(client_hello) => match &self.state {
ServerHandshakerState::ExpectingClientHello => {
let server_identity = self
.process_client_hello(client_hello)
.process_client_hello(client_hello, additional_info)
.context("Couldn't process client hello message")?;
let serialized_server_identity = server_identity
.serialize()
Expand Down Expand Up @@ -429,6 +434,7 @@ impl ServerHandshaker {
fn process_client_hello(
&mut self,
client_hello: ClientHello,
additional_info: &[u8],
) -> anyhow::Result<ServerIdentity> {
// Create server identity message.
let key_negotiator = KeyNegotiator::create(KeyNegotiatorType::Server)
Expand All @@ -448,9 +454,8 @@ impl ServerHandshaker {
.as_ref()
.context("Couldn't get TEE certificate")?;

let additional_info = self.additional_info.clone();
let attestation_info =
create_attestation_info(signer, additional_info.as_ref(), tee_certificate)
create_attestation_info(signer, additional_info, tee_certificate)
.context("Couldn't get attestation info")?;

let mut server_identity = ServerIdentity::new(
Expand All @@ -460,7 +465,7 @@ impl ServerHandshaker {
.public_key()
.context("Couldn't get singing public key")?,
attestation_info,
additional_info,
additional_info.to_vec(),
);

// Update current transcript.
Expand Down
34 changes: 20 additions & 14 deletions remote_attestation/rust/src/tests/handshaker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use assert_matches::assert_matches;

const TEE_MEASUREMENT: &str = "Test TEE measurement";
const DATA: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
const ADDITIONAL_INFO: &[u8; 15] = br"Additional Info";

fn create_handshakers() -> (ClientHandshaker, ServerHandshaker) {
let bidirectional_attestation =
Expand All @@ -47,8 +48,7 @@ fn create_handshakers() -> (ClientHandshaker, ServerHandshaker) {
AttestationBehavior::create_bidirectional_attestation(&[], TEE_MEASUREMENT.as_bytes())
.unwrap();

let additional_info = br"Additional Info".to_vec();
let server_handshaker = ServerHandshaker::new(bidirectional_attestation, additional_info);
let server_handshaker = ServerHandshaker::new(bidirectional_attestation);

(client_handshaker, server_handshaker)
}
Expand All @@ -71,7 +71,7 @@ fn test_handshake() {
.expect("Couldn't create client hello message");

let server_identity = server_handshaker
.next_step(&client_hello)
.next_step(&client_hello, ADDITIONAL_INFO)
.expect("Couldn't process client hello message")
.expect("Empty server identity message");

Expand All @@ -82,7 +82,7 @@ fn test_handshake() {
assert!(client_handshaker.is_completed());

let result = server_handshaker
.next_step(&client_identity)
.next_step(&client_identity, ADDITIONAL_INFO)
.expect("Couldn't process client identity message");
assert_matches!(result, None);
assert!(server_handshaker.is_completed());
Expand Down Expand Up @@ -122,7 +122,7 @@ fn test_invalid_message_after_initialization() {
let result = client_handshaker.create_client_hello();
assert_matches!(result, Err(_));

let result = server_handshaker.next_step(&invalid_message);
let result = server_handshaker.next_step(&invalid_message, ADDITIONAL_INFO);
assert_matches!(result, Err(_));
assert!(server_handshaker.is_aborted());
}
Expand All @@ -137,8 +137,11 @@ fn test_invalid_message_after_hello() {
assert_matches!(result, Err(_));
assert!(client_handshaker.is_aborted());

let server_identity = server_handshaker.next_step(&client_hello).unwrap().unwrap();
let result = server_handshaker.next_step(&invalid_message);
let server_identity = server_handshaker
.next_step(&client_hello, ADDITIONAL_INFO)
.unwrap()
.unwrap();
let result = server_handshaker.next_step(&invalid_message, ADDITIONAL_INFO);
assert_matches!(result, Err(_));
assert!(server_handshaker.is_aborted());

Expand All @@ -152,7 +155,10 @@ fn test_invalid_message_after_identities() {
let invalid_message = vec![INVALID_MESSAGE_HEADER];

let client_hello = client_handshaker.create_client_hello().unwrap();
let server_identity = server_handshaker.next_step(&client_hello).unwrap().unwrap();
let server_identity = server_handshaker
.next_step(&client_hello, ADDITIONAL_INFO)
.unwrap()
.unwrap();
let client_identity = client_handshaker
.next_step(&server_identity)
.unwrap()
Expand All @@ -162,11 +168,11 @@ fn test_invalid_message_after_identities() {
assert_matches!(result, Err(_));
assert!(client_handshaker.is_aborted());

let result = server_handshaker.next_step(&invalid_message);
let result = server_handshaker.next_step(&invalid_message, ADDITIONAL_INFO);
assert_matches!(result, Err(_));
assert!(server_handshaker.is_aborted());

let result = server_handshaker.next_step(&client_identity);
let result = server_handshaker.next_step(&client_identity, ADDITIONAL_INFO);
assert_matches!(result, Err(_));
}

Expand All @@ -177,7 +183,7 @@ fn test_replay_server_identity() {

let first_client_hello = first_client_handshaker.create_client_hello().unwrap();
let first_server_identity = first_server_handshaker
.next_step(&first_client_hello)
.next_step(&first_client_hello, ADDITIONAL_INFO)
.unwrap()
.unwrap();

Expand All @@ -194,7 +200,7 @@ fn test_replay_client_identity() {

let first_client_hello = first_client_handshaker.create_client_hello().unwrap();
let first_server_identity = first_server_handshaker
.next_step(&first_client_hello)
.next_step(&first_client_hello, ADDITIONAL_INFO)
.unwrap()
.unwrap();
let first_client_identity = first_client_handshaker
Expand All @@ -204,10 +210,10 @@ fn test_replay_client_identity() {

let second_client_hello = second_client_handshaker.create_client_hello().unwrap();
let _ = second_server_handshaker
.next_step(&second_client_hello)
.next_step(&second_client_hello, ADDITIONAL_INFO)
.unwrap()
.unwrap();
let result = second_server_handshaker.next_step(&first_client_identity);
let result = second_server_handshaker.next_step(&first_client_identity, ADDITIONAL_INFO);
assert_matches!(result, Err(_));
}

Expand Down