Skip to content

Commit 318d923

Browse files
committed
updated
1 parent cb0e957 commit 318d923

4 files changed

Lines changed: 190 additions & 173 deletions

File tree

src/valor_lite/common/ephemeral.py

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,61 @@
22

33
import numpy as np
44
import pyarrow as pa
5+
import pyarrow.compute as pc
56

67

78
class MemoryCache:
89
def __init__(self, table: pa.Table):
910
self._table = table
1011

12+
def count_tables(self) -> int:
13+
"""Count the number of tables in the cache."""
14+
return 1
15+
1116
def count_rows(self) -> int:
1217
"""Count the number of rows in the cache."""
1318
return self._table.num_rows
1419

1520

21+
class MemoryCacheReader(MemoryCache):
22+
def __init__(
23+
self,
24+
table: pa.Table,
25+
batch_size: int,
26+
):
27+
super().__init__(table)
28+
self._schema = self._table.schema
29+
self._batch_size = batch_size
30+
31+
@property
32+
def schema(self) -> pa.Schema:
33+
return self._schema
34+
35+
@property
36+
def batch_size(self) -> int:
37+
return self._batch_size
38+
39+
def iterate_tables(
40+
self,
41+
columns: list[str] | None = None,
42+
filter: pc.Expression | None = None,
43+
):
44+
"""Iterate over tables within the cache."""
45+
table = self._table
46+
if filter is not None:
47+
table = table.filter(filter)
48+
if columns is not None:
49+
table = table.select(columns)
50+
yield table
51+
52+
1653
class MemoryCacheWriter(MemoryCache):
1754
def __init__(
1855
self,
1956
table: pa.Table,
2057
batch_size: int,
2158
):
22-
self._table = table
59+
super().__init__(table)
2360
self._schema = table.schema
2461
self._batch_size = batch_size
2562

@@ -47,13 +84,6 @@ def create(
4784
batch_size=batch_size,
4885
)
4986

