Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Finalize covdb scaling plan (535k samples)

## Purpose
This document summarizes the current behavior of `determine_hom_refs_from_covdb`, the scaling bottlenecks observed at very large sample counts, and the recommended changes that make the step tractable for ~535k samples while preserving semantics.

## Context
The finalize step updates a sparse mtDNA MatrixTable by using per‑sample coverage from `coverage.h5` to distinguish between **missing** and **hom‑ref** genotypes. The inputs are extremely sparse (observed missing fraction ~0.9993), which means any strategy that scans the entire entry space repeatedly or attempts per‑entry random lookups will not scale.

---

## What the current code does (summary)
The current `determine_hom_refs_from_covdb` implementation:

1. Maps each MT sample to a row index in `coverage.h5` and each MT position to a column index in `coverage.h5`.
2. Splits positions into blocks (`position_block_size`).
3. For each block:
- Reads coverage for **all samples** × **positions in the block** from HDF5.
- Builds a per‑block literal and annotates entries with `__cov` using:
- `mt = mt.annotate_entries(__cov=hl.if_else(mt.__block == block_id, cov_expr, mt.__cov))`
4. After all blocks, applies the hom‑ref logic:
- If `HL` missing and `DP` > threshold → set `HL=0`, `FT=PASS`, `DP=coverage`.
- Otherwise keep missing.

### Why this becomes slow
The per‑block `annotate_entries` updates are evaluated **across all entries** each time, even though only one block should change. With ~16–20 blocks, this effectively multiplies entry‑level work ~16–20×. This behavior is the main scaling problem at 535k samples.

---

## Recommended plan (scalable approach)
### A) Keep the same semantics, but avoid repeated full‑MT entry scans
**Key change:** apply entry annotations only to the rows in the current block, then recombine.

**High‑level pattern:**

- Compute `__block` once per row.
- For each block:
- `mt_b = mt.filter_rows(mt.__block == block_id)`
- Compute coverage for that block and annotate **only** `mt_b` entries.
- Apply hom‑ref logic on `mt_b`.
- Checkpoint `mt_b`.
- Union blocks via a **small fan‑in tree** (not a long linear chain).

This preserves sparsity and ensures each row/entry is processed **once**, not once per block.

### B) Remove unnecessary shuffles
The current `mt.repartition(n_blocks, shuffle=True)` forces a global shuffle. In the block‑local pattern it is not needed and should be removed or replaced with a non‑shuffle repartition only when required for output sizing.

### C) Avoid global `__cov` entry field
Only create `__cov` inside each block MT (`mt_b`). This prevents inflating the entry schema for the full MT and reduces IR size.

### D) Combine with **sample sharding**
At 535k samples, even a perfect block‑local refactor can still be heavy. The strongest scaling improvement comes from **sharding by samples** and running finalize per shard, then `union_cols` the shards and compute cohort statistics afterwards.

Recommended structure:

1. **Split columns** into N shards (e.g., 20–24).
2. For each shard:
- Run block‑local hom‑ref imputation (A–C above).
3. **Union columns** across shards.
4. Run cohort‑wide row annotations (AC/AN/AF, histograms, hap/pop splits) once.

---

## Why this scales to 535k samples
1. **No repeated full‑matrix entry scans**: each row/entry is processed once.
2. **Sparse‑preserving**: we never densify the MT; missing entries stay missing unless coverage supports hom‑ref.
3. **Reduced driver pressure**: block literals stay small and per‑shard.
4. **Parallelism**: sample shards scale horizontally; each shard’s workload is smaller and independent.
5. **Correctness preserved**: the hom‑ref logic is unchanged; only the execution plan is optimized.

---

## Before vs. After (behavioral differences)
| Aspect | Current code | Recommended code |
|---|---|---|
| Entry‑level evaluation | Full MT per block | Only rows in block |
| Shuffling | Global shuffle (`shuffle=True`) | Removed or minimized |
| `__cov` creation | Whole MT | Block‑local only |
| Recombination | Single MT updated per block | Union of block MTs (fan‑in) |
| Scaling | Work ~ (#blocks × all entries) | Work ~ (all entries once) |

---

## Final notes
- The hom‑ref logic is **identical** to current behavior.
- The plan avoids any “lazy per‑entry lookup” that would cause billions of random HDF5 reads.
- The strategy remains compatible with downstream `add_annotations` and other cohort‑wide statistics, which should be computed **after** unioning shards.

If needed, a follow‑on doc can include concrete code changes or a WDL wiring plan for shard/fan‑in orchestration.
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#!/usr/bin/env python3
"""Shard a MatrixTable by columns (samples).

This reads a single MT and writes multiple MT shards, each containing a subset
of samples. Shards are written to --out-dir as shard_XXXXX.mt directories.
"""

from __future__ import annotations

import argparse
import logging
import math
import os
from typing import List

import hail as hl


def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Shard a MatrixTable by columns")
p.add_argument("--in-mt", required=True, help="Input MT path")
p.add_argument("--out-dir", required=True, help="Output directory for shard MTs")
p.add_argument("--temp-dir", required=True, help="Temp directory for Hail/Spark scratch")
p.add_argument("--shard-size", type=int, default=25000, help="Samples per shard")
p.add_argument(
"--n-final-partitions",
type=int,
default=256,
help="Target partitions per shard",
)
p.add_argument("--overwrite", action="store_true")
return p.parse_args()


def _chunk_list(items: List[str], chunk_size: int) -> List[List[str]]:
return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)]


def main(args: argparse.Namespace) -> None:
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(asctime)s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
logger = logging.getLogger("shard_mt_by_samples")

if args.shard_size <= 0:
raise ValueError("shard-size must be > 0")

os.makedirs(args.out_dir, exist_ok=True)

