Skip to content

Commit da22674

Browse files
committed
cvt to new cache
1 parent ac7a77e commit da22674

6 files changed

Lines changed: 979 additions & 203 deletions

File tree

src/valor_lite/common/datatype.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from datetime import datetime
2+
from enum import StrEnum
3+
4+
import pyarrow as pa
5+
import pyarrow.lib as pl
6+
7+
8+
class DataType(StrEnum):
9+
INTEGER = "int"
10+
FLOAT = "float"
11+
STRING = "string"
12+
TIMESTAMP = "timestamp"
13+
14+
def to_py(self):
15+
"""Get python type."""
16+
match self:
17+
case DataType.INTEGER:
18+
return int
19+
case DataType.FLOAT:
20+
return float
21+
case DataType.STRING:
22+
return str
23+
case DataType.TIMESTAMP:
24+
return datetime
25+
26+
def to_arrow(self):
27+
"""Get arrow type."""
28+
match self:
29+
case DataType.INTEGER:
30+
return pa.int64()
31+
case DataType.FLOAT:
32+
return pa.float64()
33+
case DataType.STRING:
34+
return pa.string()
35+
case DataType.TIMESTAMP:
36+
return pa.timestamp("us")
37+
38+
39+
def convert_type_mapping_to_schema(
40+
type_mapping: dict[str, DataType] | None
41+
) -> list[tuple[str, pl.DataType]]:
42+
"""
43+
Convert type mapping to a pyarrow schema input.
44+
45+
Parameters
46+
----------
47+
type_mapping : dict[str, DataType] | None
48+
A map from string key to datatype. Treats input of `None` as empty mapping.
49+
50+
Returns
51+
-------
52+
list[tuple[str, pyarrow.lib.DataType]]
53+
A list of field name, field type pairs that can be used as input to pyarrow.schema.
54+
"""
55+
if not type_mapping:
56+
return []
57+
return [(k, DataType(v).to_arrow()) for k, v in type_mapping.items()]

src/valor_lite/common/ephemeral.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
import pyarrow as pa
5+
6+
7+
class MemoryCache:
8+
def __init__(self, table: pa.Table):
9+
self._table = table
10+
11+
def count_rows(self) -> int:
12+
"""Count the number of rows in the cache."""
13+
return self._table.num_rows
14+
15+
16+
class MemoryCacheWriter(MemoryCache):
17+
def __init__(
18+
self,
19+
table: pa.Table,
20+
batch_size: int,
21+
):
22+
self._table = table
23+
self._schema = table.schema
24+
self._batch_size = batch_size
25+
26+
# internal state
27+
self._buffer = []
28+
29+
@classmethod
30+
def create(
31+
cls,
32+
schema: pa.Schema,
33+
batch_size: int,
34+
):
35+
"""
36+
Create a cache.
37+
38+
Parameters
39+
----------
40+
schema : pa.Schema
41+
Cache schema.
42+
batch_size : int, default=1_000
43+
Target batch size when writing chunks.
44+
"""
45+
return cls(
46+
table=schema.empty_table(),
47+
batch_size=batch_size,
48+
)
49+
50+
def delete(self):
51+
"""
52+
Delete any existing cache data.
53+
"""
54+
self._buffer = []
55+
self._table = self._table.schema.empty_table()
56+
57+
def write_rows(
58+
self,
59+
rows: list[dict[str, Any]],
60+
):
61+
"""
62+
Write rows to cache.
63+
64+
Parameters
65+
----------
66+
rows : list[dict[str, Any]]
67+
A list of rows represented by dictionaries mapping fields to values.
68+
"""
69+
if not rows:
70+
return
71+
batch = pa.RecordBatch.from_pylist(rows, schema=self._schema)
72+
self.write_batch(batch)
73+
74+
def write_batch(
75+
self,
76+
batch: pa.RecordBatch | dict[str, list | np.ndarray | pa.Array],
77+
):
78+
"""
79+
Write a batch to cache.
80+
81+
Parameters
82+
----------
83+
batch : pa.RecordBatch | dict[str, list | np.ndarray | pa.Array]
84+
A batch of columnar data.
85+
"""
86+
if isinstance(batch, dict):
87+
batch = pa.RecordBatch.from_pydict(batch)
88+
89+
size = batch.num_rows # type: ignore - pyarrow typing
90+
if self._buffer:
91+
size += sum([b.num_rows for b in self._buffer])
92+
93+
# check size
94+
if size < self._batch_size:
95+
self._buffer.append(batch)
96+
return
97+
98+
if self._buffer:
99+
self._buffer.append(batch)
100+
combined_arrays = [
101+
pa.concat_arrays([b.column(name) for b in self._buffer])
102+
for name in self._schema.names
103+
]
104+
batch = pa.RecordBatch.from_arrays(
105+
combined_arrays, schema=self._schema
106+
)
107+
self._buffer = []
108+
109+
# write batch
110+
self.write_table(pa.Table.from_batches([batch]))
111+
112+
def write_table(
113+
self,
114+
table: pa.Table,
115+
):
116+
"""
117+
Write a table directly to cache.
118+
119+
Parameters
120+
----------
121+
table : pa.Table
122+
A populated table.
123+
"""
124+
self._table = pa.concat_tables([self._table, table])
125+
126+
def flush(self):
127+
"""Flush the cache buffer."""
128+
if self._buffer:
129+
combined_arrays = [
130+
pa.concat_arrays([b.column(name) for b in self._buffer])
131+
for name in self._schema.names
132+
]
133+
batch = pa.RecordBatch.from_arrays(
134+
combined_arrays, schema=self._schema
135+
)
136+
self._table = pa.concat_tables(
137+
[self._table, pa.Table.from_batches([batch])]
138+
)
139+
self._buffer = []
140+
141+
def __enter__(self):
142+
"""Context manager entry."""
143+
return self
144+
145+
def __exit__(self, exc_type, exc_val, exc_tb):
146+
"""Context manager exit - ensures data is flushed."""
147+
self.flush()
148+
149+
150+
class MemoryCacheReader:
151+
def __init__(
152+
self,
153+
cache: MemoryCacheWriter,
154+
):
155+
self._cache = cache
156+
self._schema = self._cache._schema
157+
158+
@classmethod
159+
def load(cls, cache: MemoryCacheWriter):
160+
"""
161+
Load cache from table.
162+
163+
Parameters
164+
----------
165+
cache : MemoryCacheWriter
166+
A cache writer containing the ephemeral cache.
167+
"""
168+
return cls(cache=cache)
169+
170+
def iterate_tables(self):
171+
"""Iterate over tables within the cache."""
172+
yield self._cache._table
173+
174+
def count_rows(self) -> int:
175+
"""Count the number of rows in the cache."""
176+
return self._cache.count_rows()

0 commit comments

Comments
 (0)