Skip to content

Commit 547d0a7

Browse files
authored
Support the complex64 data type (#660)
* Support the complex64 data type * Add more tests, paddle support * Work around an old Torch bug where complex <- scalar segfaults * Fix big endian lookup * Complex64 -> C64 * More Complex64 -> C64 * Fixes
1 parent c6a98fa commit 547d0a7

File tree

11 files changed

+45
-0
lines changed

11 files changed

+45
-0
lines changed

bindings/python/py_src/safetensors/numpy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, np.ndarray]:
154154
"I8": np.int8,
155155
"U8": np.uint8,
156156
"BOOL": bool,
157+
"C64": np.complex64,
157158
}
158159

159160

bindings/python/py_src/safetensors/paddle.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def _paddle2np(paddle_dict: Dict[str, paddle.Tensor]) -> Dict[str, np.array]:
168168
paddle.float64: 8,
169169
paddle.float8_e4m3fn: 1,
170170
paddle.float8_e5m2: 1,
171+
paddle.complex64: 8,
171172
# XXX: These are not supported yet in paddle
172173
# paddle.uint64: 8,
173174
# paddle.uint32: 4,

bindings/python/py_src/safetensors/torch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def load(data: bytes) -> Dict[str, torch.Tensor]:
383383
torch.int8: 1,
384384
torch.bool: 1,
385385
torch.float64: 8,
386+
torch.complex64: 8,
386387
_float8_e4m3fn: 1,
387388
_float8_e5m2: 1,
388389
_float8_e8m0: 1,
@@ -410,6 +411,7 @@ def load(data: bytes) -> Dict[str, torch.Tensor]:
410411
"BOOL": torch.bool,
411412
"F8_E4M3": _float8_e4m3fn,
412413
"F8_E5M2": _float8_e5m2,
414+
"C64": torch.complex64,
413415
}
414416
if Version(torch.__version__) >= Version("2.3.0"):
415417
_TYPES.update(
@@ -493,6 +495,7 @@ def _tobytes(tensor: torch.Tensor, name: str) -> bytes:
493495
# XXX: This is ok because both have the same width and byteswap is a no-op anyway
494496
_float8_e4m3fn: np.uint8,
495497
_float8_e5m2: np.uint8,
498+
torch.complex64: np.complex64,
496499
}
497500
npdtype = NPDTYPES[tensor.dtype]
498501
# Not in place as that would potentially modify a live running model

