diff --git a/all_of_us/mitochondria/mtSwirl_refactor/FINALIZE_COVDB_SCALING_PLAN.md b/all_of_us/mitochondria/mtSwirl_refactor/FINALIZE_COVDB_SCALING_PLAN.md new file mode 100644 index 0000000000..daeeedeac8 --- /dev/null +++ b/all_of_us/mitochondria/mtSwirl_refactor/FINALIZE_COVDB_SCALING_PLAN.md @@ -0,0 +1,89 @@ +# Finalize covdb scaling plan (535k samples) + +## Purpose +This document summarizes the current behavior of `determine_hom_refs_from_covdb`, the scaling bottlenecks observed at very large sample counts, and the recommended changes that make the step tractable for ~535k samples while preserving semantics. + +## Context +The finalize step updates a sparse mtDNA MatrixTable by using per‑sample coverage from `coverage.h5` to distinguish between **missing** and **hom‑ref** genotypes. The inputs are extremely sparse (observed missing fraction ~0.9993), which means any strategy that scans the entire entry space repeatedly or attempts per‑entry random lookups will not scale. + +--- + +## What the current code does (summary) +The current `determine_hom_refs_from_covdb` implementation: + +1. Maps each MT sample to a row index in `coverage.h5` and each MT position to a column index in `coverage.h5`. +2. Splits positions into blocks (`position_block_size`). +3. For each block: + - Reads coverage for **all samples** × **positions in the block** from HDF5. + - Builds a per‑block literal and annotates entries with `__cov` using: + - `mt = mt.annotate_entries(__cov=hl.if_else(mt.__block == block_id, cov_expr, mt.__cov))` +4. After all blocks, applies the hom‑ref logic: + - If `HL` missing and `DP` > threshold → set `HL=0`, `FT=PASS`, `DP=coverage`. + - Otherwise keep missing. + +### Why this becomes slow +The per‑block `annotate_entries` updates are evaluated **across all entries** each time, even though only one block should change. With ~16–20 blocks, this effectively multiplies entry‑level work ~16–20×. This behavior is the main scaling problem at 535k samples. + +--- + +## Recommended plan (scalable approach) +### A) Keep the same semantics, but avoid repeated full‑MT entry scans +**Key change:** apply entry annotations only to the rows in the current block, then recombine. + +**High‑level pattern:** + +- Compute `__block` once per row. +- For each block: + - `mt_b = mt.filter_rows(mt.__block == block_id)` + - Compute coverage for that block and annotate **only** `mt_b` entries. + - Apply hom‑ref logic on `mt_b`. + - Checkpoint `mt_b`. +- Union blocks via a **small fan‑in tree** (not a long linear chain). + +This preserves sparsity and ensures each row/entry is processed **once**, not once per block. + +### B) Remove unnecessary shuffles +The current `mt.repartition(n_blocks, shuffle=True)` forces a global shuffle. In the block‑local pattern it is not needed and should be removed or replaced with a non‑shuffle repartition only when required for output sizing. + +### C) Avoid global `__cov` entry field +Only create `__cov` inside each block MT (`mt_b`). This prevents inflating the entry schema for the full MT and reduces IR size. + +### D) Combine with **sample sharding** +At 535k samples, even a perfect block‑local refactor can still be heavy. The strongest scaling improvement comes from **sharding by samples** and running finalize per shard, then `union_cols` the shards and compute cohort statistics afterwards. + +Recommended structure: + +1. **Split columns** into N shards (e.g., 20–24). +2. For each shard: + - Run block‑local hom‑ref imputation (A–C above). +3. **Union columns** across shards. +4. Run cohort‑wide row annotations (AC/AN/AF, histograms, hap/pop splits) once. + +--- + +## Why this scales to 535k samples +1. **No repeated full‑matrix entry scans**: each row/entry is processed once. +2. **Sparse‑preserving**: we never densify the MT; missing entries stay missing unless coverage supports hom‑ref. +3. **Reduced driver pressure**: block literals stay small and per‑shard. +4. **Parallelism**: sample shards scale horizontally; each shard’s workload is smaller and independent. +5. **Correctness preserved**: the hom‑ref logic is unchanged; only the execution plan is optimized. + +--- + +## Before vs. After (behavioral differences) +| Aspect | Current code | Recommended code | +|---|---|---| +| Entry‑level evaluation | Full MT per block | Only rows in block | +| Shuffling | Global shuffle (`shuffle=True`) | Removed or minimized | +| `__cov` creation | Whole MT | Block‑local only | +| Recombination | Single MT updated per block | Union of block MTs (fan‑in) | +| Scaling | Work ~ (#blocks × all entries) | Work ~ (all entries once) | + +--- + +## Final notes +- The hom‑ref logic is **identical** to current behavior. +- The plan avoids any “lazy per‑entry lookup” that would cause billions of random HDF5 reads. +- The strategy remains compatible with downstream `add_annotations` and other cohort‑wide statistics, which should be computed **after** unioning shards. + +If needed, a follow‑on doc can include concrete code changes or a WDL wiring plan for shard/fan‑in orchestration. diff --git a/all_of_us/mitochondria/mtSwirl_refactor/Terra/shard_mt_by_samples.py b/all_of_us/mitochondria/mtSwirl_refactor/Terra/shard_mt_by_samples.py new file mode 100644 index 0000000000..fe05f2dae4 --- /dev/null +++ b/all_of_us/mitochondria/mtSwirl_refactor/Terra/shard_mt_by_samples.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +"""Shard a MatrixTable by columns (samples). + +This reads a single MT and writes multiple MT shards, each containing a subset +of samples. Shards are written to --out-dir as shard_XXXXX.mt directories. +""" + +from __future__ import annotations + +import argparse +import logging +import math +import os +from typing import List + +import hail as hl + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Shard a MatrixTable by columns") + p.add_argument("--in-mt", required=True, help="Input MT path") + p.add_argument("--out-dir", required=True, help="Output directory for shard MTs") + p.add_argument("--temp-dir", required=True, help="Temp directory for Hail/Spark scratch") + p.add_argument("--shard-size", type=int, default=25000, help="Samples per shard") + p.add_argument( + "--n-final-partitions", + type=int, + default=256, + help="Target partitions per shard", + ) + p.add_argument("--overwrite", action="store_true") + return p.parse_args() + + +def _chunk_list(items: List[str], chunk_size: int) -> List[List[str]]: + return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)] + + +def main(args: argparse.Namespace) -> None: + logging.basicConfig( + level=logging.INFO, + format="%(levelname)s: %(asctime)s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + logger = logging.getLogger("shard_mt_by_samples") + + if args.shard_size <= 0: + raise ValueError("shard-size must be > 0") + + os.makedirs(args.out_dir, exist_ok=True) + + try: + hl.current_backend() + except Exception: + hl.init(tmp_dir=args.temp_dir) + + logger.info("Reading MT: %s", args.in_mt) + mt = hl.read_matrix_table(args.in_mt) + + mt = mt.add_col_index(name="__col_idx") + ht_cols = mt.cols() + col_rows = ht_cols.select(idx=ht_cols.__col_idx).collect() + col_rows.sort(key=lambda r: int(r.idx)) + samples = [r.s for r in col_rows] + + if not samples: + raise ValueError("MT has 0 columns; cannot shard") + + n_shards = int(math.ceil(len(samples) / args.shard_size)) + logger.info("Sharding %d samples into %d shards", len(samples), n_shards) + + for shard_id, shard_samples in enumerate(_chunk_list(samples, args.shard_size)): + shard_name = f"shard_{shard_id:05d}.mt" + out_mt = os.path.join(args.out_dir, shard_name) + + if hl.hadoop_exists(f"{out_mt}/_SUCCESS") and not args.overwrite: + logger.info("Shard exists, skipping: %s", out_mt) + continue + + sample_set = hl.literal(set(shard_samples)) + mt_shard = mt.filter_cols(sample_set.contains(mt.s)).drop("__col_idx") + mt_shard = mt_shard.naive_coalesce(args.n_final_partitions) + + logger.info( + "Writing shard %d/%d (%d samples) -> %s", + shard_id + 1, + n_shards, + len(shard_samples), + out_mt, + ) + mt_shard.checkpoint(out_mt, overwrite=args.overwrite) + + +if __name__ == "__main__": + main(_parse_args()) diff --git a/all_of_us/mitochondria/mtSwirl_refactor/Terra/union_mt_shards.py b/all_of_us/mitochondria/mtSwirl_refactor/Terra/union_mt_shards.py new file mode 100644 index 0000000000..8a508e95ab --- /dev/null +++ b/all_of_us/mitochondria/mtSwirl_refactor/Terra/union_mt_shards.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +"""Union MT shards by columns with row-key alignment. + +Inputs +------ +* --mt-list-tsv: TSV with header 'mt_path' and one MT path per line. + +Output +------ +* --out-mt: merged MT directory. +""" + +from __future__ import annotations + +import argparse +import logging + +import hail as hl + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Union MT shards by columns") + p.add_argument("--mt-list-tsv", required=True, help="TSV with header mt_path") + p.add_argument("--out-mt", required=True, help="Output MT path") + p.add_argument("--temp-dir", required=True, help="Temp directory for Hail/Spark scratch") + p.add_argument( + "--n-final-partitions", + type=int, + default=1000, + help="Target partitions for the merged MT", + ) + p.add_argument("--overwrite", action="store_true") + return p.parse_args() + + +def main(args: argparse.Namespace) -> None: + logging.basicConfig( + level=logging.INFO, + format="%(levelname)s: %(asctime)s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + logger = logging.getLogger("union_mt_shards") + + hl.init(tmp_dir=args.temp_dir) + + if hl.hadoop_exists(f"{args.out_mt}/_SUCCESS") and not args.overwrite: + logger.info("Output exists, reading: %s", args.out_mt) + _ = hl.read_matrix_table(args.out_mt) + return + + ht = hl.import_table(args.mt_list_tsv, impute=False, types={"mt_path": hl.tstr}) + paths = ht.aggregate(hl.agg.collect(ht.mt_path)) + paths = [p for p in paths if p is not None and len(p) > 0] + + if len(paths) == 0: + raise ValueError("mt-list-tsv contained 0 mt_path entries") + + logger.info("Unioning %d MT shards", len(paths)) + mts = [hl.read_matrix_table(p) for p in paths] + + merged = mts[0] + for mt in mts[1:]: + merged = merged.union_cols(mt, row_join_type="outer") + + logger.info("Checkpoint merged MT: %s", args.out_mt) + merged = merged.repartition(args.n_final_partitions) + merged.checkpoint(args.out_mt, overwrite=args.overwrite) + + +if __name__ == "__main__": + main(_parse_args()) diff --git a/all_of_us/mitochondria/mtSwirl_refactor/add_annotations.py b/all_of_us/mitochondria/mtSwirl_refactor/add_annotations.py index 92e60dbfe2..74f48299ed 100644 --- a/all_of_us/mitochondria/mtSwirl_refactor/add_annotations.py +++ b/all_of_us/mitochondria/mtSwirl_refactor/add_annotations.py @@ -9,8 +9,6 @@ import sys, os sys.path.append('/home/jupyter/') -hl.init(log='annotations_logging.log') - from collections import Counter from textwrap import dedent @@ -57,11 +55,7 @@ logger = logging.getLogger("add annotations") logger.setLevel(logging.INFO) -if int(hl.version().split('-')[0].split('.')[2]) >= 75: # only use this if using hail 0.2.75 or greater - logger.info("Setting hail flag to avoid array index out of bounds error...") - # Setting this flag isn't generally recommended, but is needed (since at least Hail version 0.2.75) to avoid an array index out of bounds error until changes are made in future versions of Hail - # TODO: reassess if this flag is still needed for future versions of Hail - hl._set_flags(no_whole_stage_codegen="1") +_HAIL_VERSION_MINOR = int(hl.version().split('-')[0].split('.')[2]) def add_genotype(mt_path: str, min_hom_threshold: float = 0.95) -> hl.MatrixTable: @@ -2219,6 +2213,17 @@ def main(args): # noqa: D103 mt_path = args.mt_path output_dir = args.output_dir temp_dir = args.temp_dir + hl.init( + log='annotations_logging.log', + tmp_dir=f"file://{os.path.abspath(temp_dir)}", + local_tmpdir=f"file://{os.path.abspath(temp_dir)}", + spark_conf={"spark.local.dir": os.path.abspath(temp_dir)}, + ) + if _HAIL_VERSION_MINOR >= 75: # only use this if using hail 0.2.75 or greater + logger.info("Setting hail flag to avoid array index out of bounds error...") + # Setting this flag isn't generally recommended, but is needed (since at least Hail version 0.2.75) to avoid an array index out of bounds error until changes are made in future versions of Hail + # TODO: reassess if this flag is still needed for future versions of Hail + hl._set_flags(no_whole_stage_codegen="1") participant_data = args.participant_data vep_results = args.vep_results min_hom_threshold = args.min_hom_threshold diff --git a/all_of_us/mitochondria/mtSwirl_refactor/merging_utils.py b/all_of_us/mitochondria/mtSwirl_refactor/merging_utils.py index 8ebab19785..f6255dde38 100644 --- a/all_of_us/mitochondria/mtSwirl_refactor/merging_utils.py +++ b/all_of_us/mitochondria/mtSwirl_refactor/merging_utils.py @@ -115,6 +115,29 @@ def chunks(items, binsize): yield lst +def union_rows_tree(mts: list[hl.MatrixTable], chunk_size: int = 8) -> hl.MatrixTable: + """Union MatrixTables by rows using a fan-in tree. + + This avoids a long linear union chain and keeps DAG size manageable. + + :param mts: List of MatrixTables to union. + :param chunk_size: Fan-in size per union stage. + :return: MatrixTable with all rows unioned. + """ + if not mts: + raise ValueError("union_rows_tree requires at least one MatrixTable") + staging = list(mts) + while len(staging) > 1: + next_stage = [] + for group in chunks(staging, chunk_size): + merged = group[0] + for mt in group[1:]: + merged = merged.union_rows(mt) + next_stage.append(merged) + staging = next_stage + return staging[0] + + def coverage_merging(paths, num_merges, chunk_size, check_from_disk, temp_dir, n_read_partitions, n_final_partitions, keep_targets, logger, no_batch_mode=False): @@ -622,8 +645,8 @@ def determine_hom_refs_from_covdb( * Map each MT sample to its covdb row index. * Chunk MT positions into position blocks. * For each block, read coverage from HDF5 once and broadcast only that block. - * Update entries for rows in that block, leaving other rows untouched. - * Periodically checkpoint to avoid constructing an enormous IR from all blocks. + * Filter rows to the block, update entries only for those rows, and checkpoint. + * Union all block MTs using a fan-in tree to keep the DAG size manageable. """ from generate_mtdna_call_mt.covdb_utils import open_covdb_index, read_covdb_block import numpy as np @@ -688,9 +711,7 @@ def determine_hom_refs_from_covdb( pos_to_block_hl = hl.literal({int(np.int32(k)): int(np.int32(v)) for k, v in pos_to_block.items()}) mt = mt.annotate_rows(__pos=hl.int32(mt.locus.position)) mt = mt.annotate_rows(__block=hl.int32(pos_to_block_hl.get(mt.__pos))) - mt = mt.repartition(n_blocks, shuffle=True) - - mt = mt.annotate_entries(__cov=hl.missing(hl.tint32)) + block_mts: list[hl.MatrixTable] = [] for block_id, block in enumerate(block_positions): block_id_i32 = int(np.int32(block_id)) @@ -714,47 +735,48 @@ def determine_hom_refs_from_covdb( pos_to_offset_hl = hl.literal( {int(np.int32(p)): int(np.int32(i)) for i, p in enumerate(block)} ) - offset_expr = pos_to_offset_hl.get(mt.__pos) + + mt_b = mt.filter_rows(mt.__block == block_id_i32) + offset_expr = pos_to_offset_hl.get(mt_b.__pos) cov_expr = hl.if_else( hl.is_defined(offset_expr), - cov_block_hl[hl.int32(offset_expr)][hl.int32(mt.__col_idx)], + cov_block_hl[hl.int32(offset_expr)][hl.int32(mt_b.__col_idx)], hl.missing(hl.tint32), ) - mt = mt.annotate_entries( - __cov=hl.if_else(mt.__block == block_id_i32, cov_expr, mt.__cov) + mt_b = mt_b.annotate_entries(__cov=cov_expr) + mt_b = mt_b.annotate_entries( + DP=hl.if_else(hl.is_missing(mt_b.HL), mt_b.__cov, mt_b.DP) ) - if ( - checkpoint_interval_blocks > 0 - and (block_id + 1) % checkpoint_interval_blocks == 0 - and (block_id + 1) < n_blocks - ): + hom_ref_expr = hl.is_missing(mt_b.HL) & (mt_b.DP > minimum_homref_coverage) + mt_b = mt_b.annotate_entries( + HL=hl.if_else(hom_ref_expr, 0.0, mt_b.HL), + FT=hl.if_else(hom_ref_expr, {"PASS"}, mt_b.FT), + DP=hl.if_else( + hl.is_missing(mt_b.HL) & (mt_b.DP <= minimum_homref_coverage), + hl.null(hl.tint32), + mt_b.DP, + ), + ) + + mt_b = mt_b.drop("__cov") + + if checkpoint_interval_blocks > 0: tmp_path = hl.utils.new_temp_file("covdb_block", extension="mt") log.info( - "Checkpointing covdb blocks at %d/%d -> %s", + "Checkpointing covdb block %d/%d -> %s", block_id + 1, n_blocks, tmp_path, ) - mt = mt.checkpoint(tmp_path, overwrite=True) + mt_b = mt_b.checkpoint(tmp_path, overwrite=True) - mt = mt.annotate_entries(DP=hl.if_else(hl.is_missing(mt.HL), mt.__cov, mt.DP)) + block_mts.append(mt_b) - hom_ref_expr = hl.is_missing(mt.HL) & (mt.DP > minimum_homref_coverage) - mt = mt.annotate_entries( - HL=hl.if_else(hom_ref_expr, 0.0, mt.HL), - FT=hl.if_else(hom_ref_expr, {"PASS"}, mt.FT), - DP=hl.if_else( - hl.is_missing(mt.HL) & (mt.DP <= minimum_homref_coverage), - hl.null(hl.tint32), - mt.DP, - ), - ) - - mt = mt.drop("__cov") - mt = mt.drop("__covdb_sample_index", "__col_idx", "__pos", "__block") - return mt + mt_out = union_rows_tree(block_mts) + mt_out = mt_out.drop("__covdb_sample_index", "__col_idx", "__pos", "__block") + return mt_out def apply_mito_artifact_filter(mt: hl.MatrixTable, artifact_prone_sites_path: str, artifact_prone_sites_reference: str) -> hl.MatrixTable: diff --git a/all_of_us/mitochondria/mtSwirl_refactor/smoke_test_finalize_mt_with_covdb.py b/all_of_us/mitochondria/mtSwirl_refactor/smoke_test_finalize_mt_with_covdb.py index 22ecb4e8bd..47e1755ace 100644 --- a/all_of_us/mitochondria/mtSwirl_refactor/smoke_test_finalize_mt_with_covdb.py +++ b/all_of_us/mitochondria/mtSwirl_refactor/smoke_test_finalize_mt_with_covdb.py @@ -75,7 +75,7 @@ def main() -> None: mt, coverage_h5_path=str(covdb_path), minimum_homref_coverage=100, - position_block_size=2, + position_block_size=1, ) mt = apply_mito_artifact_filter( diff --git a/all_of_us/mitochondria/mtSwirl_refactor/smoke_test_finalize_mt_with_covdb_no_artifact.py b/all_of_us/mitochondria/mtSwirl_refactor/smoke_test_finalize_mt_with_covdb_no_artifact.py index 609637ac79..c4d043f7ca 100644 --- a/all_of_us/mitochondria/mtSwirl_refactor/smoke_test_finalize_mt_with_covdb_no_artifact.py +++ b/all_of_us/mitochondria/mtSwirl_refactor/smoke_test_finalize_mt_with_covdb_no_artifact.py @@ -64,7 +64,7 @@ def main() -> None: mt, coverage_h5_path=str(covdb_path), minimum_homref_coverage=100, - position_block_size=2, + position_block_size=1, ) mt = mt.checkpoint(str(tmp_path / "final.mt"), overwrite=True, stage_locally=True) diff --git a/all_of_us/mitochondria/mtSwirl_refactor/smoke_test_shard_mt_by_samples.py b/all_of_us/mitochondria/mtSwirl_refactor/smoke_test_shard_mt_by_samples.py new file mode 100644 index 0000000000..b0caf0f84d --- /dev/null +++ b/all_of_us/mitochondria/mtSwirl_refactor/smoke_test_shard_mt_by_samples.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +"""Smoke test for shard_mt_by_samples. + +Creates a small MT, shards it by samples, and validates shard sizes. +""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +import hail as hl + +import argparse + +from generate_mtdna_call_mt.Terra.shard_mt_by_samples import main as shard_main + + +def _build_mt() -> hl.MatrixTable: + mt = hl.utils.range_matrix_table(n_rows=2, n_cols=5) + mt = mt.key_rows_by( + locus=hl.locus("MT", mt.row_idx + 1, reference_genome="GRCh37"), + alleles=hl.array(["A", "C"]), + ) + mt = mt.key_cols_by(s=hl.str(mt.col_idx)) + mt = mt.annotate_entries(HL=hl.if_else(mt.col_idx == 0, 0.5, hl.missing(hl.tfloat64))) + return mt + + +def main() -> None: + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + hl.init(tmp_dir=str(tmp_path / "hail_tmp"), quiet=True) + + mt = _build_mt() + in_mt = tmp_path / "input.mt" + mt.checkpoint(str(in_mt), overwrite=True) + + out_dir = tmp_path / "shards" + args = argparse.Namespace( + in_mt=str(in_mt), + out_dir=str(out_dir), + temp_dir=str(tmp_path / "spark_tmp"), + shard_size=2, + n_final_partitions=2, + overwrite=True, + ) + + shard_main(args) + + shard_paths = sorted(out_dir.glob("shard_*.mt")) + assert len(shard_paths) == 3 + + counts = [hl.read_matrix_table(str(p)).count_cols() for p in shard_paths] + assert counts == [2, 2, 1] + + print("smoke_test_shard_mt_by_samples: PASS") + + +if __name__ == "__main__": + main() diff --git a/all_of_us/mitochondria/mt_coverage_merge.changelog.md b/all_of_us/mitochondria/mt_coverage_merge.changelog.md index 334374498b..1eac0a8b65 100644 --- a/all_of_us/mitochondria/mt_coverage_merge.changelog.md +++ b/all_of_us/mitochondria/mt_coverage_merge.changelog.md @@ -1,4 +1,10 @@ aou_9.0.0 +2026-03-23 (Date of Last Commit) + +* Version of the pipeline use to process AoU v9 data + + +aou_9_beta 2025-10-31 (Date of Last Commit) * Added support for optional subsetting of inputs using a Terra data table TSV diff --git a/all_of_us/mitochondria/mt_coverage_merge.wdl b/all_of_us/mitochondria/mt_coverage_merge.wdl index 111559955b..869aaf4ed7 100644 --- a/all_of_us/mitochondria/mt_coverage_merge.wdl +++ b/all_of_us/mitochondria/mt_coverage_merge.wdl @@ -31,6 +31,11 @@ workflow mt_coverage_merge { Int step3_shard_n_partitions = 192 String step3_output_bucket + # Finalize sharding controls + Int finalize_shard_size = 25000 + Int finalize_shard_n_partitions = 256 + Int finalize_union_n_partitions = 1000 + } if (defined(sample_list_tsv)) { @@ -148,35 +153,61 @@ workflow mt_coverage_merge { } } - # Finalize: apply covdb homref/DP + artifact filter once on the final merged MT - # Finalize on the deepest merge output that exists. + # Shard the merged MT, finalize each shard, then union columns back together. if (do_merge_round_2) { if (defined(merge_round_3.merged_mt_tar)) { - call finalize_mt_with_covdb as finalize_mt_with_covdb_round3 { + call shard_mt_by_samples as shard_mt_by_samples_round3 { input: in_mt_tar = select_first([merge_round_3.merged_mt_tar])[0], - coverage_db_tar = annotate_coverage.output_db, - file_name = combined_mt_name + shard_size = finalize_shard_size, + n_final_partitions = finalize_shard_n_partitions, + output_bucket = step3_output_bucket } } if (!defined(merge_round_3.merged_mt_tar)) { - call finalize_mt_with_covdb as finalize_mt_with_covdb_round2 { + call shard_mt_by_samples as shard_mt_by_samples_round2 { input: in_mt_tar = select_first([merge_round_2.merged_mt_tar])[0], - coverage_db_tar = annotate_coverage.output_db, - file_name = combined_mt_name + shard_size = finalize_shard_size, + n_final_partitions = finalize_shard_n_partitions, + output_bucket = step3_output_bucket } } } if (!do_merge_round_2) { - call finalize_mt_with_covdb as finalize_mt_with_covdb_round1 { + call shard_mt_by_samples as shard_mt_by_samples_round1 { input: in_mt_tar = merge_round_1.merged_mt_tar[0], + shard_size = finalize_shard_size, + n_final_partitions = finalize_shard_n_partitions, + output_bucket = step3_output_bucket + } + } + + Array[String] shard_mt_tars_for_finalize = select_first([ + shard_mt_by_samples_round3.shard_mt_tars, + shard_mt_by_samples_round2.shard_mt_tars, + shard_mt_by_samples_round1.shard_mt_tars + ]) + + scatter (shard_mt_tar in shard_mt_tars_for_finalize) { + call finalize_mt_with_covdb as finalize_mt_shard { + input: + in_mt_tar = shard_mt_tar, coverage_db_tar = annotate_coverage.output_db, - file_name = combined_mt_name + file_name = basename(shard_mt_tar, ".tar.gz") + "_finalized", + n_final_partitions = finalize_shard_n_partitions } } + + call union_mt_shards { + input: + mt_tars = finalize_mt_shard.results_tar, + out_mt_name = combined_mt_name, + n_final_partitions = finalize_union_n_partitions, + output_bucket = step3_output_bucket + } } if (!shard_step3) { @@ -190,9 +221,7 @@ workflow mt_coverage_merge { } File combined_mt_tar = select_first([ - finalize_mt_with_covdb_round3.results_tar, - finalize_mt_with_covdb_round2.results_tar, - finalize_mt_with_covdb_round1.results_tar, + union_mt_shards.results_tar, combine_vcfs_and_homref_from_covdb.results_tar ]) @@ -608,6 +637,267 @@ task merge_mt_shards { } } +task shard_mt_by_samples { + input { + String in_mt_tar + Int shard_size = 25000 + Int n_final_partitions = 256 + Boolean overwrite = false + String output_bucket + + # Runtime parameters + Int memory_gb = 256 + Int cpu = 32 + Int disk_gb = 2000 + String disk_type = "SSD" + } + + command <<< + set -euxo pipefail + + mkdir -p ./tmp + mkdir -p ./results + mkdir -p ./input_mt + + setup_spark() { + local mem_gb="$1" + export SPARK_LOCAL_DIRS="$PWD/tmp" + local driver_mem_gb=$((mem_gb - 8)) + if [ "$driver_mem_gb" -lt 4 ]; then driver_mem_gb=4; fi + export SPARK_DRIVER_MEMORY="${driver_mem_gb}g" + export PYSPARK_SUBMIT_ARGS="--driver-memory ${driver_mem_gb}g --executor-memory ${driver_mem_gb}g pyspark-shell" + export JAVA_OPTS="-Xms${driver_mem_gb}g -Xmx${driver_mem_gb}g" + } + + find_mt_dir() { + local search_dir="$1" + local max_depth="$2" + local label="$3" + local mt_dir + if [ -f "${search_dir}/metadata.json.gz" ]; then + echo "${search_dir}" + return + fi + mt_dir=$(find "${search_dir}" -maxdepth "${max_depth}" -type d -name "*.mt" ! -path "${search_dir}" | head -n 1) + if [ -z "${mt_dir}" ]; then + echo "ERROR: could not find .mt directory after extracting ${label}" >&2 + find "${search_dir}" -maxdepth "${max_depth}" -type d | head -100 >&2 + exit 1 + fi + echo "${mt_dir}" + } + + setup_spark ~{memory_gb} + + # Extract input merged MT tar (String path, typically gs://) + command -v gcloud + IN_TAR_PATH="~{in_mt_tar}" + LOCAL_TAR="./input_mt/input_mt.tar.gz" + if [[ "${IN_TAR_PATH}" == gs://* ]]; then + gcloud storage cp "${IN_TAR_PATH}" "${LOCAL_TAR}" + else + cp -f "${IN_TAR_PATH}" "${LOCAL_TAR}" + fi + tar -xzf "${LOCAL_TAR}" -C ./input_mt + IN_MT_DIR=$(find_mt_dir "./input_mt" 2 "in_mt_tar") + + python3 /opt/mtSwirl/generate_mtdna_call_mt/Terra/shard_mt_by_samples.py \ + --in-mt "$IN_MT_DIR" \ + --out-dir ./results \ + --temp-dir ./tmp \ + --shard-size ~{shard_size} \ + --n-final-partitions ~{n_final_partitions} \ + ~{if overwrite then "--overwrite" else ""} + + # Tar and upload shards + DEST_ROOT="~{output_bucket}" + DEST_ROOT="${DEST_ROOT%/}" + rm -f shard_mt_tars.list + + for mt_dir in ./results/shard_*.mt; do + shard_name=$(basename "${mt_dir}") + tar_name="${shard_name}.tar.gz" + tar -czf "${tar_name}" -C ./results "${shard_name}" + + DEST_PATH="${DEST_ROOT}/${tar_name}" + gcloud storage cp "${tar_name}" "${DEST_PATH}" + + export TAR_NAME="${tar_name}" + LOCAL_MD5_B64=$(python3 - <<'PY' + import base64 + import hashlib + import os + + path = os.environ["TAR_NAME"] + h = hashlib.md5() + with open(path, "rb") as handle: + for chunk in iter(lambda: handle.read(1024 * 1024), b""): + h.update(chunk) + print(base64.b64encode(h.digest()).decode("utf-8")) + PY + ) + REMOTE_MD5=$(gcloud storage objects describe "${DEST_PATH}" --format='value(md5Hash)') + if [ "${LOCAL_MD5_B64}" != "${REMOTE_MD5}" ]; then + echo "ERROR: MD5 mismatch after copy to ${DEST_PATH}" >&2 + echo "LOCAL_MD5_B64=${LOCAL_MD5_B64}" >&2 + echo "REMOTE_MD5=${REMOTE_MD5}" >&2 + exit 1 + fi + + echo "${DEST_PATH}" >> shard_mt_tars.list + done + >>> + + output { + Array[String] shard_mt_tars = read_lines("shard_mt_tars.list") + } + + runtime { + docker: "us.gcr.io/broad-gotc-prod/aou-mitochondrial-combine-vcfs-covdb:dev" + memory: memory_gb + " GB" + cpu: cpu + disks: "local-disk " + disk_gb + " " + disk_type + } +} + +task union_mt_shards { + input { + Array[File] mt_tars + String out_mt_name + Int n_final_partitions = 1000 + Boolean overwrite = false + String output_bucket + + # Runtime parameters + Int memory_gb = 256 + Int cpu = 32 + Int disk_gb = 3000 + String disk_type = "SSD" + } + + command <<< + set -euxo pipefail + + mkdir -p ./tmp + mkdir -p ./results + mkdir -p ./inputs + + setup_spark() { + local mem_gb="$1" + export SPARK_LOCAL_DIRS="$PWD/tmp" + local driver_mem_gb=$((mem_gb - 8)) + if [ "$driver_mem_gb" -lt 4 ]; then driver_mem_gb=4; fi + export SPARK_DRIVER_MEMORY="${driver_mem_gb}g" + export PYSPARK_SUBMIT_ARGS="--driver-memory ${driver_mem_gb}g --executor-memory ${driver_mem_gb}g pyspark-shell" + export JAVA_OPTS="-Xms${driver_mem_gb}g -Xmx${driver_mem_gb}g" + } + + find_mt_dir() { + local search_dir="$1" + local max_depth="$2" + local label="$3" + local mt_dir + if [ -f "${search_dir}/metadata.json.gz" ]; then + echo "${search_dir}" + return + fi + mt_dir=$(find "${search_dir}" -maxdepth "${max_depth}" -type d -name "*.mt" ! -path "${search_dir}" | head -n 1) + if [ -z "${mt_dir}" ]; then + echo "ERROR: could not find .mt directory after extracting ${label}" >&2 + find "${search_dir}" -maxdepth "${max_depth}" -type d | head -100 >&2 + exit 1 + fi + echo "${mt_dir}" + } + + setup_spark ~{memory_gb} + + # Serialize mt_tars into a newline-delimited file and iterate from file. + # This avoids subtle trailing-newline/read-loop edge cases. + cat > ./inputs/mt_tars.list < ./inputs/mt_paths.tsv + + i=0 + while IFS= read -r mt_tar || [ -n "${mt_tar}" ]; do + if [ -z "${mt_tar}" ]; then + continue + fi + + printf -v local_tar "./inputs/mt_%05d.tar.gz" "${i}" + printf -v dest_dir "./inputs/mt_%05d.extract" "${i}" + mkdir -p "${dest_dir}" + + if [[ "${mt_tar}" == gs://* ]]; then + gcloud storage cp "${mt_tar}" "${local_tar}" + else + cp -f "${mt_tar}" "${local_tar}" + fi + + tar -xzf "${local_tar}" -C "${dest_dir}" + mt_dir=$(find_mt_dir "${dest_dir}" 2 "${local_tar}") + printf "%s\n" "${mt_dir}" >> ./inputs/mt_paths.tsv + + i=$((i+1)) + done < ./inputs/mt_tars.list + + loaded_count=$(tail -n +2 ./inputs/mt_paths.tsv | grep -cve '^\s*$' || true) + echo "Loaded shard MT count: ${loaded_count}" + if [ "${loaded_count}" -ne "${expected_count}" ]; then + echo "ERROR: union_mt_shards loaded ${loaded_count} shard MTs but expected ${expected_count}" >&2 + echo "Input list:" >&2 + cat ./inputs/mt_tars.list >&2 + echo "Resolved mt_paths.tsv:" >&2 + cat ./inputs/mt_paths.tsv >&2 + exit 1 + fi + + python3 /opt/mtSwirl/generate_mtdna_call_mt/Terra/union_mt_shards.py \ + --mt-list-tsv ./inputs/mt_paths.tsv \ + --out-mt ./results/~{out_mt_name}.mt \ + --temp-dir ./tmp \ + --n-final-partitions ~{n_final_partitions} \ + ~{if overwrite then "--overwrite" else ""} + + tar -czf "~{out_mt_name}.tar.gz" -C ./results "~{out_mt_name}.mt" + + command -v gcloud + gcloud config set storage/parallel_composite_upload_enabled False + DEST_ROOT="~{output_bucket}" + DEST_ROOT="${DEST_ROOT%/}" + DEST_PATH="${DEST_ROOT}/~{out_mt_name}.tar.gz" + gcloud storage cp "~{out_mt_name}.tar.gz" "${DEST_PATH}" + + LOCAL_MD5_B64=$(openssl md5 -binary "~{out_mt_name}.tar.gz" | base64) + REMOTE_MD5=$(gcloud storage objects describe "${DEST_PATH}" --format='value(md5Hash)') + if [ "${LOCAL_MD5_B64}" != "${REMOTE_MD5}" ]; then + echo "ERROR: MD5 mismatch after copy to ${DEST_PATH}" >&2 + echo "LOCAL_MD5_B64=${LOCAL_MD5_B64}" >&2 + echo "REMOTE_MD5=${REMOTE_MD5}" >&2 + exit 1 + fi + + echo "${DEST_PATH}" > results_tar_path.txt + >>> + + output { + File results_tar = "~{out_mt_name}.tar.gz" + } + + runtime { + docker: "us.gcr.io/broad-gotc-prod/aou-mitochondrial-combine-vcfs-covdb:dev" + memory: memory_gb + " GB" + cpu: cpu + disks: "local-disk " + disk_gb + " " + disk_type + } +} + task finalize_mt_with_covdb { input { String in_mt_tar @@ -621,11 +911,10 @@ task finalize_mt_with_covdb { Boolean overwrite = false # Runtime parameters - Int memory_gb = 768 - Int cpu = 96 - Int disk_gb = 4000 + Int memory_gb = 128 + Int cpu = 32 + Int disk_gb = 500 String disk_type = "SSD" - String machine_type = "n2d-highmem-96" } command <<< @@ -702,11 +991,10 @@ task finalize_mt_with_covdb { } runtime { - docker: "us.gcr.io/broad-gotc-prod/aou-mitochondrial-combine-vcfs-covdb:1.0.1" + docker: "us.gcr.io/broad-gotc-prod/aou-mitochondrial-combine-vcfs-covdb:dev" memory: memory_gb + " GB" cpu: cpu disks: "local-disk " + disk_gb + " " + disk_type - predefinedMachineType: machine_type } } @@ -1017,7 +1305,7 @@ task combine_vcfs_and_homref_from_covdb { runtime { # NOTE: This must be a Hail-capable image with mtSwirl code baked in at /opt/mtSwirl. - docker: "us.gcr.io/broad-gotc-prod/aou-mitochondrial-combine-vcfs-covdb:1.0.1" + docker: "us.gcr.io/broad-gotc-prod/aou-mitochondrial-combine-vcfs-covdb:dev" memory: memory_gb + " GB" cpu: cpu disks: "local-disk " + disk_gb + " " + disk_type @@ -1037,11 +1325,21 @@ task add_annotations { Int cpu = 32 Int disk_gb = 1000 String disk_type = "SSD" + Int monitor_interval_seconds = 60 + Boolean enable_monitoring = true + Boolean enable_disk_monitor_upload = false + String disk_monitor_gcs_dir = "" } command <<< set -euxo pipefail + echo "BEGINNING SETUP" + #echo "Contents of ./tmp:" + #ls -lh ./tmp + echo "Contents of /tmp:" + ls -lh /tmp + WORK_DIR=$(pwd) setup_spark() { @@ -1074,6 +1372,29 @@ task add_annotations { setup_spark ~{memory_gb} + touch disk_monitor.log + if ~{enable_monitoring}; then + (while true; do + echo "===== disk_monitor $(date -Iseconds) =====" + df -h -T / || true + df -i -T / || true + df -h -T /tmp /var/tmp || true + df -i -T /tmp /var/tmp || true + df -h -T /mnt/disks/cromwell_root || true + df -i -T /mnt/disks/cromwell_root || true + df -h -T "${WORK_DIR}" || true + df -i -T "${WORK_DIR}" || true + sleep ~{monitor_interval_seconds} + done) > disk_monitor.log 2>&1 & + fi + if ~{enable_disk_monitor_upload} && [ -n "~{disk_monitor_gcs_dir}" ]; then + DISK_MONITOR_GCS_DIR="~{disk_monitor_gcs_dir}" + (while true; do + gsutil -q cp disk_monitor.log "${DISK_MONITOR_GCS_DIR%/}/disk_monitor.log" || true + sleep ~{monitor_interval_seconds} + done) > disk_monitor_upload.log 2>&1 & + fi + # Unzip VCF MatrixTable tarball mkdir -p ./unzipped_vcf.mt tar -xzf ~{vcf_mt} -C ./unzipped_vcf.mt @@ -1099,6 +1420,12 @@ task add_annotations { -d ./~{output_name} \ --temp-dir ./tmp + echo "DONE WITH ANNOTATION" + echo "Contents of ./tmp:" + ls -lh ./tmp + echo "Contents of /tmp:" + ls -lh /tmp + # Compress the annotated output directory tar -czf $WORK_DIR/annotated_output.tar.gz ~{output_name} >>> @@ -1108,7 +1435,7 @@ task add_annotations { } runtime { - docker: "us.gcr.io/broad-gotc-prod/aou-mitochondrial-combine-vcfs-covdb:1.0.1" + docker: "us.gcr.io/broad-gotc-prod/aou-mitochondrial-combine-vcfs-covdb:1.0.3" memory: memory_gb + " GB" cpu: cpu disks: "local-disk " + disk_gb + " " + disk_type