|
1 | | -import heapq |
2 | 1 | import json |
3 | 2 | from pathlib import Path |
4 | 3 |
|
|
7 | 6 | from numpy.typing import NDArray |
8 | 7 | from tqdm import tqdm |
9 | 8 |
|
10 | | -from valor_lite.common.datatype import DataType, convert_type_mapping_to_schema |
11 | | -from valor_lite.common.ephemeral import MemoryCacheReader, MemoryCacheWriter |
12 | | -from valor_lite.common.persistent import FileCacheWriter |
| 9 | +from valor_lite.cache import ( |
| 10 | + DataType, |
| 11 | + FileCacheWriter, |
| 12 | + MemoryCacheWriter, |
| 13 | + convert_type_mapping_to_fields, |
| 14 | + heapsort, |
| 15 | +) |
13 | 16 | from valor_lite.exceptions import EmptyCacheError |
14 | 17 | from valor_lite.object_detection.annotation import ( |
15 | 18 | Bitmask, |
@@ -58,13 +61,13 @@ def in_memory( |
58 | 61 | groundtruth_metadata_types: dict[str, DataType] | None = None, |
59 | 62 | prediction_metadata_types: dict[str, DataType] | None = None, |
60 | 63 | ): |
61 | | - datum_metadata_fields = convert_type_mapping_to_schema( |
| 64 | + datum_metadata_fields = convert_type_mapping_to_fields( |
62 | 65 | datum_metadata_types |
63 | 66 | ) |
64 | | - groundtruth_metadata_fields = convert_type_mapping_to_schema( |
| 67 | + groundtruth_metadata_fields = convert_type_mapping_to_fields( |
65 | 68 | groundtruth_metadata_types |
66 | 69 | ) |
67 | | - prediction_metadata_fields = convert_type_mapping_to_schema( |
| 70 | + prediction_metadata_fields = convert_type_mapping_to_fields( |
68 | 71 | prediction_metadata_types |
69 | 72 | ) |
70 | 73 |
|
@@ -108,13 +111,13 @@ def persistent( |
108 | 111 | if delete_if_exists and path.exists(): |
109 | 112 | cls.delete(path) |
110 | 113 |
|
111 | | - datum_metadata_fields = convert_type_mapping_to_schema( |
| 114 | + datum_metadata_fields = convert_type_mapping_to_fields( |
112 | 115 | datum_metadata_types |
113 | 116 | ) |
114 | | - groundtruth_metadata_fields = convert_type_mapping_to_schema( |
| 117 | + groundtruth_metadata_fields = convert_type_mapping_to_fields( |
115 | 118 | groundtruth_metadata_types |
116 | 119 | ) |
117 | | - prediction_metadata_fields = convert_type_mapping_to_schema( |
| 120 | + prediction_metadata_fields = convert_type_mapping_to_fields( |
118 | 121 | prediction_metadata_types |
119 | 122 | ) |
120 | 123 |
|
@@ -488,88 +491,47 @@ def rank( |
488 | 491 | for field in self._ranked_writer.schema |
489 | 492 | if field.name not in {"high_score", "iou_prev"} |
490 | 493 | ] |
491 | | - if ( |
492 | | - isinstance(detailed_reader, MemoryCacheReader) |
493 | | - or detailed_reader.count_tables() == 1 |
494 | | - ): |
495 | | - for tbl in detailed_reader.iterate_tables(columns=subset_columns): |
496 | | - ranked_tbl = rank_table(tbl, n_labels) |
497 | | - self._ranked_writer.write_table(ranked_tbl) |
498 | | - elif isinstance(self._ranked_writer, FileCacheWriter): |
| 494 | + if isinstance(self._ranked_writer, FileCacheWriter): |
499 | 495 | if not self._path: |
500 | 496 | raise ValueError( |
501 | 497 | "missing path definition in file-based loader" |
502 | 498 | ) |
503 | 499 | path = self._generate_temporary_cache_path(self._path) |
504 | | - with FileCacheWriter.create( |
| 500 | + tmp_sink = FileCacheWriter.create( |
505 | 501 | path=path, |
506 | | - schema=self._ranked_writer._schema, |
| 502 | + schema=self._ranked_writer.schema, |
507 | 503 | batch_size=self._ranked_writer._batch_size, |
508 | 504 | rows_per_file=self._ranked_writer._rows_per_file, |
509 | 505 | compression=self._ranked_writer._compression, |
510 | 506 | delete_if_exists=True, |
511 | | - ) as tmp_writer: |
512 | | - |
513 | | - # rank individual files |
514 | | - for tbl in detailed_reader.iterate_tables( |
515 | | - columns=subset_columns |
516 | | - ): |
517 | | - ranked_tbl = rank_table(tbl, n_labels) |
518 | | - tmp_writer.write_table(ranked_tbl) |
519 | | - |
520 | | - tmp_reader = tmp_writer.to_reader() |
521 | | - |
522 | | - def generate_heap_item(batches, batch_idx, row_idx) -> tuple: |
523 | | - score = batches[batch_idx]["score"][row_idx].as_py() |
524 | | - iou = batches[batch_idx]["iou"][row_idx].as_py() |
525 | | - return ( |
526 | | - -score, |
527 | | - -iou, |
528 | | - batch_idx, |
529 | | - row_idx, |
530 | | - ) |
531 | | - |
532 | | - # merge sorted rows |
533 | | - heap = [] |
534 | | - batch_iterators = [] |
535 | | - batches = [] |
536 | | - for batch_idx, batch_fragment in enumerate( |
537 | | - tmp_reader.iterate_fragments() |
538 | | - ): |
539 | | - batch_iter = batch_fragment.to_batches(batch_size=batch_size) |
540 | | - batch_iterators.append(batch_iter) |
541 | | - batches.append(next(batch_iterators[batch_idx], None)) |
542 | | - if ( |
543 | | - batches[batch_idx] is not None |
544 | | - and len(batches[batch_idx]) > 0 |
545 | | - ): |
546 | | - heapq.heappush( |
547 | | - heap, generate_heap_item(batches, batch_idx, 0) |
548 | | - ) |
549 | | - |
550 | | - while heap: |
551 | | - _, _, batch_idx, row_idx = heapq.heappop(heap) |
552 | | - row_table = batches[batch_idx].slice(row_idx, 1) |
553 | | - self._ranked_writer.write_batch(row_table) |
554 | | - row_idx += 1 |
555 | | - if row_idx < len(batches[batch_idx]): |
556 | | - heapq.heappush( |
557 | | - heap, |
558 | | - generate_heap_item(batches, batch_idx, row_idx), |
559 | | - ) |
560 | | - else: |
561 | | - batches[batch_idx] = next(batch_iterators[batch_idx], None) |
562 | | - if ( |
563 | | - batches[batch_idx] is not None |
564 | | - and len(batches[batch_idx]) > 0 |
565 | | - ): |
566 | | - heapq.heappush( |
567 | | - heap, |
568 | | - generate_heap_item(batches, batch_idx, 0), |
569 | | - ) |
| 507 | + ) |
| 508 | + else: |
| 509 | + tmp_sink = MemoryCacheWriter.create( |
| 510 | + schema=self._ranked_writer.schema, |
| 511 | + batch_size=self._ranked_writer._batch_size, |
| 512 | + ) |
| 513 | + |
| 514 | + # rank individual files |
| 515 | + for tbl in detailed_reader.iterate_tables(columns=subset_columns): |
| 516 | + ranked_tbl = rank_table(tbl, n_labels) |
| 517 | + tmp_sink.write_table(ranked_tbl) |
| 518 | + tmp_source = tmp_sink.to_reader() |
| 519 | + |
| 520 | + # sort ranked pairs across all chunks |
| 521 | + heapsort( |
| 522 | + source=tmp_source, |
| 523 | + sink=self._ranked_writer, |
| 524 | + batch_size=batch_size, |
| 525 | + sorting=[ |
| 526 | + ("score", "descending"), |
| 527 | + ("iou", "descending"), |
| 528 | + ], |
| 529 | + ) |
570 | 530 |
|
571 | | - FileCacheWriter.delete(path) |
| 531 | + # clean up |
572 | 532 | self._ranked_writer.flush() |
| 533 | + if isinstance(tmp_sink, FileCacheWriter): |
| 534 | + FileCacheWriter.delete(tmp_sink.path) |
573 | 535 |
|
574 | 536 | def finalize( |
575 | 537 | self, |
|
0 commit comments