Skip to content

Commit 818d39f

Browse files
authored
mt merge final version (#1802)
1 parent fbcdc4b commit 818d39f

10 files changed

+737
-61
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Finalize covdb scaling plan (535k samples)
2+
3+
## Purpose
4+
This document summarizes the current behavior of `determine_hom_refs_from_covdb`, the scaling bottlenecks observed at very large sample counts, and the recommended changes that make the step tractable for ~535k samples while preserving semantics.
5+
6+
## Context
7+
The finalize step updates a sparse mtDNA MatrixTable by using per‑sample coverage from `coverage.h5` to distinguish between **missing** and **hom‑ref** genotypes. The inputs are extremely sparse (observed missing fraction ~0.9993), which means any strategy that scans the entire entry space repeatedly or attempts per‑entry random lookups will not scale.
8+
9+
---
10+
11+
## What the current code does (summary)
12+
The current `determine_hom_refs_from_covdb` implementation:
13+
14+
1. Maps each MT sample to a row index in `coverage.h5` and each MT position to a column index in `coverage.h5`.
15+
2. Splits positions into blocks (`position_block_size`).
16+
3. For each block:
17+
- Reads coverage for **all samples** × **positions in the block** from HDF5.
18+
- Builds a per‑block literal and annotates entries with `__cov` using:
19+
- `mt = mt.annotate_entries(__cov=hl.if_else(mt.__block == block_id, cov_expr, mt.__cov))`
20+
4. After all blocks, applies the hom‑ref logic:
21+
- If `HL` missing and `DP` > threshold → set `HL=0`, `FT=PASS`, `DP=coverage`.
22+
- Otherwise keep missing.
23+
24+
### Why this becomes slow
25+
The per‑block `annotate_entries` updates are evaluated **across all entries** each time, even though only one block should change. With ~16–20 blocks, this effectively multiplies entry‑level work ~16–20×. This behavior is the main scaling problem at 535k samples.
26+
27+
---
28+
29+
## Recommended plan (scalable approach)
30+
### A) Keep the same semantics, but avoid repeated full‑MT entry scans
31+
**Key change:** apply entry annotations only to the rows in the current block, then recombine.
32+
33+
**High‑level pattern:**
34+
35+
- Compute `__block` once per row.
36+
- For each block:
37+
- `mt_b = mt.filter_rows(mt.__block == block_id)`
38+
- Compute coverage for that block and annotate **only** `mt_b` entries.
39+
- Apply hom‑ref logic on `mt_b`.
40+
- Checkpoint `mt_b`.
41+
- Union blocks via a **small fan‑in tree** (not a long linear chain).
42+
43+
This preserves sparsity and ensures each row/entry is processed **once**, not once per block.
44+
45+
### B) Remove unnecessary shuffles
46+
The current `mt.repartition(n_blocks, shuffle=True)` forces a global shuffle. In the block‑local pattern it is not needed and should be removed or replaced with a non‑shuffle repartition only when required for output sizing.
47+
48+
### C) Avoid global `__cov` entry field
49+
Only create `__cov` inside each block MT (`mt_b`). This prevents inflating the entry schema for the full MT and reduces IR size.
50+
51+
### D) Combine with **sample sharding**
52+
At 535k samples, even a perfect block‑local refactor can still be heavy. The strongest scaling improvement comes from **sharding by samples** and running finalize per shard, then `union_cols` the shards and compute cohort statistics afterwards.
53+
54+
Recommended structure:
55+
56+
1. **Split columns** into N shards (e.g., 20–24).
57+
2. For each shard:
58+
- Run block‑local hom‑ref imputation (A–C above).
59+
3. **Union columns** across shards.
60+
4. Run cohort‑wide row annotations (AC/AN/AF, histograms, hap/pop splits) once.
61+
62+
---
63+
64+
## Why this scales to 535k samples
65+
1. **No repeated full‑matrix entry scans**: each row/entry is processed once.
66+
2. **Sparse‑preserving**: we never densify the MT; missing entries stay missing unless coverage supports hom‑ref.
67+
3. **Reduced driver pressure**: block literals stay small and per‑shard.
68+
4. **Parallelism**: sample shards scale horizontally; each shard’s workload is smaller and independent.
69+
5. **Correctness preserved**: the hom‑ref logic is unchanged; only the execution plan is optimized.
70+
71+
---
72+
73+
## Before vs. After (behavioral differences)
74+
| Aspect | Current code | Recommended code |
75+
|---|---|---|
76+
| Entry‑level evaluation | Full MT per block | Only rows in block |
77+
| Shuffling | Global shuffle (`shuffle=True`) | Removed or minimized |
78+
| `__cov` creation | Whole MT | Block‑local only |
79+
| Recombination | Single MT updated per block | Union of block MTs (fan‑in) |
80+
| Scaling | Work ~ (#blocks × all entries) | Work ~ (all entries once) |
81+
82+
---
83+
84+
## Final notes
85+
- The hom‑ref logic is **identical** to current behavior.
86+
- The plan avoids any “lazy per‑entry lookup” that would cause billions of random HDF5 reads.
87+
- The strategy remains compatible with downstream `add_annotations` and other cohort‑wide statistics, which should be computed **after** unioning shards.
88+
89+
If needed, a follow‑on doc can include concrete code changes or a WDL wiring plan for shard/fan‑in orchestration.
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#!/usr/bin/env python3
2+
"""Shard a MatrixTable by columns (samples).
3+
4+
This reads a single MT and writes multiple MT shards, each containing a subset
5+
of samples. Shards are written to --out-dir as shard_XXXXX.mt directories.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import argparse
11+
import logging
12+
import math
13+
import os
14+
from typing import List
15+
16+
import hail as hl
17+
18+
19+
def _parse_args() -> argparse.Namespace:
20+
p = argparse.ArgumentParser(description="Shard a MatrixTable by columns")
21+
p.add_argument("--in-mt", required=True, help="Input MT path")
22+
p.add_argument("--out-dir", required=True, help="Output directory for shard MTs")
23+
p.add_argument("--temp-dir", required=True, help="Temp directory for Hail/Spark scratch")
24+
p.add_argument("--shard-size", type=int, default=25000, help="Samples per shard")
25+
p.add_argument(
26+
"--n-final-partitions",
27+
type=int,
28+
default=256,
29+
help="Target partitions per shard",
30+
)
31+
p.add_argument("--overwrite", action="store_true")
32+
return p.parse_args()
33+
34+
35+
def _chunk_list(items: List[str], chunk_size: int) -> List[List[str]]:
36+
return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)]
37+
38+
39+
def main(args: argparse.Namespace) -> None:
40+
logging.basicConfig(
41+
level=logging.INFO,
42+
format="%(levelname)s: %(asctime)s: %(message)s",
43+
datefmt="%m/%d/%Y %I:%M:%S %p",
44+
)
45+
logger = logging.getLogger("shard_mt_by_samples")
46+
47+
if args.shard_size <= 0:
48+
raise ValueError("shard-size must be > 0")
49+
50+
os.makedirs(args.out_dir, exist_ok=True)
51+
52+
try:
53+
hl.current_backend()
54+
except Exception:
55+
hl.init(tmp_dir=args.temp_dir)
56+
57+
logger.info("Reading MT: %s", args.in_mt)
58+
mt = hl.read_matrix_table(args.in_mt)
59+
60+
mt = mt.add_col_index(name="__col_idx")
61+
ht_cols = mt.cols()
62+
col_rows = ht_cols.select(idx=ht_cols.__col_idx).collect()
63+
col_rows.sort(key=lambda r: int(r.idx))
64+
samples = [r.s for r in col_rows]
65+
66+
if not samples:
67+
raise ValueError("MT has 0 columns; cannot shard")
68+
69+
n_shards = int(math.ceil(len(samples) / args.shard_size))
70+
logger.info("Sharding %d samples into %d shards", len(samples), n_shards)
71+
72+
for shard_id, shard_samples in enumerate(_chunk_list(samples, args.shard_size)):
73+
shard_name = f"shard_{shard_id:05d}.mt"
74+
out_mt = os.path.join(args.out_dir, shard_name)
75+
76+
if hl.hadoop_exists(f"{out_mt}/_SUCCESS") and not args.overwrite:
77+
logger.info("Shard exists, skipping: %s", out_mt)
78+
continue
79+
80+
sample_set = hl.literal(set(shard_samples))
81+
mt_shard = mt.filter_cols(sample_set.contains(mt.s)).drop("__col_idx")
82+
mt_shard = mt_shard.naive_coalesce(args.n_final_partitions)
83+
84+
logger.info(
85+
"Writing shard %d/%d (%d samples) -> %s",
86+
shard_id + 1,
87+
n_shards,
88+
len(shard_samples),
89+
out_mt,
90+
)
91+
mt_shard.checkpoint(out_mt, overwrite=args.overwrite)
92+
93+
94+
if __name__ == "__main__":
95+
main(_parse_args())
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/usr/bin/env python3
2+
"""Union MT shards by columns with row-key alignment.
3+
4+
Inputs
5+
------
6+
* --mt-list-tsv: TSV with header 'mt_path' and one MT path per line.
7+
8+
Output
9+
------
10+
* --out-mt: merged MT directory.
11+
"""
12+
13+
from __future__ import annotations
14+
15+
import argparse
16+
import logging
17+
18+
import hail as hl
19+
20+
21+
def _parse_args() -> argparse.Namespace:
22+
p = argparse.ArgumentParser(description="Union MT shards by columns")
23+
p.add_argument("--mt-list-tsv", required=True, help="TSV with header mt_path")
24+
p.add_argument("--out-mt", required=True, help="Output MT path")
25+
p.add_argument("--temp-dir", required=True, help="Temp directory for Hail/Spark scratch")
26+
p.add_argument(
27+
"--n-final-partitions",
28+
type=int,
29+
default=1000,
30+
help="Target partitions for the merged MT",
31+
)
32+
p.add_argument("--overwrite", action="store_true")
33+
return p.parse_args()
34+
35+
36+
def main(args: argparse.Namespace) -> None:
37+
logging.basicConfig(
38+
level=logging.INFO,
39+
format="%(levelname)s: %(asctime)s: %(message)s",
40+
datefmt="%m/%d/%Y %I:%M:%S %p",
41+
)
42+
logger = logging.getLogger("union_mt_shards")
43+
44+
hl.init(tmp_dir=args.temp_dir)
45+
46+
if hl.hadoop_exists(f"{args.out_mt}/_SUCCESS") and not args.overwrite:
47+
logger.info("Output exists, reading: %s", args.out_mt)
48+
_ = hl.read_matrix_table(args.out_mt)
49+
return
50+
51+
ht = hl.import_table(args.mt_list_tsv, impute=False, types={"mt_path": hl.tstr})
52+
paths = ht.aggregate(hl.agg.collect(ht.mt_path))
53+
paths = [p for p in paths if p is not None and len(p) > 0]
54+
55+
if len(paths) == 0:
56+
raise ValueError("mt-list-tsv contained 0 mt_path entries")
57+
58+
logger.info("Unioning %d MT shards", len(paths))
59+
mts = [hl.read_matrix_table(p) for p in paths]
60+
61+
merged = mts[0]
62+
for mt in mts[1:]:
63+
merged = merged.union_cols(mt, row_join_type="outer")
64+
65+
logger.info("Checkpoint merged MT: %s", args.out_mt)
66+
merged = merged.repartition(args.n_final_partitions)
67+
merged.checkpoint(args.out_mt, overwrite=args.overwrite)
68+
69+
70+
if __name__ == "__main__":
71+
main(_parse_args())

all_of_us/mitochondria/mtSwirl_refactor/add_annotations.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import sys, os
1010
sys.path.append('/home/jupyter/')
1111

12-
hl.init(log='annotations_logging.log')
13-
1412
from collections import Counter
1513
from textwrap import dedent
1614

@@ -57,11 +55,7 @@
5755
logger = logging.getLogger("add annotations")
5856
logger.setLevel(logging.INFO)
5957

60-
if int(hl.version().split('-')[0].split('.')[2]) >= 75: # only use this if using hail 0.2.75 or greater
61-
logger.info("Setting hail flag to avoid array index out of bounds error...")
62-
# Setting this flag isn't generally recommended, but is needed (since at least Hail version 0.2.75) to avoid an array index out of bounds error until changes are made in future versions of Hail
63-
# TODO: reassess if this flag is still needed for future versions of Hail
64-
hl._set_flags(no_whole_stage_codegen="1")
58+
_HAIL_VERSION_MINOR = int(hl.version().split('-')[0].split('.')[2])
6559

6660

6761
def add_genotype(mt_path: str, min_hom_threshold: float = 0.95) -> hl.MatrixTable:
@@ -2219,6 +2213,17 @@ def main(args): # noqa: D103
22192213
mt_path = args.mt_path
22202214
output_dir = args.output_dir
22212215
temp_dir = args.temp_dir
2216+
hl.init(
2217+
log='annotations_logging.log',
2218+
tmp_dir=f"file://{os.path.abspath(temp_dir)}",
2219+
local_tmpdir=f"file://{os.path.abspath(temp_dir)}",
2220+
spark_conf={"spark.local.dir": os.path.abspath(temp_dir)},
2221+
)
2222+
if _HAIL_VERSION_MINOR >= 75: # only use this if using hail 0.2.75 or greater
2223+
logger.info("Setting hail flag to avoid array index out of bounds error...")
2224+
# Setting this flag isn't generally recommended, but is needed (since at least Hail version 0.2.75) to avoid an array index out of bounds error until changes are made in future versions of Hail
2225+
# TODO: reassess if this flag is still needed for future versions of Hail
2226+
hl._set_flags(no_whole_stage_codegen="1")
22222227
participant_data = args.participant_data
22232228
vep_results = args.vep_results
22242229
min_hom_threshold = args.min_hom_threshold

0 commit comments

Comments
 (0)