Skip to content

Commit f482a33

Browse files
committed
pulled sorting out of loader
1 parent 0456bd3 commit f482a33

11 files changed

Lines changed: 151 additions & 92 deletions

File tree

src/valor_lite/cache/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from .compute import heapsort
2+
from .datatype import DataType, convert_type_mapping_to_fields
3+
from .ephemeral import MemoryCacheReader, MemoryCacheWriter
4+
from .persistent import FileCacheReader, FileCacheWriter
5+
6+
__all__ = [
7+
"DataType",
8+
"convert_type_mapping_to_fields",
9+
"FileCacheReader",
10+
"FileCacheWriter",
11+
"MemoryCacheReader",
12+
"MemoryCacheWriter",
13+
"heapsort",
14+
]

src/valor_lite/cache/compute.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import heapq
2+
3+
import pyarrow as pa
4+
5+
from valor_lite.cache.ephemeral import MemoryCacheReader, MemoryCacheWriter
6+
from valor_lite.cache.persistent import FileCacheReader, FileCacheWriter
7+
8+
9+
def heapsort(
10+
source: MemoryCacheReader | FileCacheReader,
11+
sink: MemoryCacheWriter | FileCacheWriter,
12+
batch_size: int,
13+
sorting: list[tuple[str, str]],
14+
):
15+
"""
16+
Perform heapsort on a cache object.
17+
18+
Parameters
19+
----------
20+
source : MemoryCacheReader | FileCacheReader
21+
The read-only source cache.
22+
sink : MemoryCacheWriter | FileCacheWriter
23+
The cache where sorted data will be written.
24+
batch_size : int
25+
Maximum number of rows allowed to be read into memory per cache file.
26+
sorting : list[tuple[str, str]]
27+
Sorting arguments in PyArrow format (e.g. [('a', 'ascending'), ('b', 'descending')]).
28+
"""
29+
if source.count_tables() == 1 or isinstance(source, MemoryCacheReader):
30+
for tbl in source.iterate_tables():
31+
sorted_tbl = tbl.sort_by(sorting)
32+
sink.write_table(sorted_tbl)
33+
else:
34+
35+
def create_sort_key(
36+
batches: list[pa.RecordBatch],
37+
batch_idx: int,
38+
row_idx: int,
39+
):
40+
args = [
41+
-batches[batch_idx][name][row_idx].as_py()
42+
if direction == "descending"
43+
else batches[batch_idx][name][row_idx].as_py()
44+
for name, direction in sorting
45+
]
46+
return (
47+
*args,
48+
batch_idx,
49+
row_idx,
50+
)
51+
52+
# merge sorted rows
53+
heap = []
54+
batch_iterators = []
55+
batches = []
56+
for batch_idx, batch_fragment in enumerate(source.iterate_fragments()):
57+
batch_iter = batch_fragment.to_batches(batch_size=batch_size)
58+
batch_iterators.append(batch_iter)
59+
batches.append(next(batch_iterators[batch_idx], None))
60+
if batches[batch_idx] is not None and len(batches[batch_idx]) > 0:
61+
heapq.heappush(heap, create_sort_key(batches, batch_idx, 0))
62+
63+
while heap:
64+
_, _, batch_idx, row_idx = heapq.heappop(heap)
65+
row_table = batches[batch_idx].slice(row_idx, 1)
66+
sink.write_batch(row_table)
67+
row_idx += 1
68+
if row_idx < len(batches[batch_idx]):
69+
heapq.heappush(
70+
heap,
71+
create_sort_key(batches, batch_idx, row_idx),
72+
)
73+
else:
74+
batches[batch_idx] = next(batch_iterators[batch_idx], None)
75+
if (
76+
batches[batch_idx] is not None
77+
and len(batches[batch_idx]) > 0
78+
):
79+
heapq.heappush(
80+
heap,
81+
create_sort_key(batches, batch_idx, 0),
82+
)
83+
84+
# flush any buffers
85+
sink.flush()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def to_arrow(self):
3636
return pa.timestamp("us")
3737

3838

