Skip to content

Commit 0778f33

Browse files
committed
more tests
1 parent cb00999 commit 0778f33

4 files changed

Lines changed: 84 additions & 63 deletions

File tree

src/valor_lite/common/ephemeral.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,32 @@
66

77

88
class MemoryCache:
9-
def __init__(self, table: pa.Table):
10-
self._table = table
11-
12-
def count_tables(self) -> int:
13-
"""Count the number of tables in the cache."""
14-
return 1
15-
16-
def count_rows(self) -> int:
17-
"""Count the number of rows in the cache."""
18-
return self._table.num_rows
19-
20-
21-
class MemoryCacheReader(MemoryCache):
229
def __init__(
2310
self,
2411
table: pa.Table,
2512
batch_size: int,
2613
):
27-
super().__init__(table)
28-
self._schema = self._table.schema
14+
self._table = table
2915
self._batch_size = batch_size
3016

3117
@property
3218
def schema(self) -> pa.Schema:
33-
return self._schema
19+
return self._table.schema
3420

3521
@property
3622
def batch_size(self) -> int:
3723
return self._batch_size
3824

25+
def count_tables(self) -> int:
26+
"""Count the number of tables in the cache."""
27+
return 1
28+
29+
def count_rows(self) -> int:
30+
"""Count the number of rows in the cache."""
31+
return self._table.num_rows
32+
33+
34+
class MemoryCacheReader(MemoryCache):
3935
def iterate_tables(
4036
self,
4137
columns: list[str] | None = None,
@@ -56,9 +52,10 @@ def __init__(
5652
table: pa.Table,
5753
batch_size: int,
5854
):
59-
super().__init__(table)
60-
self._schema = table.schema
61-
self._batch_size = batch_size
55+
super().__init__(
56+
table=table,
57+
batch_size=batch_size,
58+
)
6259

6360
# internal state
6461
self._buffer = []
@@ -98,7 +95,7 @@ def write_rows(
9895
"""
9996
if not rows:
10097
return
101-
batch = pa.RecordBatch.from_pylist(rows, schema=self._schema)
98+
batch = pa.RecordBatch.from_pylist(rows, schema=self.schema)
10299
self.write_batch(batch)
103100

104101
def write_columns(
@@ -143,10 +140,10 @@ def write_batch(
143140
self._buffer.append(batch)
144141
combined_arrays = [
145142
pa.concat_arrays([b.column(name) for b in self._buffer])
146-
for name in self._schema.names
143+
for name in self.schema.names
147144
]
148145
batch = pa.RecordBatch.from_arrays(
149-
combined_arrays, schema=self._schema
146+
combined_arrays, schema=self.schema
150147
)
151148
self._buffer = []
152149

@@ -172,10 +169,10 @@ def flush(self):
172169
if self._buffer:
173170
combined_arrays = [
174171
pa.concat_arrays([b.column(name) for b in self._buffer])
175-
for name in self._schema.names
172+
for name in self.schema.names
176173
]
177174
batch = pa.RecordBatch.from_arrays(
178-
combined_arrays, schema=self._schema
175+
combined_arrays, schema=self.schema
179176
)
180177
self._table = pa.concat_tables(
181178
[self._table, pa.Table.from_batches([batch])]

src/valor_lite/common/persistent.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,40 @@
1313

1414

1515
class FileCache:
16-
def __init__(self, path: str | Path):
16+
def __init__(
17+
self,
18+
path: str | Path,
19+
schema: pa.Schema,
20+
batch_size: int,
21+
rows_per_file: int,
22+
compression: str,
23+
):
1724
self._path = Path(path)
25+
self._schema = schema
26+
self._batch_size = batch_size
27+
self._rows_per_file = rows_per_file
28+
self._compression = compression
1829

1930
@property
2031
def path(self) -> Path:
2132
return self._path
2233

34+
@property
35+
def schema(self) -> pa.Schema:
36+
return self._schema
37+
38+
@property
39+
def batch_size(self) -> int:
40+
return self._batch_size
41+
42+
@property
43+
def rows_per_files(self) -> int:
44+
return self._rows_per_file
45+
46+
@property
47+
def compression(self) -> str:
48+
return self._compression
49+
2350
@staticmethod
2451
def _generate_config_path(path: str | Path) -> Path:
2552
"""Generate cache configuration path."""
@@ -84,36 +111,6 @@ def get_dataset_files(self) -> list[Path]:
84111

85112

86113
class FileCacheReader(FileCache):
87-
def __init__(
88-
self,
89-
path: str | Path,
90-
schema: pa.Schema,
91-
batch_size: int,
92-
rows_per_file: int,
93-
compression: str,
94-
):
95-
super().__init__(path)
96-
self._schema = schema
97-
self._batch_size = batch_size
98-
self._rows_per_file = rows_per_file
99-
self._compression = compression
100-
101-
@property
102-
def schema(self) -> pa.Schema:
103-
return self._schema
104-
105-
@property
106-
def batch_size(self) -> int:
107-
return self._batch_size
108-
109-
@property
110-
def rows_per_file(self) -> int:
111-
return self._rows_per_file
112-
113-
@property
114-
def compression(self) -> str:
115-
return self._compression
116-
117114
@classmethod
118115
def load(cls, path: str | Path | FileCache):
119116
"""
@@ -192,11 +189,13 @@ def __init__(
192189
rows_per_file: int,
193190
compression: str,
194191
):
195-
super().__init__(path)
196-
self._schema = schema
197-
self._batch_size = batch_size
198-
self._rows_per_file = rows_per_file
199-
self._compression = compression
192+
super().__init__(
193+
path=path,
194+
schema=schema,
195+
batch_size=batch_size,
196+
rows_per_file=rows_per_file,
197+
compression=compression,
198+
)
200199

201200
# internal state
202201
self._writer = None

tests/common/test_ephemeral_cache.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,18 @@ def test_cache_reader():
2929
)
3030
writer.write_table(tbl)
3131
assert writer.count_rows() == 101
32+
assert writer.count_tables() == 1
3233

3334
reader = writer.to_reader()
3435
assert reader.count_rows() == 101
36+
assert reader.count_tables() == 1
3537
for tbl in reader.iterate_tables():
3638
assert tbl["some_int"].to_pylist() == [i for i in range(101)]
3739
assert tbl["some_float"].to_pylist() == [i for i in range(101)]
3840
assert tbl["some_str"].to_pylist() == [f"str{i}" for i in range(101)]
3941
assert reader.count_rows() == 101
40-
assert reader._schema == schema
42+
assert reader.count_tables() == 1
43+
assert reader.schema == schema
4144

4245

4346
def test_cache_write_columns():
@@ -63,6 +66,7 @@ def test_cache_write_columns():
6366

6467
reader = writer.to_reader()
6568
assert reader.count_rows() == 1000
69+
assert reader.count_tables() == 1
6670
for tbl in reader.iterate_tables():
6771
assert tbl["some_int"].to_pylist() == [i for i in range(1000)]
6872
assert tbl["some_float"].to_pylist() == [i for i in range(1000)]
@@ -98,6 +102,7 @@ def test_cache_write_rows():
98102

99103
reader = writer.to_reader()
100104
assert reader.count_rows() == 1000
105+
assert reader.count_tables() == 1
101106
for tbl in reader.iterate_tables():
102107
assert tbl["some_int"].to_pylist() == [i for i in range(1000)]
103108
assert tbl["some_float"].to_pylist() == [i for i in range(1000)]
@@ -127,12 +132,15 @@ def test_cache_write_table():
127132
]
128133
)
129134
assert writer.count_rows() == 0
135+
assert writer.count_tables() == 1
130136

