Skip to content

Commit 6e1264b

Browse files
committed
update
1 parent 9ddeebc commit 6e1264b

5 files changed

Lines changed: 121 additions & 66 deletions

File tree

src/valor_lite/common/ephemeral.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,32 @@
66

77

88
class MemoryCache:
9-
def __init__(self, table: pa.Table):
10-
self._table = table
11-
12-
def count_tables(self) -> int:
13-
"""Count the number of tables in the cache."""
14-
return 1
15-
16-
def count_rows(self) -> int:
17-
"""Count the number of rows in the cache."""
18-
return self._table.num_rows
19-
20-
21-
class MemoryCacheReader(MemoryCache):
229
def __init__(
2310
self,
2411
table: pa.Table,
2512
batch_size: int,
2613
):
27-
super().__init__(table)
28-
self._schema = self._table.schema
14+
self._table = table
2915
self._batch_size = batch_size
3016

3117
@property
3218
def schema(self) -> pa.Schema:
33-
return self._schema
19+
return self._table.schema
3420

3521
@property
3622
def batch_size(self) -> int:
3723
return self._batch_size
3824

25+
def count_tables(self) -> int:
26+
"""Count the number of tables in the cache."""
27+
return 1
28+
29+
def count_rows(self) -> int:
30+
"""Count the number of rows in the cache."""
31+
return self._table.num_rows
32+
33+
34+
class MemoryCacheReader(MemoryCache):
3935
def iterate_tables(
4036
self,
4137
columns: list[str] | None = None,
@@ -56,9 +52,10 @@ def __init__(
5652
table: pa.Table,
5753
batch_size: int,
5854
):
59-
super().__init__(table)
60-
self._schema = table.schema
61-
self._batch_size = batch_size
55+
super().__init__(
56+
table=table,
57+
batch_size=batch_size,
58+
)
6259

6360
# internal state
6461
self._buffer = []
@@ -98,7 +95,7 @@ def write_rows(
9895
"""
9996
if not rows:
10097
return
101-
batch = pa.RecordBatch.from_pylist(rows, schema=self._schema)
98+
batch = pa.RecordBatch.from_pylist(rows, schema=self.schema)
10299
self.write_batch(batch)
103100

104101
def write_columns(
@@ -143,10 +140,10 @@ def write_batch(
143140
self._buffer.append(batch)
144141
combined_arrays = [
145142
pa.concat_arrays([b.column(name) for b in self._buffer])
146-
for name in self._schema.names
143+
for name in self.schema.names
147144
]
148145
batch = pa.RecordBatch.from_arrays(
149-
combined_arrays, schema=self._schema
146+
combined_arrays, schema=self.schema
150147
)
151148
self._buffer = []
152149

@@ -172,10 +169,10 @@ def flush(self):
172169
if self._buffer:
173170
combined_arrays = [
174171
pa.concat_arrays([b.column(name) for b in self._buffer])
175-
for name in self._schema.names
172+
for name in self.schema.names
176173
]
177174
batch = pa.RecordBatch.from_arrays(
178-
combined_arrays, schema=self._schema
175+
combined_arrays, schema=self.schema
179176
)
180177
self._table = pa.concat_tables(
181178
[self._table, pa.Table.from_batches([batch])]

src/valor_lite/common/persistent.py

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,40 @@
1313

1414

1515
class FileCache:
16-
def __init__(self, path: str | Path):
16+
def __init__(
17+
self,
18+
path: str | Path,
19+
schema: pa.Schema,
20+
batch_size: int,
21+
rows_per_file: int,
22+
compression: str,
23+
):
1724
self._path = Path(path)
25+
self._schema = schema
26+
self._batch_size = batch_size
27+
self._rows_per_file = rows_per_file
28+
self._compression = compression
1829

1930
@property
2031
def path(self) -> Path:
2132
return self._path
2233

34+
@property
35+
def schema(self) -> pa.Schema:
36+
return self._schema
37+
38+
@property
39+
def batch_size(self) -> int:
40+
return self._batch_size
41+
42+
@property
43+
def rows_per_file(self) -> int:
44+
return self._rows_per_file
45+
46+
@property
47+
def compression(self) -> str:
48+
return self._compression
49+
2350
@staticmethod
2451
def _generate_config_path(path: str | Path) -> Path:
2552
"""Generate cache configuration path."""
@@ -84,38 +111,8 @@ def get_dataset_files(self) -> list[Path]:
84111

85112

86113
class FileCacheReader(FileCache):
87-
def __init__(
88-
self,
89-
path: str | Path,
90-
schema: pa.Schema,
91-
batch_size: int,
92-
rows_per_file: int,
93-
compression: str,
94-
):
95-
super().__init__(path)
96-
self._schema = schema
97-
self._batch_size = batch_size
98-
self._rows_per_file = rows_per_file
99-
self._compression = compression
100-
101-
@property
102-
def schema(self) -> pa.Schema:
103-
return self._schema
104-
105-
@property
106-
def batch_size(self) -> int:
107-
return self._batch_size
108-
109-
@property
110-
def rows_per_file(self) -> int:
111-
return self._rows_per_file
112-
113-
@property
114-
def compression(self) -> str:
115-
return self._compression
116-
117114
@classmethod
118-
def load(cls, path: str | Path | FileCache):
115+
def load(cls, path: str | Path):
119116
"""
120117
Load cache from disk.
121118
@@ -124,8 +121,6 @@ def load(cls, path: str | Path | FileCache):
124121
path : str | Path
125122
Where the cache is stored.
126123
"""
127-
if isinstance(path, FileCache):
128-
path = path.path
129124
path = Path(path)
130125
if not path.exists():
131126
raise FileNotFoundError(f"Directory does not exist: {path}")
@@ -192,11 +187,13 @@ def __init__(
192187
rows_per_file: int,
193188
compression: str,
194189
):
195-
super().__init__(path)
196-
self._schema = schema
197-
self._batch_size = batch_size
198-
self._rows_per_file = rows_per_file
199-
self._compression = compression
190+
super().__init__(
191+
path=path,
192+
schema=schema,
193+
batch_size=batch_size,
194+
rows_per_file=rows_per_file,
195+
compression=compression,
196+
)
200197

201198
# internal state
202199
self._writer = None

tests/common/test_ephemeral_cache.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,18 @@ def test_cache_reader():
2929
)
3030
writer.write_table(tbl)
3131
assert writer.count_rows() == 101
32+
assert writer.count_tables() == 1
3233

3334
reader = writer.to_reader()
3435
assert reader.count_rows() == 101
36+
assert reader.count_tables() == 1
3537
for tbl in reader.iterate_tables():
3638
assert tbl["some_int"].to_pylist() == [i for i in range(101)]
3739
assert tbl["some_float"].to_pylist() == [i for i in range(101)]
3840
assert tbl["some_str"].to_pylist() == [f"str{i}" for i in range(101)]
3941
assert reader.count_rows() == 101
40-
assert reader._schema == schema
42+
assert reader.count_tables() == 1
43+
assert reader.schema == schema
4144

4245

4346
def test_cache_write_columns():
@@ -60,9 +63,11 @@ def test_cache_write_columns():
6063
"some_str": pa.array([f"str{i}"]),
6164
}
6265
)
66+
writer.write_columns({})
6367

6468
reader = writer.to_reader()
6569
assert reader.count_rows() == 1000
70+
assert reader.count_tables() == 1
6671
for tbl in reader.iterate_tables():
6772
assert tbl["some_int"].to_pylist() == [i for i in range(1000)]
6873
assert tbl["some_float"].to_pylist() == [i for i in range(1000)]
@@ -98,6 +103,7 @@ def test_cache_write_rows():
98103

99104
reader = writer.to_reader()
100105
assert reader.count_rows() == 1000
106+
assert reader.count_tables() == 1
101107
for tbl in reader.iterate_tables():
102108
assert tbl["some_int"].to_pylist() == [i for i in range(1000)]
103109
assert tbl["some_float"].to_pylist() == [i for i in range(1000)]
@@ -127,12 +133,15 @@ def test_cache_write_table():
127133
]
128134
)
129135
assert writer.count_rows() == 0
136+
assert writer.count_tables() == 1
130137

131138
writer.write_table(tbl)
132139
assert writer.count_rows() == 101
140+
assert writer.count_tables() == 1
133141

134142
writer.write_table(tbl)
135143
assert writer.count_rows() == 202
144+
assert writer.count_tables() == 1
136145

137146
reader = writer.to_reader()
138147

tests/common/test_exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ def test_empty_evaluator_exc():
99
assert "no data" in str(e)
1010

1111

12+
def test_empty_cache_error_exc():
13+
with pytest.raises(exc.EmptyCacheError) as e:
14+
raise exc.EmptyCacheError()
15+
assert "no data" in str(e)
16+
17+
1218
def test_cache_error_exc():
1319
with pytest.raises(exc.InternalCacheError) as e:
1420
raise exc.InternalCacheError("custom message")

tests/common/test_persistent_cache.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313

1414

1515
def test_cache_files_empty(tmp_path: Path):
16-
cache = FileCache(tmp_path)
16+
cache = FileCache(
17+
tmp_path,
18+
schema=None, # type: ignore - testing
19+
batch_size=None, # type: ignore - testing
20+
rows_per_file=None, # type: ignore - testing
21+
compression=None, # type: ignore - testing
22+
)
1723
assert cache._path == tmp_path
1824
assert cache.get_files() == []
1925
assert cache.get_dataset_files() == []
@@ -22,7 +28,13 @@ def test_cache_files_empty(tmp_path: Path):
2228

2329
def test_cache_files_does_not_exist(tmp_path: Path):
2430
path = tmp_path / "does_not_exist"
25-
cache = FileCache(path)
31+
cache = FileCache(
32+
path,
33+
schema=None, # type: ignore - testing
34+
batch_size=None, # type: ignore - testing
35+
rows_per_file=None, # type: ignore - testing
36+
compression=None, # type: ignore - testing
37+
)
2638
assert cache._path == path
2739
assert cache.get_files() == []
2840
assert cache.get_dataset_files() == []
@@ -93,6 +105,7 @@ def test_cache_reader(tmp_path: Path):
93105
]
94106

95107
assert reader.count_rows() == 202
108+
assert reader.count_tables() == 2
96109
assert reader._schema == schema
97110

98111

@@ -167,11 +180,13 @@ def test_cache_write_columns(tmp_path: Path):
167180
"some_str": pa.array([f"str{i}"]),
168181
}
169182
)
183+
writer.write_columns({})
170184
writer.flush()
171185
assert len(writer.get_files()) == 11
172186

173187
reader = FileCacheReader.load(writer.path)
174188
assert reader.count_rows() == 1000
189+
assert reader.count_tables() == 10
175190
for idx, tbl in enumerate(reader.iterate_tables()):
176191
assert tbl["some_int"].to_pylist() == [
177192
i
@@ -221,6 +236,7 @@ def test_cache_write_rows(tmp_path: Path):
221236

222237
reader = FileCacheReader.load(writer.path)
223238
assert reader.count_rows() == 1000
239+
assert reader.count_tables() == 10
224240
for idx, tbl in enumerate(reader.iterate_tables()):
225241
assert tbl["some_int"].to_pylist() == [
226242
i
@@ -269,6 +285,7 @@ def test_cache_write_table(tmp_path: Path):
269285

270286
reader = FileCacheReader.load(writer.path)
271287
assert reader.count_rows() == 202
288+
assert reader.count_tables() == 2
272289
for tbl in reader.iterate_tables():
273290
assert tbl["some_int"].to_pylist() == [i for i in range(101)]
274291
assert tbl["some_float"].to_pylist() == [i for i in range(101)]
@@ -281,6 +298,33 @@ def test_cache_delete(tmp_path: Path):
281298
batch_size = 10
282299
rows_per_file = 100
283300
path = tmp_path / "cache"
301+
302+
FileCacheWriter.create(
303+
path=path,
304+
schema=pa.schema(
305+
[
306+
("some_int", pa.int64()),
307+
("some_float", pa.float64()),
308+
("some_str", pa.string()),
309+
]
310+
),
311+
batch_size=batch_size,
312+
rows_per_file=rows_per_file,
313+
)
314+
with pytest.raises(FileExistsError):
315+
FileCacheWriter.create(
316+
path=path,
317+
schema=pa.schema(
318+
[
319+
("some_int", pa.int64()),
320+
("some_float", pa.float64()),
321+
("some_str", pa.string()),
322+
]
323+
),
324+
batch_size=batch_size,
325+
rows_per_file=rows_per_file,
326+
)
327+
284328
with FileCacheWriter.create(
285329
path=path,
286330
schema=pa.schema(
@@ -292,6 +336,7 @@ def test_cache_delete(tmp_path: Path):
292336
),
293337
batch_size=batch_size,
294338
rows_per_file=rows_per_file,
339+
delete_if_exists=True,
295340
) as writer:
296341
tbl = pa.Table.from_pylist(
297342
[
@@ -311,6 +356,7 @@ def test_cache_delete(tmp_path: Path):
311356

312357
reader = FileCacheReader.load(path=writer.path)
313358
assert reader.count_rows() == 202
359+
assert reader.count_tables() == 2
314360
for tbl in reader.iterate_tables():
315361
assert tbl["some_int"].to_pylist() == [i for i in range(101)]
316362
assert tbl["some_float"].to_pylist() == [i for i in range(101)]

0 commit comments

Comments
 (0)