Skip to content

Commit df63d57

Browse files
authored
Merge pull request #766 from broadinstitute/kl/row_len
Edits to check_global_and_row_annot_lengths for efficiency
2 parents 1e459ab + a758187 commit df63d57

File tree

2 files changed

+79
-12
lines changed

2 files changed

+79
-12
lines changed

gnomad/assessment/validity_checks.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,35 +1117,49 @@ def check_global_and_row_annot_lengths(
11171117
t = t.rows() if isinstance(t, hl.MatrixTable) else t
11181118
if not check_all_rows:
11191119
t = t.head(1)
1120+
1121+
n_rows = t.count()
1122+
1123+
global_lengths = {
1124+
global_field: hl.eval(hl.len(t.index_globals()[global_field]))
1125+
for row_field, global_fields in row_to_globals_check.items()
1126+
for global_field in global_fields
1127+
}
1128+
1129+
row_length_counts = {
1130+
row_field: t.aggregate(hl.agg.counter(hl.len(t[row_field])))
1131+
for row_field in row_to_globals_check.keys()
1132+
}
1133+
11201134
for row_field, global_fields in row_to_globals_check.items():
11211135
if not check_all_rows:
11221136
logger.info(
11231137
"Checking length of %s in first row against length of globals: %s",
11241138
row_field,
11251139
global_fields,
11261140
)
1141+
1142+
row_lengths = row_length_counts[row_field]
1143+
11271144
for global_field in global_fields:
1128-
global_len = hl.eval(hl.len(t[global_field]))
1129-
row_len_expr = hl.len(t[row_field])
1130-
failed_rows = t.aggregate(
1131-
hl.struct(
1132-
n_fail=hl.agg.count_where(row_len_expr != global_len),
1133-
row_len=hl.agg.counter(row_len_expr),
1134-
)
1145+
global_len = global_lengths[global_field]
1146+
failed_rows = sum(
1147+
count for length, count in row_lengths.items() if length != global_len
11351148
)
1136-
outcome = "Failed" if failed_rows["n_fail"] > 0 else "Passed"
1137-
n_rows = t.count()
1149+
1150+
outcome = "Failed" if failed_rows > 0 else "Passed"
1151+
11381152
logger.info(
11391153
"%s global and row lengths comparison: Length of %s in"
1140-
" globals (%d) does %smatch length of %s in %d out of %d rows (%s)",
1154+
" globals (%d) does %smatch length of %s in %d out of %d rows (row length counter: %s)",
11411155
outcome,
11421156
global_field,
11431157
global_len,
11441158
"NOT " if outcome == "Failed" else "",
11451159
row_field,
1146-
failed_rows["n_fail"] if outcome == "Failed" else n_rows,
1160+
failed_rows if outcome == "Failed" else n_rows,
11471161
n_rows,
1148-
failed_rows["row_len"],
1162+
row_lengths,
11491163
)
11501164

11511165

tests/assessment/test_validity_checks.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88

99
from gnomad.assessment.validity_checks import (
10+
check_global_and_row_annot_lengths,
1011
check_missingness_of_struct,
1112
check_raw_and_adj_callstats,
1213
check_sex_chr_metrics,
@@ -502,6 +503,58 @@ def test_sum_group_callstats(ht_for_group_sums, caplog) -> None:
502503
), f"Expected phrase missing: {log_phrase}"
503504

504505

506+
@pytest.fixture
507+
def ht_for_check_global_and_row_annot_lengths() -> hl.Table:
508+
"""Fixture to set up a Hail Table with the desired structure and data for check_global_and_row_annot_lengths."""
509+
ht = hl.Table.parallelize(
510+
[
511+
{"freq": [0.1, 0.2, 0.3], "faf": [0.01, 0.02]},
512+
{"freq": [0.8, 0.4, 0.5], "faf": [0.03, 0.04, 0.05]},
513+
],
514+
hl.tstruct(freq=hl.tarray(hl.tfloat64), faf=hl.tarray(hl.tfloat64)),
515+
)
516+
517+
return ht.annotate_globals(
518+
freq_meta=["A", "B", "C"],
519+
freq_index_dict={"A": 0, "B": 1, "C": 2},
520+
freq_meta_sample_count=[100, 200, 300],
521+
faf_meta=["D", "E"],
522+
faf_index_dict={"D": 0, "E": 1},
523+
)
524+
525+
526+
def test_check_global_and_row_annot_lengths(
527+
ht_for_check_global_and_row_annot_lengths, caplog
528+
) -> None:
529+
"""Test that check_global_and_row_annot_lengths produces the expected log messages."""
530+
ht = ht_for_check_global_and_row_annot_lengths
531+
532+
# Define the row_to_globals_check dictionary.
533+
row_to_globals_check = {
534+
"freq": ["freq_meta", "freq_index_dict", "freq_meta_sample_count"],
535+
"faf": ["faf_meta", "faf_index_dict"],
536+
}
537+
538+
with caplog.at_level(logging.INFO, logger="gnomad.assessment.validity_checks"):
539+
check_global_and_row_annot_lengths(
540+
ht, row_to_globals_check, check_all_rows=True
541+
)
542+
543+
log_messages = [record.message for record in caplog.records]
544+
545+
# Verify log messages.
546+
expected_logs = [
547+
"Passed global and row lengths comparison: Length of freq_meta in globals (3) does match length of freq in 2 out of 2 rows (row length counter: {3: 2})",
548+
"Passed global and row lengths comparison: Length of freq_index_dict in globals (3) does match length of freq in 2 out of 2 rows (row length counter: {3: 2})",
549+
"Passed global and row lengths comparison: Length of freq_meta_sample_count in globals (3) does match length of freq in 2 out of 2 rows (row length counter: {3: 2})",
550+
"Failed global and row lengths comparison: Length of faf_meta in globals (2) does NOT match length of faf in 1 out of 2 rows (row length counter: {2: 1, 3: 1})",
551+
"Failed global and row lengths comparison: Length of faf_index_dict in globals (2) does NOT match length of faf in 1 out of 2 rows (row length counter: {2: 1, 3: 1})",
552+
]
553+
554+
for msg in expected_logs:
555+
assert msg in log_messages, f"Expected log message is missing: {msg}"
556+
557+
505558
@pytest.fixture
506559
def ht_for_check_raw_and_adj_callstats() -> hl.Table:
507560
"""Fixture to create a Hail Table with the expected structure and test values for check_raw_and_adj_callstats, using underscore as the delimiter."""

0 commit comments

Comments
 (0)