@@ -27,6 +27,7 @@ use assert_matches::assert_matches;
2727
2828const TEE_MEASUREMENT : & str = "Test TEE measurement" ;
2929const DATA : [ u8 ; 10 ] = [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ] ;
30+ const ADDITIONAL_INFO : [ u8 ; 15 ] = * br"Additional Info" ;
3031
3132fn create_handshakers ( ) -> ( ClientHandshaker , ServerHandshaker ) {
3233 let bidirectional_attestation =
@@ -47,8 +48,7 @@ fn create_handshakers() -> (ClientHandshaker, ServerHandshaker) {
4748 AttestationBehavior :: create_bidirectional_attestation ( & [ ] , TEE_MEASUREMENT . as_bytes ( ) )
4849 . unwrap ( ) ;
4950
50- let additional_info = br"Additional Info" . to_vec ( ) ;
51- let server_handshaker = ServerHandshaker :: new ( bidirectional_attestation, additional_info) ;
51+ let server_handshaker = ServerHandshaker :: new ( bidirectional_attestation) ;
5252
5353 ( client_handshaker, server_handshaker)
5454}
@@ -71,7 +71,7 @@ fn test_handshake() {
7171 . expect ( "Couldn't create client hello message" ) ;
7272
7373 let server_identity = server_handshaker
74- . next_step ( & client_hello)
74+ . next_step ( & client_hello, ADDITIONAL_INFO . into ( ) )
7575 . expect ( "Couldn't process client hello message" )
7676 . expect ( "Empty server identity message" ) ;
7777
@@ -82,7 +82,7 @@ fn test_handshake() {
8282 assert ! ( client_handshaker. is_completed( ) ) ;
8383
8484 let result = server_handshaker
85- . next_step ( & client_identity)
85+ . next_step ( & client_identity, ADDITIONAL_INFO . into ( ) )
8686 . expect ( "Couldn't process client identity message" ) ;
8787 assert_matches ! ( result, None ) ;
8888 assert ! ( server_handshaker. is_completed( ) ) ;
@@ -122,7 +122,7 @@ fn test_invalid_message_after_initialization() {
122122 let result = client_handshaker. create_client_hello ( ) ;
123123 assert_matches ! ( result, Err ( _) ) ;
124124
125- let result = server_handshaker. next_step ( & invalid_message) ;
125+ let result = server_handshaker. next_step ( & invalid_message, ADDITIONAL_INFO . into ( ) ) ;
126126 assert_matches ! ( result, Err ( _) ) ;
127127 assert ! ( server_handshaker. is_aborted( ) ) ;
128128}
@@ -137,8 +137,11 @@ fn test_invalid_message_after_hello() {
137137 assert_matches ! ( result, Err ( _) ) ;
138138 assert ! ( client_handshaker. is_aborted( ) ) ;
139139
140- let server_identity = server_handshaker. next_step ( & client_hello) . unwrap ( ) . unwrap ( ) ;
141- let result = server_handshaker. next_step ( & invalid_message) ;
140+ let server_identity = server_handshaker
141+ . next_step ( & client_hello, ADDITIONAL_INFO . into ( ) )
142+ . unwrap ( )
143+ . unwrap ( ) ;
144+ let result = server_handshaker. next_step ( & invalid_message, ADDITIONAL_INFO . into ( ) ) ;
142145 assert_matches ! ( result, Err ( _) ) ;
143146 assert ! ( server_handshaker. is_aborted( ) ) ;
144147
@@ -152,7 +155,10 @@ fn test_invalid_message_after_identities() {
152155 let invalid_message = vec ! [ INVALID_MESSAGE_HEADER ] ;
153156
154157 let client_hello = client_handshaker. create_client_hello ( ) . unwrap ( ) ;
155- let server_identity = server_handshaker. next_step ( & client_hello) . unwrap ( ) . unwrap ( ) ;
158+ let server_identity = server_handshaker
159+ . next_step ( & client_hello, ADDITIONAL_INFO . into ( ) )
160+ . unwrap ( )
161+ . unwrap ( ) ;
156162 let client_identity = client_handshaker
157163 . next_step ( & server_identity)
158164 . unwrap ( )
@@ -162,11 +168,11 @@ fn test_invalid_message_after_identities() {
162168 assert_matches ! ( result, Err ( _) ) ;
163169 assert ! ( client_handshaker. is_aborted( ) ) ;
164170
165- let result = server_handshaker. next_step ( & invalid_message) ;
171+ let result = server_handshaker. next_step ( & invalid_message, ADDITIONAL_INFO . into ( ) ) ;
166172 assert_matches ! ( result, Err ( _) ) ;
167173 assert ! ( server_handshaker. is_aborted( ) ) ;
168174
169- let result = server_handshaker. next_step ( & client_identity) ;
175+ let result = server_handshaker. next_step ( & client_identity, ADDITIONAL_INFO . into ( ) ) ;
170176 assert_matches ! ( result, Err ( _) ) ;
171177}
172178
@@ -177,7 +183,7 @@ fn test_replay_server_identity() {
177183
178184 let first_client_hello = first_client_handshaker. create_client_hello ( ) . unwrap ( ) ;
179185 let first_server_identity = first_server_handshaker
180- . next_step ( & first_client_hello)
186+ . next_step ( & first_client_hello, ADDITIONAL_INFO . into ( ) )
181187 . unwrap ( )
182188 . unwrap ( ) ;
183189
@@ -194,7 +200,7 @@ fn test_replay_client_identity() {
194200
195201 let first_client_hello = first_client_handshaker. create_client_hello ( ) . unwrap ( ) ;
196202 let first_server_identity = first_server_handshaker
197- . next_step ( & first_client_hello)
203+ . next_step ( & first_client_hello, ADDITIONAL_INFO . into ( ) )
198204 . unwrap ( )
199205 . unwrap ( ) ;
200206 let first_client_identity = first_client_handshaker
@@ -204,10 +210,10 @@ fn test_replay_client_identity() {
204210
205211 let second_client_hello = second_client_handshaker. create_client_hello ( ) . unwrap ( ) ;
206212 let _ = second_server_handshaker
207- . next_step ( & second_client_hello)
213+ . next_step ( & second_client_hello, ADDITIONAL_INFO . into ( ) )
208214 . unwrap ( )
209215 . unwrap ( ) ;
210- let result = second_server_handshaker. next_step ( & first_client_identity) ;
216+ let result = second_server_handshaker. next_step ( & first_client_identity, ADDITIONAL_INFO . into ( ) ) ;
211217 assert_matches ! ( result, Err ( _) ) ;
212218}
213219
0 commit comments