50-
def delete(self):
51-
"""
52-
Delete any existing cache data.
53-
"""
54-
self._buffer = []
55-
self._table = self._table.schema.empty_table()
56-
5787
def write_rows(
5888
self,
5989
rows: list[dict[str, Any]],
@@ -146,31 +176,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
146176
"""Context manager exit - ensures data is flushed."""
147177
self.flush()
148178

149-
150-
class MemoryCacheReader:
151-
def __init__(
152-
self,
153-
cache: MemoryCacheWriter,
154-
):
155-
self._cache = cache
156-
self._schema = self._cache._schema
157-
158-
@classmethod
159-
def load(cls, cache: MemoryCacheWriter):
160-
"""
161-
Load cache from table.
162-
163-
Parameters
164-
----------
165-
cache : MemoryCacheWriter
166-
A cache writer containing the ephemeral cache.
167-
"""
168-
return cls(cache=cache)
169-
170-
def iterate_tables(self):
171-
"""Iterate over tables within the cache."""
172-
yield self._cache._table
173-
174-
def count_rows(self) -> int:
175-
"""Count the number of rows in the cache."""
176-
return self._cache.count_rows()
179+
def to_reader(self) -> MemoryCacheReader:
180+
"""Get cache reader."""
181+
return MemoryCacheReader(
182+
table=self._table, batch_size=self._batch_size
183+
)

src/valor_lite/common/persistent.py

Lines changed: 138 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import pyarrow as pa
10+
import pyarrow.compute as pc
1011
import pyarrow.dataset as ds
1112
import pyarrow.parquet as pq
1213

@@ -46,6 +47,18 @@ def _decode_schema(encoded_schema: str) -> pa.Schema:
4647
schema_bytes = base64.b64decode(encoded_schema)
4748
return pa.ipc.read_schema(pa.BufferReader(schema_bytes))
4849

50+
def count_rows(self) -> int:
51+
"""Count the number of rows in the cache."""
52+
dataset = ds.dataset(
53+
source=self._path,
54+
format="parquet",
55+
)
56+
return dataset.count_rows()
57+
58+
def count_tables(self) -> int:
59+
"""Count the number of files in the cache."""
60+
return len(self.get_dataset_files())
61+
4962
def get_files(self) -> list[Path]:
5063
"""
5164
Retrieve all files.
@@ -80,6 +93,106 @@ def get_dataset_files(self) -> list[Path]:
8093
]
8194

8295

96+
class FileCacheReader(FileCache):
97+
def __init__(
98+
self,
99+
path: str | Path,
100+
schema: pa.Schema,
101+
batch_size: int,
102+
rows_per_file: int,
103+
compression: str,
104+
):
105+
super().__init__(path)
106+
self._schema = schema
107+
self._batch_size = batch_size
108+
self._rows_per_file = rows_per_file
109+
self._compression = compression
110+
111+
@property
112+
def schema(self) -> pa.Schema:
113+
return self._schema
114+
115+
@property
116+
def batch_size(self) -> int:
117+
return self._batch_size
118+
119+
@property
120+
def rows_per_file(self) -> int:
121+
return self._rows_per_file
122+
123+
@property
124+
def compression(self) -> str:
125+
return self._compression
126+
127+
@classmethod
128+
def load(cls, path: str | Path | FileCache):
129+
"""
130+
Load cache from disk.
131+
132+
Parameters
133+
----------
134+
path : str | Path
135+
Where the cache is stored.
136+
"""
137+
if isinstance(path, FileCache):
138+
path = path.path
139+
path = Path(path)
140+
if not path.exists():
141+
raise FileNotFoundError(f"Directory does not exist: {path}")
142+
elif not path.is_dir():
143+
raise NotADirectoryError(
144+
f"Path exists but is not a directory: {path}"
145+
)
146+
147+
def _retrieve(config: dict, key: str):
148+
if value := config.get(key, None):
149+
return value
150+
raise KeyError(
151+
f"'{key}' is not defined within {cls._generate_config_path(path)}"
152+
)
153+
154+
# read configuration file
155+
cfg_path = cls._generate_config_path(path)
156+
with open(cfg_path, "r") as f:
157+
cfg = json.load(f)
158+
batch_size = _retrieve(cfg, "batch_size")
159+
rows_per_file = _retrieve(cfg, "rows_per_file")
160+
compression = _retrieve(cfg, "compression")
161+
schema = cls._decode_schema(_retrieve(cfg, "schema"))
162+
163+
return cls(
164+
schema=schema,
165+
path=path,
166+
batch_size=batch_size,
167+
rows_per_file=rows_per_file,
168+
compression=compression,
169+
)
170+
171+
def iterate_tables(
172+
self,
173+
columns: list[str] | None = None,
174+
filter: pc.Expression | None = None,
175+
):
176+
"""Iterate over tables within the cache."""
177+
dataset = ds.dataset(
178+
source=self._path,
179+
schema=self._schema,
180+
format="parquet",
181+
)
182+
for fragment in dataset.get_fragments(filter=filter):
183+
yield fragment.to_table(columns=columns)
184+
185+
def iterate_fragments(self):
186+
"""Iterate over fragments within the file-based cache."""
187+
dataset = ds.dataset(
188+
source=self._path,
189+
schema=self._schema,
190+
format="parquet",
191+
)
192+
for fragment in dataset.get_fragments():
193+
yield fragment
194+
195+
83196
class FileCacheWriter(FileCache):
84197
def __init__(
85198
self,
@@ -89,7 +202,7 @@ def __init__(
89202
rows_per_file: int,
90203
compression: str,
91204
):
92-
self._path = Path(path)
205+
super().__init__(path)
93206
self._schema = schema
94207
self._batch_size = batch_size
95208
self._rows_per_file = rows_per_file
@@ -108,6 +221,7 @@ def create(
108221
batch_size: int,
109222
rows_per_file: int,
110223
compression: str = "snappy",
224+
delete_if_exists: bool = False,
111225
):
112226
"""
113227
Create a cache on disk.
@@ -124,7 +238,12 @@ def create(
124238
Target number of rows to store per file.
125239
compression : str, default="snappy"
126240
Compression method to use when storing on disk.
241+
delete_if_exists : bool, default=False
242+
Delete the cache if it already exists.
127243
"""
244+
path = Path(path)
245+
if delete_if_exists and path.exists():
246+
cls.delete(path)
128247
Path(path).mkdir(parents=True, exist_ok=False)
129248

130249
# write configuration file
@@ -146,29 +265,33 @@ def create(
146265
compression=compression,
147266
)
148267

149-
def delete(self):
268+
@classmethod
269+
def delete(cls, path: str | Path):
150270
"""
151-
Delete the cache.
271+
Delete a cache at path.
152272
153273
Parameters
154274
----------
155275
path : str | Path
156276
Where the cache is stored.
157277
"""
158-
if not self._path.exists():
278+
path = Path(path)
279+
if not path.exists():
159280
return
160-
# clear buffer
161-
self.flush()
162-
# delete config file
163-
cfg_path = self._generate_config_path(self._path)
164-
if cfg_path.exists() and cfg_path.is_file():
165-
cfg_path.unlink()
281+
166282
# delete dataset files
167-
for file in self.get_dataset_files():
283+
reader = FileCacheReader.load(path)
284+
for file in reader.get_dataset_files():
168285
if file.exists() and file.is_file() and file.suffix == ".parquet":
169286
file.unlink()
287+
288+
# delete config file
289+
cfg_path = cls._generate_config_path(path)
290+
if cfg_path.exists() and cfg_path.is_file():
291+
cfg_path.unlink()
292+
170293
# delete empty cache directory
171-
self._path.rmdir()
294+
path.rmdir()
172295

173296
def write_rows(
174297
self,
@@ -297,69 +420,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
297420
"""Context manager exit - ensures data is flushed."""
298421
self.flush()
299422

300-
301-
class FileCacheReader(FileCache):
302-
def __init__(
303-
self,
304-
path: str | Path,
305-
schema: pa.Schema,
306-
):
307-
self._schema = schema
308-
self._path = Path(path)
309-
310-
@classmethod
311-
def load(cls, path: str | Path | FileCache):
312-
"""
313-
Load cache from disk.
314-
315-
Parameters
316-
----------
317-
path : str | Path
318-
Where the cache is stored.
319-
"""
320-
if isinstance(path, FileCache):
321-
path = path.path
322-
path = Path(path)
323-
if not path.exists():
324-
raise FileNotFoundError(f"Directory does not exist: {path}")
325-
elif not path.is_dir():
326-
raise NotADirectoryError(
327-
f"Path exists but is not a directory: {path}"
328-
)
329-
330-
def _retrieve(config: dict, key: str):
331-
if value := config.get(key, None):
332-
return value
333-
raise KeyError(
334-
f"'{key}' is not defined within {cls._generate_config_path(path)}"
335-
)
336-
337-
# read configuration file
338-
cfg_path = cls._generate_config_path(path)
339-
with open(cfg_path, "r") as f:
340-
cfg = json.load(f)
341-
schema = cls._decode_schema(_retrieve(cfg, "schema"))
342-
343-
return cls(
344-
schema=schema,
345-
path=path,
346-
)
347-
348-
def count_rows(self) -> int:
349-
"""Count the number of rows in the cache."""
350-
dataset = ds.dataset(
351-
source=self._path,
352-
schema=self._schema,
353-
format="parquet",
354-
)
355-
return dataset.count_rows()
356-
357-
def iterate_tables(self):
358-
"""Iterate over tables within the cache."""
359-
dataset = ds.dataset(
360-
source=self._path,
361-
schema=self._schema,
362-
format="parquet",
363-
)
364-
for fragment in dataset.get_fragments():
365-
yield fragment.to_table()
423+
def to_reader(self) -> FileCacheReader:
424+
"""Get cache reader."""
425+
return FileCacheReader.load(path=self.path)

0 commit comments

Comments
 (0)