Skip to content

Commit 76051c1

Browse files
committed
add cache src
1 parent 75d3c19 commit 76051c1

3 files changed

Lines changed: 396 additions & 1 deletion

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ site/*
2929
*.pt
3030
*.png
3131
*.jpg
32-
*.parquet
32+
*.parquet
33+
*.valor

src/valor_lite/cache.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import glob
2+
import json
3+
from datetime import datetime
4+
from enum import StrEnum
5+
from pathlib import Path
6+
from typing import Any
7+
8+
import numpy as np
9+
import pyarrow as pa
10+
import pyarrow.dataset as ds
11+
import pyarrow.parquet as pq
12+
13+
14+
class DataType(StrEnum):
15+
INTEGER = "int"
16+
FLOAT = "float"
17+
STRING = "string"
18+
TIMESTAMP = "timestamp"
19+
20+
def to_py(self):
21+
match self:
22+
case DataType.INTEGER:
23+
return int
24+
case DataType.FLOAT:
25+
return float
26+
case DataType.STRING:
27+
return str
28+
case DataType.TIMESTAMP:
29+
return datetime
30+
31+
def to_arrow(self):
32+
match self:
33+
case DataType.INTEGER:
34+
return pa.int64()
35+
case DataType.FLOAT:
36+
return pa.float64()
37+
case DataType.STRING:
38+
return pa.string()
39+
case DataType.TIMESTAMP:
40+
return pa.timestamp("us")
41+
42+
43+
class CacheReader:
44+
def __init__(self, where: str | Path):
45+
self._dir = Path(where)
46+
self._cfg = self._dir / ".cfg"
47+
48+
with open(self._cfg, "r") as f:
49+
cfg = json.load(f)
50+
self._batch_size = cfg.get("batch_size")
51+
self._rows_per_file = cfg.get("rows_per_file")
52+
self._compression = cfg.get("compression")
53+
54+
@property
55+
def files(self) -> list[str]:
56+
return glob.glob(f"{self._dir}/*")
57+
58+
@property
59+
def num_files(self) -> int:
60+
return len(self.files)
61+
62+
@property
63+
def dataset_files(self) -> list[str]:
64+
return glob.glob(f"{self._dir}/*.parquet")
65+
66+
@property
67+
def num_dataset_files(self) -> int:
68+
return len(self.dataset_files)
69+
70+
@property
71+
def dataset(self):
72+
return ds.dataset(
73+
self._dir,
74+
format="parquet",
75+
)
76+
77+
@property
78+
def schema(self):
79+
return self.dataset.schema
80+
81+
@property
82+
def batch_size(self) -> int:
83+
return self._batch_size
84+
85+
@property
86+
def rows_per_file(self) -> int:
87+
return self._rows_per_file
88+
89+
@property
90+
def compression(self) -> str:
91+
return self._compression
92+
93+
94+
class CacheWriter(CacheReader):
95+
def __init__(
96+
self,
97+
where: str | Path,
98+
schema: pa.Schema,
99+
batch_size: int = 1000,
100+
rows_per_file: int = 10000,
101+
compression: str = "snappy",
102+
delete_if_exists: bool = True,
103+
):
104+
self._dir = Path(where)
105+
self._cfg = self._dir / ".cfg"
106+
107+
self._schema = schema
108+
self._batch_size = batch_size
109+
self._rows_per_file = rows_per_file
110+
self._compression = compression
111+
112+
if delete_if_exists:
113+
self.delete_files()
114+
self._dir.mkdir(parents=True, exist_ok=True)
115+
116+
# Internal state
117+
self._writer = None
118+
self._buffer = []
119+
self._count = 0
120+
121+
with open(self._cfg, "w") as f:
122+
info = dict(
123+
batch_size=batch_size,
124+
rows_per_file=rows_per_file,
125+
compression=compression,
126+
)
127+
json.dump(info, f, indent=2)
128+
129+
@property
130+
def schema(self):
131+
return self._schema
132+
133+
@property
134+
def dataset(self):
135+
return ds.dataset(
136+
self._dir,
137+
format="parquet",
138+
schema=self.schema,
139+
)
140+
141+
def delete_files(self):
142+
for file in self.dataset_files:
143+
Path(file).unlink()
144+
145+
@property
146+
def next_index(self):
147+
files = self.dataset_files
148+
if not files:
149+
return 0
150+
return max([int(Path(f).stem) for f in files]) + 1
151+
152+
def write_rows(
153+
self,
154+
rows: list[dict[str, Any]],
155+
):
156+
if not rows:
157+
return
158+
batch = pa.RecordBatch.from_pylist(rows, schema=self.schema)
159+
self.write_batch(batch)
160+
161+
def write_batch(
162+
self,
163+
batch: pa.RecordBatch | dict[str, list | np.ndarray | pa.Array],
164+
):
165+
if isinstance(batch, dict):
166+
batch = pa.RecordBatch.from_pydict(batch)
167+
168+
size = batch.num_rows # type: ignore - pyarrow typing
169+
if self._buffer:
170+
size += sum([b.num_rows for b in self._buffer])
171+
172+
# check size
173+
if size < self.batch_size and self._count < self.rows_per_file:
174+
self._buffer.append(batch)
175+
return
176+
177+
if self._buffer:
178+
self._buffer.append(batch)
179+
combined_arrays = [
180+
pa.concat_arrays([b.column(name) for b in self._buffer])
181+
for name in self.schema.names
182+
]
183+
batch = pa.RecordBatch.from_arrays(
184+
combined_arrays, schema=self.schema
185+
)
186+
self._buffer = []
187+
188+
# write batch
189+
writer = self._get_or_create_writer()
190+
writer.write_batch(batch)
191+
192+
# check file size
193+
self._count += size
194+
if self._count >= self.rows_per_file:
195+
self.flush()
196+
197+
def write_table(
198+
self,
199+
table: pa.Table,
200+
):
201+
self.flush()
202+
pq.write_table(table, where=self._next_filename())
203+
204+
def flush(self):
205+
if self._buffer:
206+
combined_arrays = [
207+
pa.concat_arrays([b.column(name) for b in self._buffer])
208+
for name in self.schema.names
209+
]
210+
batch = pa.RecordBatch.from_arrays(
211+
combined_arrays, schema=self.schema
212+
)
213+
self._buffer = []
214+
writer = self._get_or_create_writer()
215+
writer.write_batch(batch)
216+
self._buffer = []
217+
self._count = 0
218+
self._close_writer()
219+
220+
def _next_filename(self) -> Path:
221+
return self._dir / f"{self.next_index:06d}.parquet"
222+
223+
def _get_or_create_writer(self) -> pq.ParquetWriter:
224+
"""Open a new parquet file for writing."""
225+
if self._writer is not None:
226+
return self._writer
227+
self._writer = pq.ParquetWriter(
228+
where=self._next_filename(),
229+
schema=self.schema,
230+
compression=self.compression,
231+
)
232+
return self._writer
233+
234+
def _close_writer(self) -> None:
235+
"""Close the current parquet file."""
236+
if self._writer is not None:
237+
self._writer.close()
238+
self._writer = None
239+
240+
def __enter__(self):
241+
"""Context manager entry."""
242+
return self
243+
244+
def __exit__(self, exc_type, exc_val, exc_tb):
245+
"""Context manager exit - ensures data is flushed."""
246+
self.flush()

0 commit comments

Comments
 (0)