Skip to content

Commit 280141d

Browse files
authored
feat: add header size check at serialization (#669)
1 parent 547d0a7 commit 280141d

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

safetensors/src/tensor.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,10 @@ pub fn serialize<
285285
tensors,
286286
) = prepare(data, data_info)?;
287287

288+
if n > MAX_HEADER_SIZE as u64 {
289+
return Err(SafeTensorError::HeaderTooLarge);
290+
}
291+
288292
let expected_size = N_LEN + header_bytes.len() + offset;
289293
let mut buffer: Vec<u8> = Vec::with_capacity(expected_size);
290294
buffer.extend(n.to_le_bytes());
@@ -318,6 +322,10 @@ where
318322
tensors,
319323
) = prepare(data, data_info)?;
320324

325+
if n > MAX_HEADER_SIZE as u64 {
326+
return Err(SafeTensorError::HeaderTooLarge);
327+
}
328+
321329
let mut f = std::io::BufWriter::new(std::fs::File::create(filename)?);
322330
f.write_all(n.to_le_bytes().as_ref())?;
323331
f.write_all(&header_bytes)?;
@@ -1494,4 +1502,20 @@ mod tests {
14941502
_ => panic!("This should not be able to be deserialized"),
14951503
}
14961504
}
1505+
1506+
#[test]
1507+
fn test_invalid_header_size_serialization() {
1508+
let mut data_info = HashMap::<String, String>::new();
1509+
let tensors: HashMap<String, TensorView> = HashMap::new();
1510+
1511+
// a char is 1 byte in utf-8, so we can just repeat 'a' to get large metadata
1512+
let very_large_metadata = "a".repeat(MAX_HEADER_SIZE);
1513+
data_info.insert("very_large_metadata".to_string(), very_large_metadata);
1514+
match serialize(&tensors, Some(data_info)) {
1515+
Err(SafeTensorError::HeaderTooLarge) => {
1516+
// Yes we have the correct error
1517+
}
1518+
_ => panic!("This should not be able to be serialized"),
1519+
}
1520+
}
14971521
}

0 commit comments

Comments
 (0)