Skip to content

Commit 0789621

Browse files
author
Juliette Pretot
committed
Do not clone additional info into each ServerHandshaker.
This is important, as with unary request the runtime will keep a cache of ServerHandshaker instances
1 parent ccad804 commit 0789621

File tree

3 files changed

+40
-30
lines changed

3 files changed

+40
-30
lines changed

grpc_attestation/src/server.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,7 @@ where
151151
.map_err(|error| {
152152
error_logger.log_error(&format!("Couldn't create self attestation behavior: {:?}", error));
153153
Status::internal("")
154-
})?,
155-
additional_info,
154+
})?
156155
);
157156
while !handshaker.is_completed() {
158157
let incoming_message = request_stream.next()
@@ -167,7 +166,7 @@ where
167166
})?;
168167

169168
let outgoing_message = handshaker
170-
.next_step(&incoming_message.body)
169+
.next_step(&incoming_message.body, additional_info.clone())
171170
.map_err(|error| {
172171
error_logger.log_error(&format!("Couldn't process handshake message: {:?}", error));
173172
Status::aborted("")

remote_attestation/rust/src/handshaker.rs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -329,40 +329,45 @@ pub struct ServerHandshaker {
329329
/// Collection of previously sent and received messages.
330330
/// Signed transcript is sent in messages to prevent replay attacks.
331331
transcript: Transcript,
332-
/// Additional info about the server, including configuration information and proof of
333-
/// inclusion in a verifiable log.
334-
additional_info: Vec<u8>,
335332
}
336333

337334
impl ServerHandshaker {
338335
/// Creates [`ServerHandshaker`] with `ServerHandshakerState::ExpectingClientIdentity`
339336
/// state.
340-
pub fn new(behavior: AttestationBehavior, additional_info: Vec<u8>) -> Self {
337+
pub fn new(behavior: AttestationBehavior) -> Self {
341338
Self {
342339
behavior,
343340
state: ServerHandshakerState::ExpectingClientHello,
344341
transcript: Transcript::new(),
345-
additional_info,
346342
}
347343
}
348344

349345
/// Processes incoming `message` and returns a serialized remote attestation message.
350346
/// If [`None`] is returned, then no messages should be sent out to the client.
351-
pub fn next_step(&mut self, message: &[u8]) -> anyhow::Result<Option<Vec<u8>>> {
352-
self.next_step_util(message).map_err(|error| {
353-
self.state = ServerHandshakerState::Aborted;
354-
error
355-
})
347+
pub fn next_step(
348+
&mut self,
349+
message: &[u8],
350+
additional_info: Vec<u8>,
351+
) -> anyhow::Result<Option<Vec<u8>>> {
352+
self.next_step_util(message, additional_info)
353+
.map_err(|error| {
354+
self.state = ServerHandshakerState::Aborted;
355+
error
356+
})
356357
}
357358

358-
fn next_step_util(&mut self, message: &[u8]) -> anyhow::Result<Option<Vec<u8>>> {
359+
fn next_step_util(
360+
&mut self,
361+
message: &[u8],
362+
additional_info: Vec<u8>,
363+
) -> anyhow::Result<Option<Vec<u8>>> {
359364
let deserialized_message =
360365
deserialize_message(message).context("Couldn't deserialize message")?;
361366
match deserialized_message {
362367
MessageWrapper::ClientHello(client_hello) => match &self.state {
363368
ServerHandshakerState::ExpectingClientHello => {
364369
let server_identity = self
365-
.process_client_hello(client_hello)
370+
.process_client_hello(client_hello, additional_info)
366371
.context("Couldn't process client hello message")?;
367372
let serialized_server_identity = server_identity
368373
.serialize()
@@ -429,6 +434,7 @@ impl ServerHandshaker {
429434
fn process_client_hello(
430435
&mut self,
431436
client_hello: ClientHello,
437+
additional_info: Vec<u8>,
432438
) -> anyhow::Result<ServerIdentity> {
433439
// Create server identity message.
434440
let key_negotiator = KeyNegotiator::create(KeyNegotiatorType::Server)
@@ -448,7 +454,6 @@ impl ServerHandshaker {
448454
.as_ref()
449455
.context("Couldn't get TEE certificate")?;
450456

451-
let additional_info = self.additional_info.clone();
452457
let attestation_info =
453458
create_attestation_info(signer, additional_info.as_ref(), tee_certificate)
454459
.context("Couldn't get attestation info")?;

remote_attestation/rust/src/tests/handshaker.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use assert_matches::assert_matches;
2727

2828
const TEE_MEASUREMENT: &str = "Test TEE measurement";
2929
const DATA: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
30+
const ADDITIONAL_INFO: [u8; 15] = *br"Additional Info";
3031

3132
fn create_handshakers() -> (ClientHandshaker, ServerHandshaker) {
3233
let bidirectional_attestation =
@@ -47,8 +48,7 @@ fn create_handshakers() -> (ClientHandshaker, ServerHandshaker) {
4748
AttestationBehavior::create_bidirectional_attestation(&[], TEE_MEASUREMENT.as_bytes())
4849
.unwrap();
4950

50-
let additional_info = br"Additional Info".to_vec();
51-
let server_handshaker = ServerHandshaker::new(bidirectional_attestation, additional_info);
51+
let server_handshaker = ServerHandshaker::new(bidirectional_attestation);
5252

5353
(client_handshaker, server_handshaker)
5454
}
@@ -71,7 +71,7 @@ fn test_handshake() {
7171
.expect("Couldn't create client hello message");
7272

7373
let server_identity = server_handshaker
74-
.next_step(&client_hello)
74+
.next_step(&client_hello, ADDITIONAL_INFO.into())
7575
.expect("Couldn't process client hello message")
7676
.expect("Empty server identity message");
7777

@@ -82,7 +82,7 @@ fn test_handshake() {
8282
assert!(client_handshaker.is_completed());
8383

8484
let result = server_handshaker
85-
.next_step(&client_identity)
85+
.next_step(&client_identity, ADDITIONAL_INFO.into())
8686
.expect("Couldn't process client identity message");
8787
assert_matches!(result, None);
8888
assert!(server_handshaker.is_completed());
@@ -122,7 +122,7 @@ fn test_invalid_message_after_initialization() {
122122
let result = client_handshaker.create_client_hello();
123123
assert_matches!(result, Err(_));
124124

125-
let result = server_handshaker.next_step(&invalid_message);
125+
let result = server_handshaker.next_step(&invalid_message, ADDITIONAL_INFO.into());
126126
assert_matches!(result, Err(_));
127127
assert!(server_handshaker.is_aborted());
128128
}
@@ -137,8 +137,11 @@ fn test_invalid_message_after_hello() {
137137
assert_matches!(result, Err(_));
138138
assert!(client_handshaker.is_aborted());
139139

140-
let server_identity = server_handshaker.next_step(&client_hello).unwrap().unwrap();
141-
let result = server_handshaker.next_step(&invalid_message);
140+
let server_identity = server_handshaker
141+
.next_step(&client_hello, ADDITIONAL_INFO.into())
142+
.unwrap()
143+
.unwrap();
144+
let result = server_handshaker.next_step(&invalid_message, ADDITIONAL_INFO.into());
142145
assert_matches!(result, Err(_));
143146
assert!(server_handshaker.is_aborted());
144147

@@ -152,7 +155,10 @@ fn test_invalid_message_after_identities() {
152155
let invalid_message = vec![INVALID_MESSAGE_HEADER];
153156

154157
let client_hello = client_handshaker.create_client_hello().unwrap();
155-
let server_identity = server_handshaker.next_step(&client_hello).unwrap().unwrap();
158+
let server_identity = server_handshaker
159+
.next_step(&client_hello, ADDITIONAL_INFO.into())
160+
.unwrap()
161+
.unwrap();
156162
let client_identity = client_handshaker
157163
.next_step(&server_identity)
158164
.unwrap()
@@ -162,11 +168,11 @@ fn test_invalid_message_after_identities() {
162168
assert_matches!(result, Err(_));
163169
assert!(client_handshaker.is_aborted());
164170

165-
let result = server_handshaker.next_step(&invalid_message);
171+
let result = server_handshaker.next_step(&invalid_message, ADDITIONAL_INFO.into());
166172
assert_matches!(result, Err(_));
167173
assert!(server_handshaker.is_aborted());
168174

169-
let result = server_handshaker.next_step(&client_identity);
175+
let result = server_handshaker.next_step(&client_identity, ADDITIONAL_INFO.into());
170176
assert_matches!(result, Err(_));
171177
}
172178

@@ -177,7 +183,7 @@ fn test_replay_server_identity() {
177183

178184
let first_client_hello = first_client_handshaker.create_client_hello().unwrap();
179185
let first_server_identity = first_server_handshaker
180-
.next_step(&first_client_hello)
186+
.next_step(&first_client_hello, ADDITIONAL_INFO.into())
181187
.unwrap()
182188
.unwrap();
183189

@@ -194,7 +200,7 @@ fn test_replay_client_identity() {
194200

195201
let first_client_hello = first_client_handshaker.create_client_hello().unwrap();
196202
let first_server_identity = first_server_handshaker
197-
.next_step(&first_client_hello)
203+
.next_step(&first_client_hello, ADDITIONAL_INFO.into())
198204
.unwrap()
199205
.unwrap();
200206
let first_client_identity = first_client_handshaker
@@ -204,10 +210,10 @@ fn test_replay_client_identity() {
204210

205211
let second_client_hello = second_client_handshaker.create_client_hello().unwrap();
206212
let _ = second_server_handshaker
207-
.next_step(&second_client_hello)
213+
.next_step(&second_client_hello, ADDITIONAL_INFO.into())
208214
.unwrap()
209215
.unwrap();
210-
let result = second_server_handshaker.next_step(&first_client_identity);
216+
let result = second_server_handshaker.next_step(&first_client_identity, ADDITIONAL_INFO.into());
211217
assert_matches!(result, Err(_));
212218
}
213219

0 commit comments

Comments
 (0)