Skip to content

Commit 0bb63d6

Browse files
committed
getting obj det working
1 parent da22674 commit 0bb63d6

11 files changed

Lines changed: 1254 additions & 1536 deletions

File tree

src/valor_lite/common/ephemeral.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,43 @@ class MemoryCache:
88
def __init__(self, table: pa.Table):
99
self._table = table
1010

11+
def count_tables(self) -> int:
12+
"""Count the number of tables in the cache."""
13+
return 1
14+
1115
def count_rows(self) -> int:
1216
"""Count the number of rows in the cache."""
1317
return self._table.num_rows
1418

1519

20+
class MemoryCacheReader(MemoryCache):
21+
def __init__(
22+
self,
23+
table: pa.Table,
24+
):
25+
self._table = table
26+
self._schema = self._table.schema
27+
28+
@classmethod
29+
def load(cls, table: pa.Table):
30+
"""
31+
Load cache from table.
32+
33+
Parameters
34+
----------
35+
cache : MemoryCacheWriter
36+
A cache writer containing the ephemeral cache.
37+
"""
38+
return cls(table=table)
39+
40+
def iterate_tables(self, columns: list[str] | None = None):
41+
"""Iterate over tables within the cache."""
42+
if columns:
43+
yield self._table.select(columns=columns)
44+
else:
45+
yield self._table
46+
47+
1648
class MemoryCacheWriter(MemoryCache):
1749
def __init__(
1850
self,
@@ -47,13 +79,6 @@ def create(
4779
batch_size=batch_size,
4880
)
4981

50-
def delete(self):
51-
"""
52-
Delete any existing cache data.
53-
"""
54-
self._buffer = []
55-
self._table = self._table.schema.empty_table()
56-
5782
def write_rows(
5883
self,
5984
rows: list[dict[str, Any]],
@@ -146,31 +171,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
146171
"""Context manager exit - ensures data is flushed."""
147172
self.flush()
148173

149-
150-
class MemoryCacheReader:
151-
def __init__(
152-
self,
153-
cache: MemoryCacheWriter,
154-
):
155-
self._cache = cache
156-
self._schema = self._cache._schema
157-
158-
@classmethod
159-
def load(cls, cache: MemoryCacheWriter):
160-
"""
161-
Load cache from table.
162-
163-
Parameters
164-
----------
165-
cache : MemoryCacheWriter
166-
A cache writer containing the ephemeral cache.
167-
"""
168-
return cls(cache=cache)
169-
170-
def iterate_tables(self):
171-
"""Iterate over tables within the cache."""
172-
yield self._cache._table
173-
174-
def count_rows(self) -> int:
175-
"""Count the number of rows in the cache."""
176-
return self._cache.count_rows()
174+
def to_reader(self) -> MemoryCacheReader:
175+
"""Get cache reader."""
176+
return MemoryCacheReader.load(self._table)

src/valor_lite/common/persistent.py

Lines changed: 104 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@ def _decode_schema(encoded_schema: str) -> pa.Schema:
4646
schema_bytes = base64.b64decode(encoded_schema)
4747
return pa.ipc.read_schema(pa.BufferReader(schema_bytes))
4848

49+
def count_rows(self) -> int:
50+
"""Count the number of rows in the cache."""
51+
dataset = ds.dataset(
52+
source=self._path,
53+
format="parquet",
54+
)
55+
return dataset.count_rows()
56+
57+
def count_tables(self) -> int:
58+
"""Count the number of files in the cache."""
59+
return len(self.get_dataset_files())
60+
4961
def get_files(self) -> list[Path]:
5062
"""
5163
Retrieve all files.
@@ -80,6 +92,74 @@ def get_dataset_files(self) -> list[Path]:
8092
]
8193

8294

95+
class FileCacheReader(FileCache):
96+
def __init__(
97+
self,
98+
path: str | Path,
99+
schema: pa.Schema,
100+
):
101+
self._schema = schema
102+
self._path = Path(path)
103+
104+
@classmethod
105+
def load(cls, path: str | Path | FileCache):
106+
"""
107+
Load cache from disk.
108+
109+
Parameters
110+
----------
111+
path : str | Path
112+
Where the cache is stored.
113+
"""
114+
if isinstance(path, FileCache):
115+
path = path.path
116+
path = Path(path)
117+
if not path.exists():
118+
raise FileNotFoundError(f"Directory does not exist: {path}")
119+
elif not path.is_dir():
120+
raise NotADirectoryError(
121+
f"Path exists but is not a directory: {path}"
122+
)
123+
124+
def _retrieve(config: dict, key: str):
125+
if value := config.get(key, None):
126+
return value
127+
raise KeyError(
128+
f"'{key}' is not defined within {cls._generate_config_path(path)}"
129+
)
130+
131+
# read configuration file
132+
cfg_path = cls._generate_config_path(path)
133+
with open(cfg_path, "r") as f:
134+
cfg = json.load(f)
135+
schema = cls._decode_schema(_retrieve(cfg, "schema"))
136+
137+
return cls(
138+
schema=schema,
139+
path=path,
140+
)
141+
142+
def iterate_tables(self, columns: list[str] | None = None):
143+
"""Iterate over tables within the cache."""
144+
dataset = ds.dataset(
145+
source=self._path,
146+
schema=self._schema,
147+
format="parquet",
148+
)
149+
for fragment in dataset.get_fragments():
150+
yield fragment.to_table(columns=columns)
151+
152+
def iterate_fragments(self):
153+
"""Iterate over fragments within the file-based cache."""
154+
dataset = ds.dataset(
155+
source=self._path,
156+
schema=self._schema,
157+
format="parquet",
158+
)
159+
for fragment in dataset.get_fragments():
160+
yield fragment
161+
162+
83163
class FileCacheWriter(FileCache):
84164
def __init__(
85165
self,
@@ -108,6 +188,7 @@ def create(
108188
batch_size: int,
109189
rows_per_file: int,
110190
compression: str = "snappy",
191+
delete_if_exists: bool = False,
111192
):
112193
"""
113194
Create a cache on disk.
@@ -124,7 +205,12 @@ def create(
124205
Target number of rows to store per file.
125206
compression : str, default="snappy"
126207
Compression method to use when storing on disk.
208+
delete_if_exists : bool, default=False
209+
Delete the cache if it already exists.
127210
"""
211+
path = Path(path)
212+
if delete_if_exists and path.exists():
213+
cls.delete(path)
128214
Path(path).mkdir(parents=True, exist_ok=False)
129215

130216
# write configuration file
@@ -146,29 +232,33 @@ def create(
146232
compression=compression,
147233
)
148234

149-
def delete(self):
235+
@classmethod
236+
def delete(cls, path: str | Path):
150237
"""
151-
Delete the cache.
238+
Delete a cache at path.
152239
153240
Parameters
154241
----------
155242
path : str | Path
156243
Where the cache is stored.
157244
"""
158-
if not self._path.exists():
245+
path = Path(path)
246+
if not path.exists():
159247
return
160-
# clear buffer
161-
self.flush()
162-
# delete config file
163-
cfg_path = self._generate_config_path(self._path)
164-
if cfg_path.exists() and cfg_path.is_file():
165-
cfg_path.unlink()
248+
166249
# delete dataset files
167-
for file in self.get_dataset_files():
250+
reader = FileCacheReader.load(path)
251+
for file in reader.get_dataset_files():
168252
if file.exists() and file.is_file() and file.suffix == ".parquet":
169253
file.unlink()
254+
255+
# delete config file
256+
cfg_path = cls._generate_config_path(path)
257+
if cfg_path.exists() and cfg_path.is_file():
258+
cfg_path.unlink()
259+
170260
# delete empty cache directory
171-
self._path.rmdir()
261+
path.rmdir()
172262

173263
def write_rows(
174264
self,
@@ -297,69 +387,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
297387
"""Context manager exit - ensures data is flushed."""
298388
self.flush()
299389

300-
301-
class FileCacheReader(FileCache):
302-
def __init__(
303-
self,
304-
path: str | Path,
305-
schema: pa.Schema,
306-
):
307-
self._schema = schema
308-
self._path = Path(path)
309-
310-
@classmethod
311-
def load(cls, path: str | Path | FileCache):
312-
"""
313-
Load cache from disk.
314-
315-
Parameters
316-
----------
317-
path : str | Path
318-
Where the cache is stored.
319-
"""
320-
if isinstance(path, FileCache):
321-
path = path.path
322-
path = Path(path)
323-
if not path.exists():
324-
raise FileNotFoundError(f"Directory does not exist: {path}")
325-
elif not path.is_dir():
326-
raise NotADirectoryError(
327-
f"Path exists but is not a directory: {path}"
328-
)
329-
330-
def _retrieve(config: dict, key: str):
331-
if value := config.get(key, None):
332-
return value
333-
raise KeyError(
334-
f"'{key}' is not defined within {cls._generate_config_path(path)}"
335-
)
336-
337-
# read configuration file
338-
cfg_path = cls._generate_config_path(path)
339-
with open(cfg_path, "r") as f:
340-
cfg = json.load(f)
341-
schema = cls._decode_schema(_retrieve(cfg, "schema"))
342-
343-
return cls(
344-
schema=schema,
345-
path=path,
346-
)
347-
348-
def count_rows(self) -> int:
349-
"""Count the number of rows in the cache."""
350-
dataset = ds.dataset(
351-
source=self._path,
352-
schema=self._schema,
353-
format="parquet",
354-
)
355-
return dataset.count_rows()
356-
357-
def iterate_tables(self):
358-
"""Iterate over tables within the cache."""
359-
dataset = ds.dataset(
360-
source=self._path,
361-
schema=self._schema,
362-
format="parquet",
363-
)
364-
for fragment in dataset.get_fragments():
365-
yield fragment.to_table()
390+
def to_reader(self) -> FileCacheReader:
391+
"""Get cache reader."""
392+
return FileCacheReader.load(path=self.path)
Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .annotation import Bitmask, BoundingBox, Detection, Polygon
2-
from .evaluator import DataType, Filter
3-
from .legacy import DataLoader, Evaluator, Metadata
2+
from .evaluator import DataType, Evaluator, Filter
3+
from .loader import Loader
44
from .metric import Metric, MetricType
55

66
__all__ = [
@@ -10,9 +10,8 @@
1010
"Polygon",
1111
"Metric",
1212
"MetricType",
13-
"DataLoader",
13+
"Loader",
1414
"Evaluator",
1515
"Filter",
1616
"DataType",
17-
"Metadata",
1817
]

src/valor_lite/object_detection/computation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,14 @@ def rank_table(tbl: pa.Table, number_of_labels: int) -> pa.Table:
285285
lower_iou_bound, winning_predictions = calculate_ranking_boundaries(
286286
pairs, number_of_labels=number_of_labels
287287
)
288-
ranked_tbl = ranked_tbl.append_column(
289-
pa.field("iou_prev", pa.float64()),
290-
pa.array(lower_iou_bound, type=pa.float64()),
291-
)
292288
ranked_tbl = ranked_tbl.append_column(
293289
pa.field("high_score", pa.bool_()),
294290
pa.array(winning_predictions, type=pa.bool_()),
295291
)
292+
ranked_tbl = ranked_tbl.append_column(
293+
pa.field("iou_prev", pa.float64()),
294+
pa.array(lower_iou_bound, type=pa.float64()),
295+
)
296296
ranked_tbl = ranked_tbl.sort_by(sorting_args)
297297
return ranked_tbl
298298

0 commit comments

Comments
 (0)