Skip to content

Commit e472167

Browse files
committed
Create SafeTensors reader implementation
1 parent 4bc70df commit e472167

File tree

7 files changed

+1148
-7
lines changed

7 files changed

+1148
-7
lines changed

docs/safetensors.md

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# SafeTensors Reader User Guide
2+
3+
The SafeTensors reader in VirtualiZarr allows you to reference tensors stored in SafeTensors files. This guide explains how to use the reader effectively.
4+
5+
## What is SafeTensors Format?
6+
7+
SafeTensors is a file format for storing tensors (multidimensional arrays) that offers several advantages:
8+
- Safe: No use of pickle, eliminating security concerns
9+
- Efficient: Zero-copy access for fast loading
10+
- Simple: Straightforward binary format with JSON header
11+
- Language-agnostic: Available across Python, Rust, C++, and JavaScript
12+
13+
The format consists of:
14+
- 8 bytes (header size): little-endian uint64 containing the size of the header
15+
- JSON header: Contains metadata for all tensors (shapes, dtypes, offsets)
16+
- Binary data: Contiguous tensor data
17+
18+
## How VirtualiZarr's SafeTensors Reader Works
19+
20+
VirtualiZarr's SafeTensors reader allows you to:
21+
- Work with the tensors as xarray DataArrays with named dimensions
22+
- Access specific slices of tensors from cloud storage
23+
- Preserve metadata from the original SafeTensors file
24+
25+
## Basic Usage
26+
27+
Opening a SafeTensors file is straightforward:
28+
29+
```python
30+
import virtualizarr as vz
31+
32+
# Open a SafeTensors file
33+
vds = vz.open_virtual_dataset("model.safetensors")
34+
35+
# Access tensors as xarray variables
36+
weight = vds["weight"]
37+
bias = vds["bias"]
38+
39+
# Convert to numpy arrays when needed
40+
weight_array = weight.values
41+
bias_array = bias.values
42+
```
43+
44+
## Custom Dimension Names
45+
46+
By default, dimensions are named generically (e.g., "weight_dim_0", "weight_dim_1"). You can provide custom dimension names for better semantics:
47+
48+
```python
49+
# Define custom dimension names
50+
custom_dims = {
51+
"weight": ["input_dims", "output_dims"],
52+
"bias": ["output_dims"]
53+
}
54+
55+
# Open with custom dimension names
56+
vds = vz.open_virtual_dataset(
57+
"model.safetensors",
58+
virtual_backend_kwargs={"dimension_names": custom_dims}
59+
)
60+
61+
# Now dimensions have meaningful names
62+
print(vds["weight"].dims) # ('input_dims', 'output_dims')
63+
print(vds["bias"].dims) # ('output_dims',)
64+
```
65+
66+
## Loading Specific Variables
67+
68+
You can specify which variables to load as eager arrays instead of virtual references:
69+
70+
```python
71+
# Load specific variables as eager arrays
72+
vds = vz.open_virtual_dataset(
73+
"model_weights.safetensors",
74+
loadable_variables=["small_tensor1", "small_tensor2"]
75+
)
76+
77+
# These will be loaded as regular numpy arrays
78+
small_tensor1 = vds["small_tensor1"]
79+
# Large tensors remain virtual references
80+
large_tensor = vds["large_tensor"]
81+
```
82+
83+
## Working with Remote Files
84+
85+
The SafeTensors reader supports reading from the HuggingFace Hub:
86+
```python
87+
# S3
88+
vds = vz.open_virtual_dataset(
89+
"https://huggingface.co/openai-community/gpt2/model.safetensors",
90+
virtual_backend_kwargs={"revision": "main"}
91+
)
92+
```
93+
94+
It supports reading from object storage:
95+
96+
```python
97+
# S3
98+
vds = vz.open_virtual_dataset(
99+
"s3://my-bucket/model.safetensors",
100+
reader_options={
101+
"storage_options": {
102+
"key": "ACCESS_KEY",
103+
"secret": "SECRET_KEY",
104+
"region_name": "us-west-2"
105+
}
106+
}
107+
)
108+
```
109+
110+
## Accessing Metadata
111+
112+
SafeTensors files can contain metadata at the file level and tensor level:
113+
114+
```python
115+
# Access file-level metadata
116+
print(vds.attrs) # File-level metadata
117+
118+
# Access tensor-specific metadata
119+
print(vds["weight"].attrs) # Tensor-specific metadata
120+
121+
# Access original SafeTensors dtype information
122+
original_dtype = vds["weight"].attrs["original_safetensors_dtype"]
123+
print(f"Original dtype: {original_dtype}")
124+
```
125+
126+
## Known Limitations
127+
128+
### Performance Considerations
129+
- Very large tensors (>1GB) are treated as a single chunk, which may impact memory usage when accessing small slices
130+
- Files with thousands of tiny tensors may have overhead due to metadata handling
131+
132+
## Best Practices
133+
134+
- **For large tensors**: Use slicing to access only the portions you need
135+
- **For remote files**: Use appropriate credentials and optimize access patterns
136+
- **For many small tensors**: Consider loading them eagerly using `loadable_variables`