131137
writer.write_table(tbl)
132138
assert writer.count_rows() == 101
139+
assert writer.count_tables() == 1
133140

134141
writer.write_table(tbl)
135142
assert writer.count_rows() == 202
143+
assert writer.count_tables() == 1
136144

137145
reader = writer.to_reader()
138146

tests/common/test_persistent_cache.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313

1414

1515
def test_cache_files_empty(tmp_path: Path):
16-
cache = FileCache(tmp_path)
16+
cache = FileCache(
17+
tmp_path,
18+
schema=None, # type: ignore - testing
19+
batch_size=None, # type: ignore - testing
20+
rows_per_file=None, # type: ignore - testing
21+
compression=None, # type: ignore - testing
22+
)
1723
assert cache._path == tmp_path
1824
assert cache.get_files() == []
1925
assert cache.get_dataset_files() == []
@@ -22,7 +28,13 @@ def test_cache_files_empty(tmp_path: Path):
2228

2329
def test_cache_files_does_not_exist(tmp_path: Path):
2430
path = tmp_path / "does_not_exist"
25-
cache = FileCache(path)
31+
cache = FileCache(
32+
path,
33+
schema=None, # type: ignore - testing
34+
batch_size=None, # type: ignore - testing
35+
rows_per_file=None, # type: ignore - testing
36+
compression=None, # type: ignore - testing
37+
)
2638
assert cache._path == path
2739
assert cache.get_files() == []
2840
assert cache.get_dataset_files() == []
@@ -93,6 +105,7 @@ def test_cache_reader(tmp_path: Path):
93105
]
94106

95107
assert reader.count_rows() == 202
108+
assert reader.count_tables() == 2
96109
assert reader._schema == schema
97110

98111

@@ -172,6 +185,7 @@ def test_cache_write_columns(tmp_path: Path):
172185

173186
reader = FileCacheReader.load(writer.path)
174187
assert reader.count_rows() == 1000
188+
assert reader.count_tables() == 10
175189
for idx, tbl in enumerate(reader.iterate_tables()):
176190
assert tbl["some_int"].to_pylist() == [
177191
i
@@ -221,6 +235,7 @@ def test_cache_write_rows(tmp_path: Path):
221235

222236
reader = FileCacheReader.load(writer.path)
223237
assert reader.count_rows() == 1000
238+
assert reader.count_tables() == 10
224239
for idx, tbl in enumerate(reader.iterate_tables()):
225240
assert tbl["some_int"].to_pylist() == [
226241
i
@@ -269,6 +284,7 @@ def test_cache_write_table(tmp_path: Path):
269284

270285
reader = FileCacheReader.load(writer.path)
271286
assert reader.count_rows() == 202
287+
assert reader.count_tables() == 2
272288
for tbl in reader.iterate_tables():
273289
assert tbl["some_int"].to_pylist() == [i for i in range(101)]
274290
assert tbl["some_float"].to_pylist() == [i for i in range(101)]
@@ -311,6 +327,7 @@ def test_cache_delete(tmp_path: Path):
311327

312328
reader = FileCacheReader.load(path=writer.path)
313329
assert reader.count_rows() == 202
330+
assert reader.count_tables() == 2
314331
for tbl in reader.iterate_tables():
315332
assert tbl["some_int"].to_pylist() == [i for i in range(101)]
316333
assert tbl["some_float"].to_pylist() == [i for i in range(101)]

0 commit comments

Comments
 (0)