Skip to content

Commit 176efbe

Browse files
Adds engine option to LazyReferenceMapper (#1692)
1 parent f2c7717 commit 176efbe

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

Diff for: fsspec/implementations/reference.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
from itertools import chain
99
from functools import lru_cache
10-
from typing import TYPE_CHECKING
10+
from typing import TYPE_CHECKING, Literal
1111

1212
import fsspec.core
1313

@@ -104,7 +104,13 @@ def pd(self):
104104
return pd
105105

106106
def __init__(
107-
self, root, fs=None, out_root=None, cache_size=128, categorical_threshold=10
107+
self,
108+
root,
109+
fs=None,
110+
out_root=None,
111+
cache_size=128,
112+
categorical_threshold=10,
113+
engine: Literal["fastparquet", "pyarrow"] = "fastparquet",
108114
):
109115
"""
110116
@@ -126,16 +132,25 @@ def __init__(
126132
Encode urls as pandas.Categorical to reduce memory footprint if the ratio
127133
of the number of unique urls to total number of refs for each variable
128134
is greater than or equal to this number. (default 10)
135+
engine: Literal["fastparquet","pyarrow"]
136+
Engine choice for reading parquet files. (default is "fastparquet")
129137
"""
138+
130139
self.root = root
131140
self.chunk_sizes = {}
132141
self.out_root = out_root or self.root
133142
self.cat_thresh = categorical_threshold
143+
self.engine = engine
134144
self.cache_size = cache_size
135145
self.url = self.root + "/{field}/refs.{record}.parq"
136146
# TODO: derive fs from `root`
137147
self.fs = fsspec.filesystem("file") if fs is None else fs
138148

149+
from importlib.util import find_spec
150+
151+
if self.engine == "pyarrow" and find_spec("pyarrow") is None:
152+
raise ImportError("engine choice `pyarrow` is not installed.")
153+
139154
def __getattr__(self, item):
140155
if item in ("_items", "record_size", "zmetadata"):
141156
self.setup()
@@ -158,7 +173,7 @@ def open_refs(field, record):
158173
"""cached parquet file loader"""
159174
path = self.url.format(field=field, record=record)
160175
data = io.BytesIO(self.fs.cat_file(path))
161-
df = self.pd.read_parquet(data, engine="fastparquet")
176+
df = self.pd.read_parquet(data, engine=self.engine)
162177
refs = {c: df[c].to_numpy() for c in df.columns}
163178
return refs
164179

@@ -463,18 +478,28 @@ def write(self, field, record, base_url=None, storage_options=None):
463478

464479
fn = f"{base_url or self.out_root}/{field}/refs.{record}.parq"
465480
self.fs.mkdirs(f"{base_url or self.out_root}/{field}", exist_ok=True)
481+
482+
if self.engine == "pyarrow":
483+
df_backend_kwargs = {"write_statistics": False}
484+
elif self.engine == "fastparquet":
485+
df_backend_kwargs = {
486+
"stats": False,
487+
"object_encoding": object_encoding,
488+
"has_nulls": has_nulls,
489+
}
490+
else:
491+
raise NotImplementedError(f"{self.engine} not supported")
492+
466493
df.to_parquet(
467494
fn,
468-
engine="fastparquet",
495+
engine=self.engine,
469496
storage_options=storage_options
470497
or getattr(self.fs, "storage_options", None),
471498
compression="zstd",
472499
index=False,
473-
stats=False,
474-
object_encoding=object_encoding,
475-
has_nulls=has_nulls,
476-
# **kwargs,
500+
**df_backend_kwargs,
477501
)
502+
478503
partition.clear()
479504
self._items.pop((field, record))
480505

@@ -486,6 +511,7 @@ def flush(self, base_url=None, storage_options=None):
486511
base_url: str
487512
Location of the output
488513
"""
514+
489515
# write what we have so far and clear sub chunks
490516
for thing in list(self._items):
491517
if isinstance(thing, tuple):

Diff for: fsspec/implementations/tests/test_reference.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -761,25 +761,33 @@ def test_append_parquet(lazy_refs, m):
761761
assert lazy2["data/1"] == b"Adata"
762762

763763

764-
def test_deep_parq(m):
764+
@pytest.mark.parametrize("engine", ["fastparquet", "pyarrow"])
765+
def test_deep_parq(m, engine):
765766
pytest.importorskip("kerchunk")
766767
zarr = pytest.importorskip("zarr")
768+
767769
lz = fsspec.implementations.reference.LazyReferenceMapper.create(
768-
"memory://out.parq", fs=m
770+
"memory://out.parq",
771+
fs=m,
772+
engine=engine,
769773
)
770774
g = zarr.open_group(lz, mode="w")
775+
771776
g2 = g.create_group("instant")
772777
g2.create_dataset(name="one", data=[1, 2, 3])
773778
lz.flush()
774779

775-
lz = fsspec.implementations.reference.LazyReferenceMapper("memory://out.parq", fs=m)
780+
lz = fsspec.implementations.reference.LazyReferenceMapper(
781+
"memory://out.parq", fs=m, engine=engine
782+
)
776783
g = zarr.open_group(lz)
777784
assert g.instant.one[:].tolist() == [1, 2, 3]
778785
assert sorted(_["name"] for _ in lz.ls("")) == [".zgroup", ".zmetadata", "instant"]
779786
assert sorted(_["name"] for _ in lz.ls("instant")) == [
780787
"instant/.zgroup",
781788
"instant/one",
782789
]
790+
783791
assert sorted(_["name"] for _ in lz.ls("instant/one")) == [
784792
"instant/one/.zarray",
785793
"instant/one/0",

0 commit comments

Comments
 (0)