Skip to content

Commit 9a421fa

Browse files
committed
added cache tests
1 parent 53108f2 commit 9a421fa

2 files changed

Lines changed: 87 additions & 6 deletions

File tree

src/valor_lite/cache.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import glob
22
import json
3+
import os
34
from datetime import datetime
45
from enum import StrEnum
56
from pathlib import Path
@@ -75,7 +76,12 @@ def __init__(self, where: str | Path):
7576

7677
@property
7778
def files(self) -> list[str]:
78-
return glob.glob(f"{self._dir}/*")
79+
files = []
80+
for entry in os.listdir(self._dir):
81+
full_path = os.path.join(self._dir, entry)
82+
if os.path.isfile(full_path):
83+
files.append(full_path)
84+
return files
7985

8086
@property
8187
def num_files(self) -> int:

tests/common/test_cache.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
import numpy as np
66
import pyarrow as pa
77

8-
from valor_lite.cache import CacheWriter, DataType
8+
from valor_lite.cache import (
9+
CacheReader,
10+
CacheWriter,
11+
DataType,
12+
convert_type_mapping_to_schema,
13+
)
914

1015

1116
def test_datatype_casting_to_arrow():
@@ -22,6 +27,22 @@ def test_datatype_casting_to_python():
2227
assert DataType.TIMESTAMP.to_py() is datetime
2328

2429

30+
def test_convert_type_mapping_to_schema():
31+
x = convert_type_mapping_to_schema(
32+
{
33+
"a": DataType.FLOAT,
34+
"b": DataType.STRING,
35+
}
36+
)
37+
assert x == [
38+
("a", pa.float64()),
39+
("b", pa.string()),
40+
]
41+
42+
assert convert_type_mapping_to_schema({}) == []
43+
assert convert_type_mapping_to_schema(None) == []
44+
45+
2546
def test_cache_write_batch():
2647
batch_size = 10
2748
rows_per_file = 100
@@ -47,7 +68,7 @@ def test_cache_write_batch():
4768
}
4869
)
4970
cache.flush()
50-
assert cache.num_files == 10
71+
assert cache.num_files == 11
5172
for idx, fragment in enumerate(cache.dataset.get_fragments()):
5273
tbl = fragment.to_table()
5374
assert tbl["some_int"].to_pylist() == [
@@ -91,7 +112,7 @@ def test_cache_write_rows():
91112
]
92113
)
93114
cache.flush()
94-
assert cache.num_files == 10
115+
assert cache.num_files == 11
95116
for idx, fragment in enumerate(cache.dataset.get_fragments()):
96117
tbl = fragment.to_table()
97118
assert tbl["some_int"].to_pylist() == [
@@ -135,9 +156,9 @@ def test_cache_write_table():
135156
]
136157
)
137158
cache.write_table(tbl)
138-
assert cache.num_files == 1
139-
cache.write_table(tbl)
140159
assert cache.num_files == 2
160+
cache.write_table(tbl)
161+
assert cache.num_files == 3
141162
cache.flush()
142163
for _, fragment in enumerate(cache.dataset.get_fragments()):
143164
tbl = fragment.to_table()
@@ -146,3 +167,57 @@ def test_cache_write_table():
146167
assert tbl["some_str"].to_pylist() == [
147168
f"str{i}" for i in range(101)
148169
]
170+
171+
172+
def test_cache_reader():
173+
batch_size = 10
174+
rows_per_file = 100
175+
with tempfile.TemporaryDirectory() as tmpdir:
176+
cache = CacheWriter(
177+
where=Path(tmpdir),
178+
schema=pa.schema(
179+
[
180+
("some_int", pa.int64()),
181+
("some_float", pa.float64()),
182+
("some_str", pa.string()),
183+
]
184+
),
185+
batch_size=batch_size,
186+
rows_per_file=rows_per_file,
187+
)
188+
tbl = pa.Table.from_pylist(
189+
[
190+
{
191+
"some_int": i,
192+
"some_float": np.float64(i),
193+
"some_str": f"str{i}",
194+
}
195+
for i in range(101)
196+
]
197+
)
198+
cache.write_table(tbl)
199+
200+
readonly_cache = CacheReader(where=Path(tmpdir))
201+
assert readonly_cache.num_files == 2
202+
assert readonly_cache.files == [
203+
tmpdir + "/000000.parquet",
204+
tmpdir + "/.cfg",
205+
]
206+
assert readonly_cache.num_dataset_files == 1
207+
assert readonly_cache.dataset_files == [
208+
tmpdir + "/000000.parquet",
209+
]
210+
211+
cache.write_table(tbl)
212+
assert readonly_cache.num_files == 3
213+
assert readonly_cache.num_dataset_files == 2
214+
215+
for _, fragment in enumerate(readonly_cache.dataset.get_fragments()):
216+
tbl = fragment.to_table()
217+
assert tbl["some_int"].to_pylist() == [i for i in range(101)]
218+
assert tbl["some_float"].to_pylist() == [i for i in range(101)]
219+
assert tbl["some_str"].to_pylist() == [
220+
f"str{i}" for i in range(101)
221+
]
222+
223+
assert readonly_cache.dataset.count_rows() == 202

0 commit comments

Comments
 (0)