7
7
import os
8
8
from itertools import chain
9
9
from functools import lru_cache
10
- from typing import TYPE_CHECKING
10
+ from typing import TYPE_CHECKING , Literal
11
11
12
12
import fsspec .core
13
13
@@ -104,7 +104,13 @@ def pd(self):
104
104
return pd
105
105
106
106
def __init__ (
107
- self , root , fs = None , out_root = None , cache_size = 128 , categorical_threshold = 10
107
+ self ,
108
+ root ,
109
+ fs = None ,
110
+ out_root = None ,
111
+ cache_size = 128 ,
112
+ categorical_threshold = 10 ,
113
+ engine : Literal ["fastparquet" , "pyarrow" ] = "fastparquet" ,
108
114
):
109
115
"""
110
116
@@ -126,16 +132,25 @@ def __init__(
126
132
Encode urls as pandas.Categorical to reduce memory footprint if the ratio
127
133
of the number of unique urls to total number of refs for each variable
128
134
is greater than or equal to this number. (default 10)
135
+ engine: Literal["fastparquet","pyarrow"]
136
+ Engine choice for reading parquet files. (default is "fastparquet")
129
137
"""
138
+
130
139
self .root = root
131
140
self .chunk_sizes = {}
132
141
self .out_root = out_root or self .root
133
142
self .cat_thresh = categorical_threshold
143
+ self .engine = engine
134
144
self .cache_size = cache_size
135
145
self .url = self .root + "/{field}/refs.{record}.parq"
136
146
# TODO: derive fs from `root`
137
147
self .fs = fsspec .filesystem ("file" ) if fs is None else fs
138
148
149
+ from importlib .util import find_spec
150
+
151
+ if self .engine == "pyarrow" and find_spec ("pyarrow" ) is None :
152
+ raise ImportError ("engine choice `pyarrow` is not installed." )
153
+
139
154
def __getattr__ (self , item ):
140
155
if item in ("_items" , "record_size" , "zmetadata" ):
141
156
self .setup ()
@@ -158,7 +173,7 @@ def open_refs(field, record):
158
173
"""cached parquet file loader"""
159
174
path = self .url .format (field = field , record = record )
160
175
data = io .BytesIO (self .fs .cat_file (path ))
161
- df = self .pd .read_parquet (data , engine = "fastparquet" )
176
+ df = self .pd .read_parquet (data , engine = self . engine )
162
177
refs = {c : df [c ].to_numpy () for c in df .columns }
163
178
return refs
164
179
@@ -463,18 +478,28 @@ def write(self, field, record, base_url=None, storage_options=None):
463
478
464
479
fn = f"{ base_url or self .out_root } /{ field } /refs.{ record } .parq"
465
480
self .fs .mkdirs (f"{ base_url or self .out_root } /{ field } " , exist_ok = True )
481
+
482
+ if self .engine == "pyarrow" :
483
+ df_backend_kwargs = {"write_statistics" : False }
484
+ elif self .engine == "fastparquet" :
485
+ df_backend_kwargs = {
486
+ "stats" : False ,
487
+ "object_encoding" : object_encoding ,
488
+ "has_nulls" : has_nulls ,
489
+ }
490
+ else :
491
+ raise NotImplementedError (f"{ self .engine } not supported" )
492
+
466
493
df .to_parquet (
467
494
fn ,
468
- engine = "fastparquet" ,
495
+ engine = self . engine ,
469
496
storage_options = storage_options
470
497
or getattr (self .fs , "storage_options" , None ),
471
498
compression = "zstd" ,
472
499
index = False ,
473
- stats = False ,
474
- object_encoding = object_encoding ,
475
- has_nulls = has_nulls ,
476
- # **kwargs,
500
+ ** df_backend_kwargs ,
477
501
)
502
+
478
503
partition .clear ()
479
504
self ._items .pop ((field , record ))
480
505
@@ -486,6 +511,7 @@ def flush(self, base_url=None, storage_options=None):
486
511
base_url: str
487
512
Location of the output
488
513
"""
514
+
489
515
# write what we have so far and clear sub chunks
490
516
for thing in list (self ._items ):
491
517
if isinstance (thing , tuple ):
0 commit comments