39-
def convert_type_mapping_to_schema(
39+
def convert_type_mapping_to_fields(
4040
type_mapping: dict[str, DataType] | None
4141
) -> list[tuple[str, pl.DataType]]:
4242
"""

src/valor_lite/object_detection/evaluator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
import pyarrow.compute as pc
99
from numpy.typing import NDArray
1010

11-
from valor_lite.common.datatype import DataType
12-
from valor_lite.common.ephemeral import MemoryCacheReader
13-
from valor_lite.common.persistent import FileCacheReader
11+
from valor_lite.cache import DataType, FileCacheReader, MemoryCacheReader
1412
from valor_lite.object_detection.computation import (
1513
compute_average_precision,
1614
compute_average_recall,

src/valor_lite/object_detection/loader.py

Lines changed: 42 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import heapq
21
import json
32
from pathlib import Path
43

@@ -7,9 +6,13 @@
76
from numpy.typing import NDArray
87
from tqdm import tqdm
98

10-
from valor_lite.common.datatype import DataType, convert_type_mapping_to_schema
11-
from valor_lite.common.ephemeral import MemoryCacheReader, MemoryCacheWriter
12-
from valor_lite.common.persistent import FileCacheWriter
9+
from valor_lite.cache import (
10+
DataType,
11+
FileCacheWriter,
12+
MemoryCacheWriter,
13+
convert_type_mapping_to_fields,
14+
heapsort,
15+
)
1316
from valor_lite.exceptions import EmptyCacheError
1417
from valor_lite.object_detection.annotation import (
1518
Bitmask,
@@ -58,13 +61,13 @@ def in_memory(
5861
groundtruth_metadata_types: dict[str, DataType] | None = None,
5962
prediction_metadata_types: dict[str, DataType] | None = None,
6063
):
61-
datum_metadata_fields = convert_type_mapping_to_schema(
64+
datum_metadata_fields = convert_type_mapping_to_fields(
6265
datum_metadata_types
6366
)
64-
groundtruth_metadata_fields = convert_type_mapping_to_schema(
67+
groundtruth_metadata_fields = convert_type_mapping_to_fields(
6568
groundtruth_metadata_types
6669
)
67-
prediction_metadata_fields = convert_type_mapping_to_schema(
70+
prediction_metadata_fields = convert_type_mapping_to_fields(
6871
prediction_metadata_types
6972
)
7073

@@ -108,13 +111,13 @@ def persistent(
108111
if delete_if_exists and path.exists():
109112
cls.delete(path)
110113

111-
datum_metadata_fields = convert_type_mapping_to_schema(
114+
datum_metadata_fields = convert_type_mapping_to_fields(
112115
datum_metadata_types
113116
)
114-
groundtruth_metadata_fields = convert_type_mapping_to_schema(
117+
groundtruth_metadata_fields = convert_type_mapping_to_fields(
115118
groundtruth_metadata_types
116119
)
117-
prediction_metadata_fields = convert_type_mapping_to_schema(
120+
prediction_metadata_fields = convert_type_mapping_to_fields(
118121
prediction_metadata_types
119122
)
120123

@@ -488,88 +491,47 @@ def rank(
488491
for field in self._ranked_writer.schema
489492
if field.name not in {"high_score", "iou_prev"}
490493
]
491-
if (
492-
isinstance(detailed_reader, MemoryCacheReader)
493-
or detailed_reader.count_tables() == 1
494-
):
495-
for tbl in detailed_reader.iterate_tables(columns=subset_columns):
496-
ranked_tbl = rank_table(tbl, n_labels)
497-
self._ranked_writer.write_table(ranked_tbl)
498-
elif isinstance(self._ranked_writer, FileCacheWriter):
494+
if isinstance(self._ranked_writer, FileCacheWriter):
499495
if not self._path:
500496
raise ValueError(
501497
"missing path definition in file-based loader"
502498
)
503499
path = self._generate_temporary_cache_path(self._path)
504-
with FileCacheWriter.create(
500+
tmp_sink = FileCacheWriter.create(
505501
path=path,
506-
schema=self._ranked_writer._schema,
502+
schema=self._ranked_writer.schema,
507503
batch_size=self._ranked_writer._batch_size,
508504
rows_per_file=self._ranked_writer._rows_per_file,
509505
compression=self._ranked_writer._compression,
510506
delete_if_exists=True,
511-
) as tmp_writer:
512-
513-
# rank individual files
514-
for tbl in detailed_reader.iterate_tables(
515-
columns=subset_columns
516-
):
517-
ranked_tbl = rank_table(tbl, n_labels)
518-
tmp_writer.write_table(ranked_tbl)
519-
520-
tmp_reader = tmp_writer.to_reader()
521-
522-
def generate_heap_item(batches, batch_idx, row_idx) -> tuple:
523-
score = batches[batch_idx]["score"][row_idx].as_py()
524-
iou = batches[batch_idx]["iou"][row_idx].as_py()
525-
return (
526-
-score,
527-
-iou,
528-
batch_idx,
529-
row_idx,
530-
)
531-
532-
# merge sorted rows
533-
heap = []
534-
batch_iterators = []
535-
batches = []
536-
for batch_idx, batch_fragment in enumerate(
537-
tmp_reader.iterate_fragments()
538-
):
539-
batch_iter = batch_fragment.to_batches(batch_size=batch_size)
540-
batch_iterators.append(batch_iter)
541-
batches.append(next(batch_iterators[batch_idx], None))
542-
if (
543-
batches[batch_idx] is not None
544-
and len(batches[batch_idx]) > 0
545-
):
546-
heapq.heappush(
547-
heap, generate_heap_item(batches, batch_idx, 0)
548-
)
549-
550-
while heap:
551-
_, _, batch_idx, row_idx = heapq.heappop(heap)
552-
row_table = batches[batch_idx].slice(row_idx, 1)
553-
self._ranked_writer.write_batch(row_table)
554-
row_idx += 1
555-
if row_idx < len(batches[batch_idx]):
556-
heapq.heappush(
557-
heap,
558-
generate_heap_item(batches, batch_idx, row_idx),
559-
)
560-
else:
561-
batches[batch_idx] = next(batch_iterators[batch_idx], None)
562-
if (
563-
batches[batch_idx] is not None
564-
and len(batches[batch_idx]) > 0
565-
):
566-
heapq.heappush(
567-
heap,
568-
generate_heap_item(batches, batch_idx, 0),
569-
)
507+
)
508+
else:
509+
tmp_sink = MemoryCacheWriter.create(
510+
schema=self._ranked_writer.schema,
511+
batch_size=self._ranked_writer._batch_size,
512+
)
513+
514+
# rank individual files
515+
for tbl in detailed_reader.iterate_tables(columns=subset_columns):
516+
ranked_tbl = rank_table(tbl, n_labels)
517+
tmp_sink.write_table(ranked_tbl)
518+
tmp_source = tmp_sink.to_reader()
519+
520+
# sort ranked pairs across all chunks
521+
heapsort(
522+
source=tmp_source,
523+
sink=self._ranked_writer,
524+
batch_size=batch_size,
525+
sorting=[
526+
("score", "descending"),
527+
("iou", "descending"),
528+
],
529+
)
570530

571-
FileCacheWriter.delete(path)
531+
# clean up
572532
self._ranked_writer.flush()
533+
if isinstance(tmp_sink, FileCacheWriter):
534+
FileCacheWriter.delete(tmp_sink.path)
573535

574536
def finalize(
575537
self,

tests/common/test_datatype.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pyarrow as pa
44

5-
from valor_lite.common.datatype import DataType, convert_type_mapping_to_schema
5+
from valor_lite.cache.datatype import DataType, convert_type_mapping_to_fields
66

77

88
def test_datatype_casting_to_arrow():
@@ -19,8 +19,8 @@ def test_datatype_casting_to_python():
1919
assert DataType.TIMESTAMP.to_py() is datetime
2020

2121

22-
def test_convert_type_mapping_to_schema():
23-
x = convert_type_mapping_to_schema(
22+
def test_convert_type_mapping_to_fields():
23+
x = convert_type_mapping_to_fields(
2424
{
2525
"a": DataType.FLOAT,
2626
"b": DataType.STRING,
@@ -31,5 +31,5 @@ def test_convert_type_mapping_to_schema():
3131
("b", pa.string()),
3232
]
3333

34-
assert convert_type_mapping_to_schema({}) == []
35-
assert convert_type_mapping_to_schema(None) == []
34+
assert convert_type_mapping_to_fields({}) == []
35+
assert convert_type_mapping_to_fields(None) == []

tests/common/test_ephemeral_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import pyarrow as pa
33

4-
from valor_lite.common.ephemeral import MemoryCacheWriter
4+
from valor_lite.cache.ephemeral import MemoryCacheWriter
55

66

77
def test_cache_reader():

tests/common/test_persistent_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pyarrow as pa
66
import pytest
77

8-
from valor_lite.common.persistent import (
8+
from valor_lite.cache.persistent import (
99
FileCache,
1010
FileCacheReader,
1111
FileCacheWriter,

0 commit comments

Comments
 (0)