Skip to content

Commit 8929452

Browse files
authored
Add progress flags (#83)
Add flags for turning off progress bars in more functions
1 parent 02b2576 commit 8929452

2 files changed

Lines changed: 37 additions & 23 deletions

File tree

src/semra/api.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,9 @@ def assert_projection(mappings: list[Mapping]) -> None:
508508
)
509509

510510

511-
def prioritize(mappings: list[Mapping], priority: list[str]) -> list[Mapping]:
511+
def prioritize(
512+
mappings: list[Mapping], priority: list[str], *, progress: bool = True
513+
) -> list[Mapping]:
512514
"""Get a priority star graph.
513515
514516
:param mappings: An iterable of mappings
@@ -535,7 +537,9 @@ def prioritize(mappings: list[Mapping], priority: list[str]) -> list[Mapping]:
535537

536538
graph = to_digraph(mappings).to_undirected()
537539
rv: list[Mapping] = []
538-
for component in tqdm(nx.connected_components(graph), unit="component", unit_scale=True):
540+
for component in tqdm(
541+
nx.connected_components(graph), unit="component", unit_scale=True, disable=not progress
542+
):
539543
o = get_priority_reference(component, priority)
540544
if o is None:
541545
continue

src/semra/pipeline.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def get_mappings(
422422
refresh_processed: bool = False,
423423
refresh_source: bool = False,
424424
return_type: GetMappingReturnType = GetMappingReturnType.none,
425+
progress: bool = True,
425426
) -> list[Mapping] | MappingPack | None:
426427
"""Run assembly based on this configuration."""
427428
return get_priority_mappings_from_config( # type:ignore[no-any-return,call-overload]
@@ -430,6 +431,7 @@ def get_mappings(
430431
refresh_raw=refresh_raw,
431432
refresh_processed=refresh_processed,
432433
return_type=return_type,
434+
progress=progress,
433435
)
434436

435437
def read_raw_mappings(self, *, show_progress: bool = False) -> list[Mapping]:
@@ -710,6 +712,7 @@ def get_priority_mappings_from_config(
710712
refresh_raw: bool = False,
711713
refresh_processed: bool = False,
712714
return_type: GetMappingReturnType = GetMappingReturnType.none,
715+
progress: bool = True,
713716
) -> None | list[Mapping] | MappingPack:
714717
"""Get prioritized mappings based on an assembly configuration."""
715718
if refresh_source:
@@ -727,9 +730,9 @@ def get_priority_mappings_from_config(
727730
if not configuration.has_processed_path():
728731
raise FileNotFoundError
729732
return MappingPack(
730-
raw=configuration.read_raw_mappings(show_progress=True),
731-
processed=configuration.read_processed_mappings(show_progress=True),
732-
priority=configuration.read_priority_mappings(show_progress=True),
733+
raw=configuration.read_raw_mappings(show_progress=progress),
734+
processed=configuration.read_processed_mappings(show_progress=progress),
735+
priority=configuration.read_priority_mappings(show_progress=progress),
733736
)
734737
case GetMappingReturnType.priority:
735738
return configuration.read_priority_mappings()
@@ -749,11 +752,13 @@ def get_priority_mappings_from_config(
749752
configuration.configuration_path.write_text(
750753
configuration.model_dump_json(exclude_none=True, exclude_unset=True, indent=2)
751754
)
752-
raw_mappings = get_raw_mappings(configuration, refresh_source=refresh_source)
755+
raw_mappings = get_raw_mappings(
756+
configuration, refresh_source=refresh_source, show_progress=progress
757+
)
753758
if not raw_mappings:
754759
raise ValueError(f"no raw mappings found for configuration: {configuration.name}")
755760
if configuration.validate_raw:
756-
validate_mappings(raw_mappings)
761+
validate_mappings(raw_mappings, progress=progress)
757762

758763
# TODO stream?
759764
write_sssom(
@@ -764,6 +769,7 @@ def get_priority_mappings_from_config(
764769
write_jsonl(
765770
raw_mappings,
766771
configuration.raw_jsonl_path,
772+
show_progress=progress,
767773
)
768774
if configuration.write_raw_neo4j:
769775
write_neo4j(
@@ -772,6 +778,7 @@ def get_priority_mappings_from_config(
772778
docker_name=configuration.raw_neo4j_name,
773779
add_labels=False, # configuration.add_labels,
774780
compress=configuration.neo4j_gzip,
781+
use_tqdm=progress,
775782
)
776783

777784
# click.echo(semra.api.str_source_target_counts(mappings, minimum=20))
@@ -786,9 +793,10 @@ def get_priority_mappings_from_config(
786793
post_keep_prefixes=configuration.post_keep_prefixes,
787794
remove_imprecise=configuration.remove_imprecise,
788795
subsets=configuration.get_hydrated_subsets(),
796+
progress=progress,
789797
)
790798

791-
prioritized_mappings = prioritize(processed_mappings, configuration.priority)
799+
prioritized_mappings = prioritize(processed_mappings, configuration.priority, progress=progress)
792800
equivalence_classes = _get_equivalence_classes(processed_mappings, prioritized_mappings)
793801
write_sssom(
794802
processed_mappings,
@@ -798,6 +806,7 @@ def get_priority_mappings_from_config(
798806
write_jsonl(
799807
processed_mappings,
800808
configuration.processed_jsonl_path,
809+
show_progress=progress,
801810
)
802811
write_neo4j(
803812
processed_mappings,
@@ -806,12 +815,10 @@ def get_priority_mappings_from_config(
806815
equivalence_classes=equivalence_classes,
807816
add_labels=configuration.add_labels,
808817
compress=configuration.neo4j_gzip,
818+
use_tqdm=progress,
809819
)
810820

811-
write_jsonl(
812-
prioritized_mappings,
813-
configuration.priority_jsonl_path,
814-
)
821+
write_jsonl(prioritized_mappings, configuration.priority_jsonl_path, show_progress=progress)
815822
write_sssom(
816823
prioritized_mappings,
817824
configuration.priority_sssom_path,
@@ -921,15 +928,16 @@ def process(
921928
subsets: SubsetConfiguration | None = None,
922929
*,
923930
remove_imprecise: bool = True,
931+
progress: bool = True,
924932
) -> list[Mapping]:
925933
"""Run a full deduplication, reasoning, and inference pipeline over a set of mappings."""
926934
from semra.sources.biopragmatics import from_biomappings_negative
927935

928936
if keep_prefix_set:
929-
mappings = keep_prefixes(mappings, keep_prefix_set)
937+
mappings = keep_prefixes(mappings, keep_prefix_set, progress=progress)
930938

931939
if remove_prefix_set:
932-
mappings = filter_prefixes(mappings, remove_prefix_set)
940+
mappings = filter_prefixes(mappings, remove_prefix_set, progress=progress)
933941

934942
if subsets:
935943
mappings = list(filter_subsets(mappings, subsets))
@@ -940,13 +948,13 @@ def process(
940948

941949
before = len(mappings)
942950
start = time.time()
943-
mappings = filter_mappings(mappings, negatives)
951+
mappings = filter_mappings(mappings, negatives, progress=progress)
944952
_log_diff(before, mappings, verb="Filtered negative mappings", elapsed=time.time() - start)
945953

946954
# deduplicate
947955
before = len(mappings)
948956
start = time.time()
949-
mappings = assemble_evidences(mappings)
957+
mappings = assemble_evidences(mappings, progress=progress)
950958
_log_diff(before, mappings, verb="Assembled", elapsed=time.time() - start)
951959

952960
# only keep relevant prefixes
@@ -967,7 +975,9 @@ def process(
967975
# resources to each other are exact matches, rewrite the prefixes
968976
before = len(mappings)
969977
start = time.time()
970-
mappings = infer_mutual_dbxref_mutations(mappings, upgrade_prefixes, confidence=0.95)
978+
mappings = infer_mutual_dbxref_mutations(
979+
mappings, upgrade_prefixes, confidence=0.95, progress=progress
980+
)
971981
_log_diff(before, mappings, verb="Inferred upgrades", elapsed=time.time() - start)
972982

973983
# remove database cross-references
@@ -982,30 +992,30 @@ def process(
982992
logger.info("Inferring reverse mappings")
983993
before = len(mappings)
984994
start = time.time()
985-
mappings = infer_reversible(mappings)
995+
mappings = infer_reversible(mappings, progress=progress)
986996
_log_diff(before, mappings, verb="Inferred", elapsed=time.time() - start)
987997

988998
logger.info("Inferring based on chains")
989999
before = len(mappings)
9901000
time.time()
991-
mappings = infer_chains(mappings)
1001+
mappings = infer_chains(mappings, progress=progress)
9921002
_log_diff(before, mappings, verb="Inferred", elapsed=time.time() - start)
9931003

9941004
# 4/5. Filtering negative
9951005
logger.info("Filtering out negative mappings")
9961006
before = len(mappings)
9971007
start = time.time()
998-
mappings = filter_mappings(mappings, negatives)
1008+
mappings = filter_mappings(mappings, negatives, progress=progress)
9991009
_log_diff(before, mappings, verb="Filtered negative mappings", elapsed=time.time() - start)
10001010

10011011
# filter out self mappings again, just in case
1002-
mappings = filter_self_matches(mappings)
1012+
mappings = filter_self_matches(mappings, progress=progress)
10031013

10041014
if post_keep_prefixes:
1005-
mappings = keep_prefixes(mappings, post_keep_prefixes)
1015+
mappings = keep_prefixes(mappings, post_keep_prefixes, progress=progress)
10061016

10071017
if post_remove_prefixes:
1008-
mappings = filter_prefixes(mappings, post_remove_prefixes)
1018+
mappings = filter_prefixes(mappings, post_remove_prefixes, progress=progress)
10091019

10101020
return mappings
10111021

0 commit comments

Comments
 (0)