bindings/python/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ fn prepare(tensor_dict: HashMap<String, PyBound<PyDict>>) -> PyResult<HashMap<St
8484
"float8_e5m2" => Dtype::F8_E5M2,
8585
"float8_e8m0fnu" => Dtype::F8_E8M0,
8686
"float4_e2m1fn_x2" => Dtype::F4,
87+
"complex64" => Dtype::C64,
8788
dtype_str => {
8889
return Err(SafetensorError::new_err(format!(
8990
"dtype {dtype_str} is not covered",
@@ -1467,6 +1468,7 @@ fn get_pydtype(module: &PyBound<'_, PyModule>, dtype: Dtype, is_numpy: bool) ->
14671468
Dtype::F8_E5M2 => module.getattr(intern!(py, "float8_e5m2"))?.into(),
14681469
Dtype::F8_E8M0 => module.getattr(intern!(py, "float8_e8m0fnu"))?.into(),
14691470
Dtype::F4 => module.getattr(intern!(py, "float4_e2m1fn_x2"))?.into(),
1471+
Dtype::C64 => module.getattr(intern!(py, "complex64"))?.into(),
14701472
dtype => {
14711473
return Err(SafetensorError::new_err(format!(
14721474
"Dtype not understood: {dtype}"

bindings/python/src/view.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ pub fn prepare(tensor_dict: HashMap<String, PyBound<PyDict>>) -> PyResult<HashMa
9292
"float8_e5m2" => Dtype::F8_E5M2,
9393
"float8_e8m0fnu" => Dtype::E8M0,
9494
"float4_e2m1fn_x2" => Dtype::F4,
95+
"complex64" => Dtype::C64,
9596
dtype_str => {
9697
return Err(SafetensorError::new_err(format!(
9798
"dtype {dtype_str} is not covered",

bindings/python/tests/test_flax_comparison.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def setUp(self):
2222
"test": random.normal(key, (1024, 1024), dtype=jnp.float32),
2323
"test2": random.normal(key, (1024, 1024), dtype=jnp.float16),
2424
"test3": random.normal(key, (1024, 1024), dtype=jnp.bfloat16),
25+
"test4": random.normal(key, (1024, 1024), dtype=jnp.complex64),
2526
}
2627
self.flax_filename = "./tests/data/flax_load.msgpack"
2728
self.sf_filename = "./tests/data/flax_load.safetensors"

bindings/python/tests/test_mlx_comparison.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def setUp(self):
2626
"test": mx.random.uniform(shape=(1024, 1024), dtype=mx.float32),
2727
"test2": mx.random.uniform(shape=(1024, 1024), dtype=mx.float32),
2828
"test3": mx.random.uniform(shape=(1024, 1024), dtype=mx.float32),
29+
"test4": mx.random.uniform(shape=(1024, 1024), dtype=mx.float32).astype(
30+
mx.complex64
31+
),
2932
# This doesn't work because bfloat16 is not implemented
3033
# with similar workarounds as jax/tensorflow.
3134
# https://github.com/ml-explore/mlx/issues/1296

bindings/python/tests/test_paddle_comparison.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def setUp(self):
2020
"test": paddle.zeros((1024, 1024), dtype=paddle.float32),
2121
"test2": paddle.zeros((1024, 1024), dtype=paddle.float32),
2222
"test3": paddle.zeros((1024, 1024), dtype=paddle.float32),
23+
"test4": paddle.zeros((1024, 1024), dtype=paddle.complex64),
2324
}
2425
self.paddle_filename = "./tests/data/paddle_load.pdparams"
2526
self.sf_filename = "./tests/data/paddle_load.safetensors"

bindings/python/tests/test_pt_comparison.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def test_odd_dtype(self):
6464
"test": torch.randn((2, 2), dtype=torch.bfloat16),
6565
"test2": torch.randn((2, 2), dtype=torch.float16),
6666
"test3": torch.zeros((2, 2), dtype=torch.bool),
67+
"test4": torch.zeros((2, 2), dtype=torch.complex64),
6768
}
6869

6970
# Modify bool to have both values.
@@ -75,6 +76,31 @@ def test_odd_dtype(self):
7576
self.assertTrue(torch.equal(data["test"], reloaded["test"]))
7677
self.assertTrue(torch.equal(data["test2"], reloaded["test2"]))
7778
self.assertTrue(torch.equal(data["test3"], reloaded["test3"]))
79+
self.assertTrue(torch.equal(data["test4"], reloaded["test4"]))
80+
81+
def test_complex(self):
82+
# Test complex separately. Each value consists of two numbers
83+
# and we want to validate that the representation is the same
84+
# across platforms.
85+
data = torch.zeros((2, 2), dtype=torch.complex64)
86+
out = save({"test": data})
87+
88+
self.assertEqual(
89+
out,
90+
b'@\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"C64","shape":[2,2],"data_offsets":[0,32]}} '
91+
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
92+
)
93+
94+
real = torch.tensor([-1.0])
95+
imag = torch.tensor([1.0])
96+
data[1][1] = torch.complex(real, imag)
97+
out = save({"test": data})
98+
99+
self.assertEqual(
100+
out,
101+
b'@\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"C64","shape":[2,2],"data_offsets":[0,32]}} '
102+
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\xbf\x00\x00\x80?",
103+
)
78104

79105
def test_odd_dtype_fp8(self):
80106
if torch.__version__ < "2.1":

bindings/python/tests/test_tf_comparison.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def setUp(self):
3333
"test": tf.zeros((1024, 1024), dtype=tf.float32),
3434
"test2": tf.zeros((1024, 1024), dtype=tf.float32),
3535
"test3": tf.zeros((1024, 1024), dtype=tf.float32),
36+
"test4": tf.zeros((1024, 1024), dtype=tf.complex64),
3637
}
3738
self.tf_filename = "./tests/data/tf_load.h5"
3839
self.sf_filename = "./tests/data/tf_load.safetensors"

0 commit comments

Comments
 (0)