@@ -27,6 +27,7 @@ use assert_matches::assert_matches;
27
27
28
28
const TEE_MEASUREMENT : & str = "Test TEE measurement" ;
29
29
const DATA : [ u8 ; 10 ] = [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ] ;
30
+ const ADDITIONAL_INFO : [ u8 ; 15 ] = * br"Additional Info" ;
30
31
31
32
fn create_handshakers ( ) -> ( ClientHandshaker , ServerHandshaker ) {
32
33
let bidirectional_attestation =
@@ -47,8 +48,7 @@ fn create_handshakers() -> (ClientHandshaker, ServerHandshaker) {
47
48
AttestationBehavior :: create_bidirectional_attestation ( & [ ] , TEE_MEASUREMENT . as_bytes ( ) )
48
49
. unwrap ( ) ;
49
50
50
- let additional_info = br"Additional Info" . to_vec ( ) ;
51
- let server_handshaker = ServerHandshaker :: new ( bidirectional_attestation, additional_info) ;
51
+ let server_handshaker = ServerHandshaker :: new ( bidirectional_attestation) ;
52
52
53
53
( client_handshaker, server_handshaker)
54
54
}
@@ -71,7 +71,7 @@ fn test_handshake() {
71
71
. expect ( "Couldn't create client hello message" ) ;
72
72
73
73
let server_identity = server_handshaker
74
- . next_step ( & client_hello)
74
+ . next_step ( & client_hello, ADDITIONAL_INFO . into ( ) )
75
75
. expect ( "Couldn't process client hello message" )
76
76
. expect ( "Empty server identity message" ) ;
77
77
@@ -82,7 +82,7 @@ fn test_handshake() {
82
82
assert ! ( client_handshaker. is_completed( ) ) ;
83
83
84
84
let result = server_handshaker
85
- . next_step ( & client_identity)
85
+ . next_step ( & client_identity, ADDITIONAL_INFO . into ( ) )
86
86
. expect ( "Couldn't process client identity message" ) ;
87
87
assert_matches ! ( result, None ) ;
88
88
assert ! ( server_handshaker. is_completed( ) ) ;
@@ -122,7 +122,7 @@ fn test_invalid_message_after_initialization() {
122
122
let result = client_handshaker. create_client_hello ( ) ;
123
123
assert_matches ! ( result, Err ( _) ) ;
124
124
125
- let result = server_handshaker. next_step ( & invalid_message) ;
125
+ let result = server_handshaker. next_step ( & invalid_message, ADDITIONAL_INFO . into ( ) ) ;
126
126
assert_matches ! ( result, Err ( _) ) ;
127
127
assert ! ( server_handshaker. is_aborted( ) ) ;
128
128
}
@@ -137,8 +137,11 @@ fn test_invalid_message_after_hello() {
137
137
assert_matches ! ( result, Err ( _) ) ;
138
138
assert ! ( client_handshaker. is_aborted( ) ) ;
139
139
140
- let server_identity = server_handshaker. next_step ( & client_hello) . unwrap ( ) . unwrap ( ) ;
141
- let result = server_handshaker. next_step ( & invalid_message) ;
140
+ let server_identity = server_handshaker
141
+ . next_step ( & client_hello, ADDITIONAL_INFO . into ( ) )
142
+ . unwrap ( )
143
+ . unwrap ( ) ;
144
+ let result = server_handshaker. next_step ( & invalid_message, ADDITIONAL_INFO . into ( ) ) ;
142
145
assert_matches ! ( result, Err ( _) ) ;
143
146
assert ! ( server_handshaker. is_aborted( ) ) ;
144
147
@@ -152,7 +155,10 @@ fn test_invalid_message_after_identities() {
152
155
let invalid_message = vec ! [ INVALID_MESSAGE_HEADER ] ;
153
156
154
157
let client_hello = client_handshaker. create_client_hello ( ) . unwrap ( ) ;
155
- let server_identity = server_handshaker. next_step ( & client_hello) . unwrap ( ) . unwrap ( ) ;
158
+ let server_identity = server_handshaker
159
+ . next_step ( & client_hello, ADDITIONAL_INFO . into ( ) )
160
+ . unwrap ( )
161
+ . unwrap ( ) ;
156
162
let client_identity = client_handshaker
157
163
. next_step ( & server_identity)
158
164
. unwrap ( )
@@ -162,11 +168,11 @@ fn test_invalid_message_after_identities() {
162
168
assert_matches ! ( result, Err ( _) ) ;
163
169
assert ! ( client_handshaker. is_aborted( ) ) ;
164
170
165
- let result = server_handshaker. next_step ( & invalid_message) ;
171
+ let result = server_handshaker. next_step ( & invalid_message, ADDITIONAL_INFO . into ( ) ) ;
166
172
assert_matches ! ( result, Err ( _) ) ;
167
173
assert ! ( server_handshaker. is_aborted( ) ) ;
168
174
169
- let result = server_handshaker. next_step ( & client_identity) ;
175
+ let result = server_handshaker. next_step ( & client_identity, ADDITIONAL_INFO . into ( ) ) ;
170
176
assert_matches ! ( result, Err ( _) ) ;
171
177
}
172
178
@@ -177,7 +183,7 @@ fn test_replay_server_identity() {
177
183
178
184
let first_client_hello = first_client_handshaker. create_client_hello ( ) . unwrap ( ) ;
179
185
let first_server_identity = first_server_handshaker
180
- . next_step ( & first_client_hello)
186
+ . next_step ( & first_client_hello, ADDITIONAL_INFO . into ( ) )
181
187
. unwrap ( )
182
188
. unwrap ( ) ;
183
189
@@ -194,7 +200,7 @@ fn test_replay_client_identity() {
194
200
195
201
let first_client_hello = first_client_handshaker. create_client_hello ( ) . unwrap ( ) ;
196
202
let first_server_identity = first_server_handshaker
197
- . next_step ( & first_client_hello)
203
+ . next_step ( & first_client_hello, ADDITIONAL_INFO . into ( ) )
198
204
. unwrap ( )
199
205
. unwrap ( ) ;
200
206
let first_client_identity = first_client_handshaker
@@ -204,10 +210,10 @@ fn test_replay_client_identity() {
204
210
205
211
let second_client_hello = second_client_handshaker. create_client_hello ( ) . unwrap ( ) ;
206
212
let _ = second_server_handshaker
207
- . next_step ( & second_client_hello)
213
+ . next_step ( & second_client_hello, ADDITIONAL_INFO . into ( ) )
208
214
. unwrap ( )
209
215
. unwrap ( ) ;
210
- let result = second_server_handshaker. next_step ( & first_client_identity) ;
216
+ let result = second_server_handshaker. next_step ( & first_client_identity, ADDITIONAL_INFO . into ( ) ) ;
211
217
assert_matches ! ( result, Err ( _) ) ;
212
218
}
213
219
0 commit comments