Skip to content

Commit

Permalink
Cope with zarr3 Buffers in referenceFS (#1784)
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant authored Jan 29, 2025
1 parent fe59f48 commit 08d1e49
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions fsspec/implementations/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,14 @@ def __setitem__(self, key, value):
self.write(field, record)
else:
# metadata or top-level
self._items[key] = value
new_value = json.loads(
value.decode() if isinstance(value, bytes) else value
)
if hasattr(value, "to_bytes"):
val = value.to_bytes().decode()
elif isinstance(value, bytes):
val = value.decode()
else:
val = value
self._items[key] = val
new_value = json.loads(val)
self.zmetadata[key] = {**self.zmetadata.get(key, {}), **new_value}

@staticmethod
Expand Down Expand Up @@ -606,6 +610,7 @@ class ReferenceFileSystem(AsyncFileSystem):
"""

protocol = "reference"
cachable = False

def __init__(
self,
Expand Down Expand Up @@ -762,6 +767,11 @@ def __init__(
for k, f in self.fss.items():
if not f.async_impl:
self.fss[k] = AsyncFileSystemWrapper(f)
elif self.asynchronous ^ f.asynchronous:
raise ValueError(
"Reference-FS's target filesystem must have same value"
"of asynchronous"
)

def _cat_common(self, path, start=None, end=None):
path = self._strip_protocol(path)
Expand All @@ -772,6 +782,8 @@ def _cat_common(self, path, start=None, end=None):
raise FileNotFoundError(path) from exc
if isinstance(part, str):
part = part.encode()
if hasattr(part, "to_bytes"):
part = part.to_bytes()
if isinstance(part, bytes):
logger.debug(f"Reference: {path}, type bytes")
if part.startswith(b"base64:"):
Expand Down Expand Up @@ -1073,7 +1085,7 @@ def _dircache_from_items(self):
self.dircache = {"": []}
it = self.references.items()
for path, part in it:
if isinstance(part, (bytes, str)):
if isinstance(part, (bytes, str)) or hasattr(part, "to_bytes"):
size = len(part)
elif len(part) == 1:
size = None
Expand Down Expand Up @@ -1104,6 +1116,7 @@ def _open(self, path, mode="rb", block_size=None, cache_options=None, **kwargs):
return io.BytesIO(data)

def ls(self, path, detail=True, **kwargs):
logger.debug("list %s", path)
path = self._strip_protocol(path)
if isinstance(self.references, LazyReferenceMapper):
try:
Expand Down

0 comments on commit 08d1e49

Please sign in to comment.