Skip to content

Commit f3a1c10

Browse files
committed
data type tests
1 parent f175f52 commit f3a1c10

3 files changed

Lines changed: 20 additions & 7 deletions

File tree

src/valor_lite/cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class DataType(StrEnum):
1717
STRING = "string"
1818
TIMESTAMP = "timestamp"
1919

20-
def to_type(self):
20+
def to_py(self):
2121
match self:
2222
case DataType.INTEGER:
2323
return int
@@ -37,7 +37,7 @@ def to_arrow(self):
3737
case DataType.STRING:
3838
return pa.string()
3939
case DataType.TIMESTAMP:
40-
return pa.timestamp()
40+
return pa.timestamp("us")
4141

4242

4343
class CacheReader:

src/valor_lite/object_detection/loader.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@ def __init__(
2727
self,
2828
directory: str | Path = ".valor",
2929
name: str = "default",
30-
batch_size: int = 1,
31-
rows_per_file: int = 1,
32-
# batch_size: int = 10_000,
33-
# rows_per_file: int = 1_000_000,
30+
batch_size: int = 10_000,
31+
rows_per_file: int = 100_000,
3432
compression: str = "snappy",
3533
datum_metadata_types: dict[str, DataType] | None = None,
3634
groundtruth_metadata_types: dict[str, DataType] | None = None,

tests/common/test_cache.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
11
import tempfile
2+
from datetime import datetime
23
from pathlib import Path
34

45
import numpy as np
56
import pyarrow as pa
67

7-
from valor_lite.cache import CacheWriter
8+
from valor_lite.cache import CacheWriter, DataType
9+
10+
11+
def test_datatype_casting_to_arrow():
12+
assert DataType.FLOAT.to_arrow() == pa.float64()
13+
assert DataType.INTEGER.to_arrow() == pa.int64()
14+
assert DataType.STRING.to_arrow() == pa.string()
15+
assert DataType.TIMESTAMP.to_arrow() == pa.timestamp("us")
16+
17+
18+
def test_datatype_casting_to_python():
19+
assert DataType.FLOAT.to_py() is float
20+
assert DataType.INTEGER.to_py() is int
21+
assert DataType.STRING.to_py() is str
22+
assert DataType.TIMESTAMP.to_py() is datetime
823

924

1025
def test_cache_write_batch():

0 commit comments

Comments
 (0)