77
88import numpy as np
99import pyarrow as pa
10+ import pyarrow .compute as pc
1011import pyarrow .dataset as ds
1112import pyarrow .parquet as pq
1213
@@ -46,6 +47,18 @@ def _decode_schema(encoded_schema: str) -> pa.Schema:
4647 schema_bytes = base64 .b64decode (encoded_schema )
4748 return pa .ipc .read_schema (pa .BufferReader (schema_bytes ))
4849
50+ def count_rows (self ) -> int :
51+ """Count the number of rows in the cache."""
52+ dataset = ds .dataset (
53+ source = self ._path ,
54+ format = "parquet" ,
55+ )
56+ return dataset .count_rows ()
57+
58+ def count_tables (self ) -> int :
59+ """Count the number of files in the cache."""
60+ return len (self .get_dataset_files ())
61+
4962 def get_files (self ) -> list [Path ]:
5063 """
5164 Retrieve all files.
@@ -80,6 +93,106 @@ def get_dataset_files(self) -> list[Path]:
8093 ]
8194
8295
96+ class FileCacheReader (FileCache ):
97+ def __init__ (
98+ self ,
99+ path : str | Path ,
100+ schema : pa .Schema ,
101+ batch_size : int ,
102+ rows_per_file : int ,
103+ compression : str ,
104+ ):
105+ super ().__init__ (path )
106+ self ._schema = schema
107+ self ._batch_size = batch_size
108+ self ._rows_per_file = rows_per_file
109+ self ._compression = compression
110+
111+ @property
112+ def schema (self ) -> pa .Schema :
113+ return self ._schema
114+
115+ @property
116+ def batch_size (self ) -> int :
117+ return self ._batch_size
118+
119+ @property
120+ def rows_per_file (self ) -> int :
121+ return self ._rows_per_file
122+
123+ @property
124+ def compression (self ) -> str :
125+ return self ._compression
126+
127+ @classmethod
128+ def load (cls , path : str | Path | FileCache ):
129+ """
130+ Load cache from disk.
131+
132+ Parameters
133+ ----------
134+ path : str | Path
135+ Where the cache is stored.
136+ """
137+ if isinstance (path , FileCache ):
138+ path = path .path
139+ path = Path (path )
140+ if not path .exists ():
141+ raise FileNotFoundError (f"Directory does not exist: { path } " )
142+ elif not path .is_dir ():
143+ raise NotADirectoryError (
144+ f"Path exists but is not a directory: { path } "
145+ )
146+
147+ def _retrieve (config : dict , key : str ):
148+ if value := config .get (key , None ):
149+ return value
150+ raise KeyError (
151+ f"'{ key } ' is not defined within { cls ._generate_config_path (path )} "
152+ )
153+
154+ # read configuration file
155+ cfg_path = cls ._generate_config_path (path )
156+ with open (cfg_path , "r" ) as f :
157+ cfg = json .load (f )
158+ batch_size = _retrieve (cfg , "batch_size" )
159+ rows_per_file = _retrieve (cfg , "rows_per_file" )
160+ compression = _retrieve (cfg , "compression" )
161+ schema = cls ._decode_schema (_retrieve (cfg , "schema" ))
162+
163+ return cls (
164+ schema = schema ,
165+ path = path ,
166+ batch_size = batch_size ,
167+ rows_per_file = rows_per_file ,
168+ compression = compression ,
169+ )
170+
171+ def iterate_tables (
172+ self ,
173+ columns : list [str ] | None = None ,
174+ filter : pc .Expression | None = None ,
175+ ):
176+ """Iterate over tables within the cache."""
177+ dataset = ds .dataset (
178+ source = self ._path ,
179+ schema = self ._schema ,
180+ format = "parquet" ,
181+ )
182+ for fragment in dataset .get_fragments (filter = filter ):
183+ yield fragment .to_table (columns = columns )
184+
185+ def iterate_fragments (self ):
186+ """Iterate over fragments within the file-based cache."""
187+ dataset = ds .dataset (
188+ source = self ._path ,
189+ schema = self ._schema ,
190+ format = "parquet" ,
191+ )
192+ for fragment in dataset .get_fragments ():
193+ yield fragment
194+
195+
83196class FileCacheWriter (FileCache ):
84197 def __init__ (
85198 self ,
@@ -89,7 +202,7 @@ def __init__(
89202 rows_per_file : int ,
90203 compression : str ,
91204 ):
92- self . _path = Path (path )
205+ super (). __init__ (path )
93206 self ._schema = schema
94207 self ._batch_size = batch_size
95208 self ._rows_per_file = rows_per_file
@@ -108,6 +221,7 @@ def create(
108221 batch_size : int ,
109222 rows_per_file : int ,
110223 compression : str = "snappy" ,
224+ delete_if_exists : bool = False ,
111225 ):
112226 """
113227 Create a cache on disk.
@@ -124,7 +238,12 @@ def create(
124238 Target number of rows to store per file.
125239 compression : str, default="snappy"
126240 Compression method to use when storing on disk.
241+ delete_if_exists : bool, default=False
242+ Delete the cache if it already exists.
127243 """
244+ path = Path (path )
245+ if delete_if_exists and path .exists ():
246+ cls .delete (path )
128247 Path (path ).mkdir (parents = True , exist_ok = False )
129248
130249 # write configuration file
@@ -146,29 +265,33 @@ def create(
146265 compression = compression ,
147266 )
148267
149- def delete (self ):
268+ @classmethod
269+ def delete (cls , path : str | Path ):
150270 """
151- Delete the cache.
271+ Delete a cache at path .
152272
153273 Parameters
154274 ----------
155275 path : str | Path
156276 Where the cache is stored.
157277 """
158- if not self ._path .exists ():
278+ path = Path (path )
279+ if not path .exists ():
159280 return
160- # clear buffer
161- self .flush ()
162- # delete config file
163- cfg_path = self ._generate_config_path (self ._path )
164- if cfg_path .exists () and cfg_path .is_file ():
165- cfg_path .unlink ()
281+
166282 # delete dataset files
167- for file in self .get_dataset_files ():
283+ reader = FileCacheReader .load (path )
284+ for file in reader .get_dataset_files ():
168285 if file .exists () and file .is_file () and file .suffix == ".parquet" :
169286 file .unlink ()
287+
288+ # delete config file
289+ cfg_path = cls ._generate_config_path (path )
290+ if cfg_path .exists () and cfg_path .is_file ():
291+ cfg_path .unlink ()
292+
170293 # delete empty cache directory
171- self . _path .rmdir ()
294+ path .rmdir ()
172295
173296 def write_rows (
174297 self ,
@@ -297,69 +420,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
297420 """Context manager exit - ensures data is flushed."""
298421 self .flush ()
299422
300-
301- class FileCacheReader (FileCache ):
302- def __init__ (
303- self ,
304- path : str | Path ,
305- schema : pa .Schema ,
306- ):
307- self ._schema = schema
308- self ._path = Path (path )
309-
310- @classmethod
311- def load (cls , path : str | Path | FileCache ):
312- """
313- Load cache from disk.
314-
315- Parameters
316- ----------
317- path : str | Path
318- Where the cache is stored.
319- """
320- if isinstance (path , FileCache ):
321- path = path .path
322- path = Path (path )
323- if not path .exists ():
324- raise FileNotFoundError (f"Directory does not exist: { path } " )
325- elif not path .is_dir ():
326- raise NotADirectoryError (
327- f"Path exists but is not a directory: { path } "
328- )
329-
330- def _retrieve (config : dict , key : str ):
331- if value := config .get (key , None ):
332- return value
333- raise KeyError (
334- f"'{ key } ' is not defined within { cls ._generate_config_path (path )} "
335- )
336-
337- # read configuration file
338- cfg_path = cls ._generate_config_path (path )
339- with open (cfg_path , "r" ) as f :
340- cfg = json .load (f )
341- schema = cls ._decode_schema (_retrieve (cfg , "schema" ))
342-
343- return cls (
344- schema = schema ,
345- path = path ,
346- )
347-
348- def count_rows (self ) -> int :
349- """Count the number of rows in the cache."""
350- dataset = ds .dataset (
351- source = self ._path ,
352- schema = self ._schema ,
353- format = "parquet" ,
354- )
355- return dataset .count_rows ()
356-
357- def iterate_tables (self ):
358- """Iterate over tables within the cache."""
359- dataset = ds .dataset (
360- source = self ._path ,
361- schema = self ._schema ,
362- format = "parquet" ,
363- )
364- for fragment in dataset .get_fragments ():
365- yield fragment .to_table ()
423+ def to_reader (self ) -> FileCacheReader :
424+ """Get cache reader."""
425+ return FileCacheReader .load (path = self .path )
0 commit comments