Skip to content

Commit dd9c479

Browse files
committed
broke out caching
1 parent 75d3c19 commit dd9c479

4 files changed

Lines changed: 760 additions & 0 deletions

File tree

src/valor_lite/cache.py

Lines changed: 380 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,380 @@
1+
import glob
2+
import json
3+
import os
4+
from datetime import datetime
5+
from enum import StrEnum
6+
from pathlib import Path
7+
from typing import Any
8+
9+
import numpy as np
10+
import pyarrow as pa
11+
import pyarrow.dataset as ds
12+
import pyarrow.lib as pl
13+
import pyarrow.parquet as pq
14+
15+
16+
class DataType(StrEnum):
17+
INTEGER = "int"
18+
FLOAT = "float"
19+
STRING = "string"
20+
TIMESTAMP = "timestamp"
21+
22+
def to_py(self):
23+
match self:
24+
case DataType.INTEGER:
25+
return int
26+
case DataType.FLOAT:
27+
return float
28+
case DataType.STRING:
29+
return str
30+
case DataType.TIMESTAMP:
31+
return datetime
32+
33+
def to_arrow(self):
34+
match self:
35+
case DataType.INTEGER:
36+
return pa.int64()
37+
case DataType.FLOAT:
38+
return pa.float64()
39+
case DataType.STRING:
40+
return pa.string()
41+
case DataType.TIMESTAMP:
42+
return pa.timestamp("us")
43+
44+
45+
def convert_type_mapping_to_schema(
46+
type_mapping: dict[str, DataType] | None
47+
) -> list[tuple[str, pl.DataType]]:
48+
"""
49+
Convert type mapping to a pyarrow schema input.
50+
51+
Parameters
52+
----------
53+
type_mapping : dict[str, DataType] | None
54+
A map from string key to datatype. Treats input of `None` as empty mapping.
55+
56+
Returns
57+
-------
58+
list[tuple[str, pyarrow.lib.DataType]]
59+
A list of field name, field type pairs that can be used as input to pyarrow.schema.
60+
"""
61+
if not type_mapping:
62+
return []
63+
return [(k, DataType(v).to_arrow()) for k, v in type_mapping.items()]
64+
65+
66+
class CacheFiles:
67+
def __init__(self, path: str | Path):
68+
self._path = Path(path)
69+
70+
@property
71+
def path(self) -> Path:
72+
return self._path
73+
74+
@property
75+
def files(self) -> list[Path]:
76+
if not self.path.exists():
77+
return []
78+
files = []
79+
for entry in os.listdir(self._path):
80+
full_path = os.path.join(self._path, entry)
81+
if os.path.isfile(full_path):
82+
files.append(Path(full_path))
83+
return files
84+
85+
@property
86+
def num_files(self) -> int:
87+
return len(self.files)
88+
89+
@property
90+
def dataset_files(self) -> list[Path]:
91+
if not self.path.exists():
92+
return []
93+
return [
94+
Path(filepath) for filepath in glob.glob(f"{self._path}/*.parquet")
95+
]
96+
97+
@property
98+
def num_dataset_files(self) -> int:
99+
return len(self.dataset_files)
100+
101+
@staticmethod
102+
def _generate_config_path(path: str | Path) -> Path:
103+
return Path(path) / ".cfg"
104+
105+
106+
class CacheReader(CacheFiles):
107+
def __init__(
108+
self,
109+
path: str | Path,
110+
batch_size: int,
111+
rows_per_file: int,
112+
compression: str,
113+
):
114+
self._path = Path(path)
115+
self._batch_size = batch_size
116+
self._rows_per_file = rows_per_file
117+
self._compression = compression
118+
119+
@classmethod
120+
def load(cls, path: str | Path):
121+
path = Path(path)
122+
123+
# validate path
124+
if not path.exists():
125+
raise FileNotFoundError(f"Directory does not exist: {path}")
126+
elif not path.is_dir():
127+
raise NotADirectoryError(
128+
f"Path exists but is not a directory: {path}"
129+
)
130+
131+
def _retrieve(config: dict, key: str):
132+
if value := config.get(key, None):
133+
return value
134+
raise KeyError(
135+
f"'{key}' is not defined within {cls._generate_config_path(path)}"
136+
)
137+
138+
cfg_path = cls._generate_config_path(path)
139+
with open(cfg_path, "r") as f:
140+
cfg = json.load(f)
141+
batch_size = _retrieve(cfg, "batch_size")
142+
rows_per_file = _retrieve(cfg, "rows_per_file")
143+
compression = _retrieve(cfg, "compression")
144+
145+
return cls(
146+
path=path,
147+
batch_size=batch_size,
148+
rows_per_file=rows_per_file,
149+
compression=compression,
150+
)
151+
152+
@property
153+
def dataset(self) -> ds.Dataset:
154+
return ds.dataset(self._path, format="parquet")
155+
156+
@property
157+
def schema(self) -> pa.Schema:
158+
return self.dataset.schema
159+
160+
@property
161+
def batch_size(self) -> int:
162+
return self._batch_size
163+
164+
@property
165+
def rows_per_file(self) -> int:
166+
return self._rows_per_file
167+
168+
@property
169+
def compression(self) -> str:
170+
return self._compression
171+
172+
173+
class CacheWriter(CacheFiles):
174+
def __init__(
175+
self,
176+
path: str | Path,
177+
schema: pa.Schema,
178+
batch_size: int,
179+
rows_per_file: int,
180+
compression: str,
181+
):
182+
self._path = Path(path)
183+
self._schema = schema
184+
self._batch_size = batch_size
185+
self._rows_per_file = rows_per_file
186+
self._compression = compression
187+
188+
# internal state
189+
self._writer = None
190+
self._buffer = []
191+
self._count = 0
192+
193+
@classmethod
194+
def create(
195+
cls,
196+
path: str | Path,
197+
schema: pa.Schema,
198+
batch_size: int = 1000,
199+
rows_per_file: int = 10000,
200+
compression: str = "snappy",
201+
):
202+
Path(path).mkdir(parents=True, exist_ok=False)
203+
cfg_path = cls._generate_config_path(path)
204+
with open(cfg_path, "w") as f:
205+
cfg = dict(
206+
batch_size=batch_size,
207+
rows_per_file=rows_per_file,
208+
compression=compression,
209+
)
210+
json.dump(cfg, f, indent=2)
211+
return cls(
212+
path=path,
213+
schema=schema,
214+
batch_size=batch_size,
215+
rows_per_file=rows_per_file,
216+
compression=compression,
217+
)
218+
219+
@classmethod
220+
def load(cls, path: str | Path):
221+
path = Path(path)
222+
# validate path
223+
if not path.exists():
224+
raise FileNotFoundError(f"Directory does not exist: {path}")
225+
elif not path.is_dir():
226+
raise NotADirectoryError(
227+
f"Path exists but is not a directory: {path}"
228+
)
229+
230+
cfg_path = cls._generate_config_path(path)
231+
dataset = ds.dataset(path, format="parquet")
232+
with open(cfg_path, "r") as f:
233+
cfg = json.load(f)
234+
return cls(
235+
path=path,
236+
schema=dataset.schema,
237+
**cfg,
238+
)
239+
240+
@classmethod
241+
def delete(cls, path: str | Path):
242+
path = Path(path)
243+
if not path.exists():
244+
return
245+
cache = cls.load(path)
246+
# delete config file
247+
cfg_path = cls._generate_config_path(path)
248+
if cfg_path.exists() and cfg_path.is_file():
249+
cfg_path.unlink()
250+
# delete parquet files
251+
for file in cache.dataset_files:
252+
if file.exists() and file.is_file() and file.suffix == ".parquet":
253+
file.unlink()
254+
# delete empty cache directory
255+
path.rmdir()
256+
257+
def write_rows(
258+
self,
259+
rows: list[dict[str, Any]],
260+
):
261+
if not rows:
262+
return
263+
batch = pa.RecordBatch.from_pylist(rows, schema=self.schema)
264+
self.write_batch(batch)
265+
266+
def write_batch(
267+
self,
268+
batch: pa.RecordBatch | dict[str, list | np.ndarray | pa.Array],
269+
):
270+
if isinstance(batch, dict):
271+
batch = pa.RecordBatch.from_pydict(batch)
272+
273+
size = batch.num_rows # type: ignore - pyarrow typing
274+
if self._buffer:
275+
size += sum([b.num_rows for b in self._buffer])
276+
277+
# check size
278+
if size < self.batch_size and self._count < self.rows_per_file:
279+
self._buffer.append(batch)
280+
return
281+
282+
if self._buffer:
283+
self._buffer.append(batch)
284+
combined_arrays = [
285+
pa.concat_arrays([b.column(name) for b in self._buffer])
286+
for name in self.schema.names
287+
]
288+
batch = pa.RecordBatch.from_arrays(
289+
combined_arrays, schema=self.schema
290+
)
291+
self._buffer = []
292+
293+
# write batch
294+
writer = self._get_or_create_writer()
295+
writer.write_batch(batch)
296+
297+
# check file size
298+
self._count += size
299+
if self._count >= self.rows_per_file:
300+
self.flush()
301+
302+
def write_table(
303+
self,
304+
table: pa.Table,
305+
):
306+
self.flush()
307+
pq.write_table(table, where=self._next_filename())
308+
309+
def flush(self):
310+
if self._buffer:
311+
combined_arrays = [
312+
pa.concat_arrays([b.column(name) for b in self._buffer])
313+
for name in self.schema.names
314+
]
315+
batch = pa.RecordBatch.from_arrays(
316+
combined_arrays, schema=self.schema
317+
)
318+
self._buffer = []
319+
writer = self._get_or_create_writer()
320+
writer.write_batch(batch)
321+
self._buffer = []
322+
self._count = 0
323+
self._close_writer()
324+
325+
def _next_filename(self) -> Path:
326+
files = self.dataset_files
327+
if not files:
328+
next_index = 0
329+
else:
330+
next_index = max([int(Path(f).stem) for f in files]) + 1
331+
return self._path / f"{next_index:06d}.parquet"
332+
333+
def _get_or_create_writer(self) -> pq.ParquetWriter:
334+
"""Open a new parquet file for writing."""
335+
if self._writer is not None:
336+
return self._writer
337+
self._writer = pq.ParquetWriter(
338+
where=self._next_filename(),
339+
schema=self.schema,
340+
compression=self.compression,
341+
)
342+
return self._writer
343+
344+
def _close_writer(self) -> None:
345+
"""Close the current parquet file."""
346+
if self._writer is not None:
347+
self._writer.close()
348+
self._writer = None
349+
350+
def __enter__(self):
351+
"""Context manager entry."""
352+
return self
353+
354+
def __exit__(self, exc_type, exc_val, exc_tb):
355+
"""Context manager exit - ensures data is flushed."""
356+
self.flush()
357+
358+
@property
359+
def schema(self) -> pa.Schema:
360+
return self._schema
361+
362+
@property
363+
def dataset(self) -> ds.Dataset:
364+
return ds.dataset(
365+
self._path,
366+
format="parquet",
367+
schema=self.schema,
368+
)
369+
370+
@property
371+
def batch_size(self) -> int:
372+
return self._batch_size
373+
374+
@property
375+
def rows_per_file(self) -> int:
376+
return self._rows_per_file
377+
378+
@property
379+
def compression(self) -> str:
380+
return self._compression

src/valor_lite/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ def __init__(self):
55
)
66

77

8+
class EmptyCacheError(Exception):
9+
def __init__(self):
10+
super().__init__("cache contains no data")
11+
12+
813
class EmptyFilterError(Exception):
914
def __init__(self, message: str):
1015
super().__init__(message)

0 commit comments

Comments
 (0)