try:
hl.current_backend()
except Exception:
hl.init(tmp_dir=args.temp_dir)

logger.info("Reading MT: %s", args.in_mt)
mt = hl.read_matrix_table(args.in_mt)

mt = mt.add_col_index(name="__col_idx")
ht_cols = mt.cols()
col_rows = ht_cols.select(idx=ht_cols.__col_idx).collect()
col_rows.sort(key=lambda r: int(r.idx))
samples = [r.s for r in col_rows]

if not samples:
raise ValueError("MT has 0 columns; cannot shard")

n_shards = int(math.ceil(len(samples) / args.shard_size))
logger.info("Sharding %d samples into %d shards", len(samples), n_shards)

for shard_id, shard_samples in enumerate(_chunk_list(samples, args.shard_size)):
shard_name = f"shard_{shard_id:05d}.mt"
out_mt = os.path.join(args.out_dir, shard_name)

if hl.hadoop_exists(f"{out_mt}/_SUCCESS") and not args.overwrite:
logger.info("Shard exists, skipping: %s", out_mt)
continue

sample_set = hl.literal(set(shard_samples))
mt_shard = mt.filter_cols(sample_set.contains(mt.s)).drop("__col_idx")
mt_shard = mt_shard.naive_coalesce(args.n_final_partitions)

logger.info(
"Writing shard %d/%d (%d samples) -> %s",
shard_id + 1,
n_shards,
len(shard_samples),
out_mt,
)
mt_shard.checkpoint(out_mt, overwrite=args.overwrite)


if __name__ == "__main__":
main(_parse_args())
71 changes: 71 additions & 0 deletions all_of_us/mitochondria/mtSwirl_refactor/Terra/union_mt_shards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python3
"""Union MT shards by columns with row-key alignment.

Inputs
------
* --mt-list-tsv: TSV with header 'mt_path' and one MT path per line.

Output
------
* --out-mt: merged MT directory.
"""

from __future__ import annotations

import argparse
import logging

import hail as hl


def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Union MT shards by columns")
p.add_argument("--mt-list-tsv", required=True, help="TSV with header mt_path")
p.add_argument("--out-mt", required=True, help="Output MT path")
p.add_argument("--temp-dir", required=True, help="Temp directory for Hail/Spark scratch")
p.add_argument(
"--n-final-partitions",
type=int,
default=1000,
help="Target partitions for the merged MT",
)
p.add_argument("--overwrite", action="store_true")
return p.parse_args()


def main(args: argparse.Namespace) -> None:
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(asctime)s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
logger = logging.getLogger("union_mt_shards")

hl.init(tmp_dir=args.temp_dir)

if hl.hadoop_exists(f"{args.out_mt}/_SUCCESS") and not args.overwrite:
logger.info("Output exists, reading: %s", args.out_mt)
_ = hl.read_matrix_table(args.out_mt)
return

ht = hl.import_table(args.mt_list_tsv, impute=False, types={"mt_path": hl.tstr})
paths = ht.aggregate(hl.agg.collect(ht.mt_path))
paths = [p for p in paths if p is not None and len(p) > 0]

if len(paths) == 0:
raise ValueError("mt-list-tsv contained 0 mt_path entries")

logger.info("Unioning %d MT shards", len(paths))
mts = [hl.read_matrix_table(p) for p in paths]

merged = mts[0]
for mt in mts[1:]:
merged = merged.union_cols(mt, row_join_type="outer")

logger.info("Checkpoint merged MT: %s", args.out_mt)
merged = merged.repartition(args.n_final_partitions)
merged.checkpoint(args.out_mt, overwrite=args.overwrite)


if __name__ == "__main__":
main(_parse_args())
19 changes: 12 additions & 7 deletions all_of_us/mitochondria/mtSwirl_refactor/add_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import sys, os
sys.path.append('/home/jupyter/')

hl.init(log='annotations_logging.log')

from collections import Counter
from textwrap import dedent

Expand Down Expand Up @@ -57,11 +55,7 @@
logger = logging.getLogger("add annotations")
logger.setLevel(logging.INFO)

if int(hl.version().split('-')[0].split('.')[2]) >= 75: # only use this if using hail 0.2.75 or greater
logger.info("Setting hail flag to avoid array index out of bounds error...")
# Setting this flag isn't generally recommended, but is needed (since at least Hail version 0.2.75) to avoid an array index out of bounds error until changes are made in future versions of Hail
# TODO: reassess if this flag is still needed for future versions of Hail
hl._set_flags(no_whole_stage_codegen="1")
_HAIL_VERSION_MINOR = int(hl.version().split('-')[0].split('.')[2])


def add_genotype(mt_path: str, min_hom_threshold: float = 0.95) -> hl.MatrixTable:
Expand Down Expand Up @@ -2219,6 +2213,17 @@ def main(args): # noqa: D103
mt_path = args.mt_path
output_dir = args.output_dir
temp_dir = args.temp_dir
hl.init(
log='annotations_logging.log',
tmp_dir=f"file://{os.path.abspath(temp_dir)}",
local_tmpdir=f"file://{os.path.abspath(temp_dir)}",
spark_conf={"spark.local.dir": os.path.abspath(temp_dir)},
)
if _HAIL_VERSION_MINOR >= 75: # only use this if using hail 0.2.75 or greater
logger.info("Setting hail flag to avoid array index out of bounds error...")
# Setting this flag isn't generally recommended, but is needed (since at least Hail version 0.2.75) to avoid an array index out of bounds error until changes are made in future versions of Hail
# TODO: reassess if this flag is still needed for future versions of Hail
hl._set_flags(no_whole_stage_codegen="1")
participant_data = args.participant_data
vep_results = args.vep_results
min_hom_threshold = args.min_hom_threshold
Expand Down
Loading
Loading