@@ -285,6 +285,10 @@ pub fn serialize<
285285 tensors,
286286 ) = prepare ( data, data_info) ?;
287287
288+ if n > MAX_HEADER_SIZE as u64 {
289+ return Err ( SafeTensorError :: HeaderTooLarge ) ;
290+ }
291+
288292 let expected_size = N_LEN + header_bytes. len ( ) + offset;
289293 let mut buffer: Vec < u8 > = Vec :: with_capacity ( expected_size) ;
290294 buffer. extend ( n. to_le_bytes ( ) ) ;
@@ -318,6 +322,10 @@ where
318322 tensors,
319323 ) = prepare ( data, data_info) ?;
320324
325+ if n > MAX_HEADER_SIZE as u64 {
326+ return Err ( SafeTensorError :: HeaderTooLarge ) ;
327+ }
328+
321329 let mut f = std:: io:: BufWriter :: new ( std:: fs:: File :: create ( filename) ?) ;
322330 f. write_all ( n. to_le_bytes ( ) . as_ref ( ) ) ?;
323331 f. write_all ( & header_bytes) ?;
@@ -1494,4 +1502,20 @@ mod tests {
14941502 _ => panic ! ( "This should not be able to be deserialized" ) ,
14951503 }
14961504 }
1505+
1506+ #[ test]
1507+ fn test_invalid_header_size_serialization ( ) {
1508+ let mut data_info = HashMap :: < String , String > :: new ( ) ;
1509+ let tensors: HashMap < String , TensorView > = HashMap :: new ( ) ;
1510+
1511+ // a char is 1 byte in utf-8, so we can just repeat 'a' to get large metadata
1512+ let very_large_metadata = "a" . repeat ( MAX_HEADER_SIZE ) ;
1513+ data_info. insert ( "very_large_metadata" . to_string ( ) , very_large_metadata) ;
1514+ match serialize ( & tensors, Some ( data_info) ) {
1515+ Err ( SafeTensorError :: HeaderTooLarge ) => {
1516+ // Yes we have the correct error
1517+ }
1518+ _ => panic ! ( "This should not be able to be serialized" ) ,
1519+ }
1520+ }
14971521}
0 commit comments