pyproject.toml

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ hdf = [
5151
"imagecodecs",
5252
"imagecodecs-numcodecs==2024.6.1",
5353
]
54+
safetensors = [
55+
"safetensors",
56+
"ml-dtypes",
57+
]
5458

5559
# kerchunk-based readers
5660
hdf5 = [
@@ -70,6 +74,7 @@ fits = [
7074
]
7175
all_readers = [
7276
"virtualizarr[hdf]",
77+
"virtualizarr[safetensors]",
7378
"virtualizarr[hdf5]",
7479
"virtualizarr[netcdf3]",
7580
"virtualizarr[fits]",
@@ -175,7 +180,7 @@ rust = "*"
175180
run-mypy = { cmd = "mypy virtualizarr" }
176181
# Using '--dist loadscope' (rather than default of '--dist load' when '-n auto'
177182
# is used), reduces test hangs that appear to be macOS-related.
178-
run-tests = { cmd = "pytest -n auto --dist loadscope --run-network-tests --verbose" }
183+
run-tests = { cmd = "pytest -n auto --dist loadscope --run-network-tests --verbose --ignore=codemcp" }
179184
run-tests-no-network = { cmd = "pytest -n auto --verbose" }
180185
run-tests-cov = { cmd = "pytest -n auto --run-network-tests --verbose --cov=term-missing" }
181186
run-tests-xml-cov = { cmd = "pytest -n auto --run-network-tests --verbose --cov-report=xml" }
@@ -185,12 +190,12 @@ run-tests-html-cov = { cmd = "pytest -n auto --run-network-tests --verbose --cov
185190
[tool.pixi.environments]
186191
min-deps = ["dev", "test", "hdf", "hdf5", "hdf5-lib"] # VirtualiZarr/conftest.py using h5py, so the minimum set of dependencies for testing still includes hdf libs
187192
# Inherit from min-deps to get all the test commands, along with optional dependencies
188-
test = ["dev", "test", "remote", "hdf", "hdf5", "netcdf3", "fits", "icechunk", "kerchunk", "hdf5-lib", "obstore"]
189-
test-py311 = ["dev", "test", "remote", "hdf", "hdf5", "netcdf3", "fits", "icechunk", "kerchunk", "hdf5-lib", "obstore", "py311"] # test against python 3.11
190-
test-py312 = ["dev", "test", "remote", "hdf", "hdf5", "netcdf3", "fits", "icechunk", "kerchunk", "hdf5-lib", "obstore", "py312"] # test against python 3.12
191-
minio = ["dev", "remote", "hdf", "hdf5", "netcdf3", "fits", "icechunk", "kerchunk", "hdf5-lib", "obstore", "py312", "minio"]
192-
upstream = ["dev", "test", "hdf", "hdf5", "hdf5-lib", "netcdf3", "upstream", "icechunk-dev"]
193-
all = ["dev", "test", "remote", "hdf", "hdf5", "netcdf3", "fits", "icechunk", "kerchunk", "hdf5-lib", "obstore", "all_readers", "all_writers"]
193+
test = ["dev", "test", "remote", "hdf", "safetensors", "hdf5", "netcdf3", "fits", "icechunk", "kerchunk", "hdf5-lib", "obstore"]
194+
test-py311 = ["dev", "test", "remote", "hdf", "safetensors", "hdf5", "netcdf3", "fits", "icechunk", "kerchunk", "hdf5-lib", "obstore", "py311"] # test against python 3.11
195+
test-py312 = ["dev", "test", "remote", "hdf", "safetensors", "hdf5", "netcdf3", "fits", "icechunk", "kerchunk", "hdf5-lib", "obstore", "py312"] # test against python 3.12
196+
minio = ["dev", "remote", "hdf", "safetensors", "hdf5", "netcdf3", "fits", "icechunk", "kerchunk", "hdf5-lib", "obstore", "py312", "minio"]
197+
upstream = ["dev", "test", "hdf", "safetensors", "hdf5", "hdf5-lib", "netcdf3", "upstream", "icechunk-dev"]
198+
all = ["dev", "test", "remote", "hdf", "safetensors", "hdf5", "netcdf3", "fits", "icechunk", "kerchunk", "hdf5-lib", "obstore", "all_readers", "all_writers"]
194199
docs = ["docs"]
195200

196201
# Define commands to run within the docs environment

virtualizarr/backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
HDFVirtualBackend,
2828
KerchunkVirtualBackend,
2929
NetCDF3VirtualBackend,
30+
SafeTensorsVirtualBackend,
3031
TIFFVirtualBackend,
3132
)
3233
from virtualizarr.readers.api import VirtualBackend
@@ -45,6 +46,7 @@
4546
"kerchunk": KerchunkVirtualBackend,
4647
"dmrpp": DMRPPVirtualBackend,
4748
"hdf5": HDFVirtualBackend,
49+
"safetensors": SafeTensorsVirtualBackend,
4850
"netcdf4": HDFVirtualBackend, # note this is the same as for hdf5
4951
# all the below call one of the kerchunk backends internally (https://fsspec.github.io/kerchunk/reference.html#file-format-backends)
5052
"netcdf3": NetCDF3VirtualBackend,
@@ -70,6 +72,7 @@ class FileType(AutoName):
7072
fits = auto()
7173
dmrpp = auto()
7274
kerchunk = auto()
75+
safetensors = auto()
7376

7477

7578
def automatically_determine_filetype(
@@ -89,6 +92,8 @@ def automatically_determine_filetype(
8992
if Path(filepath).suffix == ".zarr":
9093
# TODO we could imagine opening an existing zarr store, concatenating it, and writing a new virtual one...
9194
raise NotImplementedError()
95+
elif Path(filepath).suffix.lower() == ".safetensors":
96+
return FileType.safetensors
9297

9398
# Read magic bytes from local or remote file
9499
fpath = _FsspecFSFromFilepath(

virtualizarr/manifests/store.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ def default_object_store(filepath: str) -> ObjectStore:
150150
virtual_hosted_style_request=False,
151151
region=_find_bucket_region(bucket),
152152
)
153+
if parsed.scheme == "https" and parsed.netloc == "huggingface.co":
154+
# TO DO: timeout can be passed here via client_options kwarg e.g. {"timeout":"30s"}
155+
# TO DO: how to pass HF token with obstore? requires "authorization": f"Bearer {token}" in header.
156+
return obs.store.HTTPStore.from_url(url=f"{parsed.scheme}://{parsed.netloc}")
153157

154158
raise NotImplementedError(f"{parsed.scheme} is not yet supported")
155159

virtualizarr/readers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from virtualizarr.readers.hdf5 import HDF5VirtualBackend
55
from virtualizarr.readers.kerchunk import KerchunkVirtualBackend
66
from virtualizarr.readers.netcdf3 import NetCDF3VirtualBackend
7+
from virtualizarr.readers.safetensors import SafeTensorsVirtualBackend
78
from virtualizarr.readers.tiff import TIFFVirtualBackend
89

910
__all__ = [
@@ -13,5 +14,6 @@
1314
"HDF5VirtualBackend",
1415
"KerchunkVirtualBackend",
1516
"NetCDF3VirtualBackend",
17+
"SafeTensorsVirtualBackend",
1618
"TIFFVirtualBackend",
1719
]

0 commit comments

Comments
 (0)