|
2 | 2 | import struct |
3 | 3 | from collections.abc import Iterable, Mapping |
4 | 4 | from pathlib import Path |
5 | | -from typing import Any, Dict, Optional |
| 5 | +from typing import TYPE_CHECKING, Any, Dict, Optional |
6 | 6 | from urllib.parse import urlparse |
7 | 7 |
|
8 | 8 | import numpy as np |
9 | | -from obstore.store import ( |
10 | | - HTTPStore, |
11 | | - LocalStore, |
12 | | - ObjectStore, # type: ignore[import-not-found] |
13 | | -) |
14 | 9 | from xarray import Dataset, Index |
15 | 10 |
|
16 | 11 | from virtualizarr.manifests import ( |
|
26 | 21 | from virtualizarr.readers.api import VirtualBackend |
27 | 22 | from virtualizarr.types import ChunkKey |
28 | 23 |
|
| 24 | +if TYPE_CHECKING: |
| 25 | + from obstore.store import ( |
| 26 | + ObjectStore, # type: ignore[import-not-found] |
| 27 | + ) |
| 28 | + |
29 | 29 |
|
30 | 30 | class SafeTensorsVirtualBackend(VirtualBackend): |
31 | 31 | """ |
@@ -67,7 +67,7 @@ class SafeTensorsVirtualBackend(VirtualBackend): |
67 | 67 |
|
68 | 68 | @staticmethod |
69 | 69 | def _parse_safetensors_header( |
70 | | - filepath: str, store: ObjectStore |
| 70 | + filepath: str, store: "ObjectStore" |
71 | 71 | ) -> tuple[dict[str, Any], int]: |
72 | 72 | """ |
73 | 73 | Parse the header of a SafeTensors file to extract metadata. |
@@ -131,7 +131,7 @@ def _parse_safetensors_header( |
131 | 131 | def _create_manifest_group( |
132 | 132 | filepath: str, |
133 | 133 | drop_variables: list, |
134 | | - store: ObjectStore, |
| 134 | + store: "ObjectStore", |
135 | 135 | dimension_names: Optional[Dict[str, list[str]]] = None, |
136 | 136 | ) -> ManifestGroup: |
137 | 137 | """ |
@@ -207,8 +207,11 @@ def _create_manifest_group( |
207 | 207 |
|
208 | 208 | data_start = 8 + header_size |
209 | 209 |
|
| 210 | + def should_skip_tensor(tensor_name: str, drop_variables: list) -> bool: |
| 211 | + return tensor_name == "__metadata__" or tensor_name in drop_variables |
| 212 | + |
210 | 213 | for tensor_name, tensor_info in header.items(): |
211 | | - if tensor_name == "__metadata__" or tensor_name in drop_variables: |
| 214 | + if should_skip_tensor(tensor_name, drop_variables): |
212 | 215 | continue |
213 | 216 |
|
214 | 217 | dtype_str = tensor_info["dtype"] |
@@ -328,6 +331,11 @@ def _create_manifest_store( |
328 | 331 | ... revision="v2.0" |
329 | 332 | ... ) |
330 | 333 | """ |
| 334 | + from obstore.store import ( |
| 335 | + HTTPStore, |
| 336 | + LocalStore, |
| 337 | + ) |
| 338 | + |
331 | 339 | store_registry = ObjectStoreRegistry() |
332 | 340 | store = default_object_store(filepath) |
333 | 341 |
|
@@ -518,6 +526,20 @@ def _create_chunk_manifest( |
518 | 526 | a chunk manifest that points to the exact location of a tensor within the file, |
519 | 527 | treating the entire tensor as a single chunk for efficient memory mapping. |
520 | 528 |
|
| 529 | + The structure of the variable names within a Safetensors file often reflects a |
| 530 | + hierarchical organization, commonly represented using a dot separator (e.g., |
| 531 | + 'a.b.c'). While this structure could naturally map to a nested format like Zarr |
| 532 | + groups (e.g., a/b/c), the dominant framework for using these models, PyTorch, |
| 533 | + utilizes a flattened dictionary structure (a 'state dict') where these dot-separated |
| 534 | + names serve as keys. |
| 535 | +
|
| 536 | + To ease integration with PyTorch's expected format, ChunkManifests are currently a |
| 537 | + flattened dictionary where the keys are the dot-separated variable names. |
| 538 | +
|
| 539 | + Further consideration could be given to optionally returning the data as an |
| 540 | + xarray.DataTree to better represent the inherent hierarchical structure, but |
| 541 | + this has been deferred to prioritize compatibility with PyTorch workflows. |
| 542 | +
|
521 | 543 | Parameters |
522 | 544 | ---------- |
523 | 545 | filepath : str |
|
0 commit comments