diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..e283ccacd --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,230 @@ +# gnomad_methods Project Reference + +## Project Overview + +Shared Hail utility library for gnomAD pipelines. Provides reusable functions for constraint analysis, variant QC, sample QC, VEP processing, resource management, and general genomics operations. Used as a dependency by `gnomad_constraint`, `gnomad_qc`, `gnomad_mnv`, and other gnomAD repos. + +## Package Structure + +| Directory | Purpose | +|-----------|---------| +| `gnomad/utils/constraint.py` | Constraint pipeline utilities (mutation rate, model building, o/e, pLI, z-scores) | +| `gnomad/utils/vep.py` | VEP annotation processing (consequence filtering, LOFTEE, MANE Select) | +| `gnomad/utils/annotations.py` | General variant annotations (age, frequency, quality) | +| `gnomad/utils/file_utils.py` | File I/O helpers, struct printing, array conversion | +| `gnomad/utils/filtering.py` | Frequency and variant filtering utilities | +| `gnomad/utils/gen_stats.py` | Statistical helper functions | +| `gnomad/utils/reference_genome.py` | Reference genome utilities | +| `gnomad/utils/release.py` | Release table formatting | +| `gnomad/utils/sparse_mt.py` | Sparse MatrixTable operations | +| `gnomad/resources/resource_utils.py` | Resource classes (`TableResource`, `VersionedTableResource`, etc.) | +| `gnomad/resources/grch38/gnomad.py` | GRCh38 resource paths and constants | +| `gnomad/resources/grch37/gnomad.py` | GRCh37 resource paths and constants | +| `gnomad/sample_qc/` | Sample QC (ancestry, relatedness, sex, platform) | +| `gnomad/variant_qc/` | Variant QC (random forest, evaluation, training) | +| `gnomad/assessment/` | Summary stats, validity checks | + +## Code Style + +### Formatting + +Code is formatted with **black** (preview mode, line length 88), **isort** (profile `"black"`), and **autopep8** (aggressive=1, ignoring E201/E202/E203/E731). Linting uses **pylint** and **pydocstyle** (PEP 257 convention, ignoring D105/D107). Config is in `pyproject.toml`, `.pydocstylerc`, and `.pylintrc`. + +Pre-commit hooks are installed. Run `pre-commit run --all-files` to check. + +```bash +# Manual formatting +black gnomad/ tests/ +isort --profile black --filter-files gnomad/ tests/ +``` + +### Docstrings + +Use **Sphinx-style** (`:param:`, `:return:`) docstrings: + +- **Summary line**: Concise one-line description, followed by blank line. +- **Body**: Extended description if needed. Use `.. note::` for caveats. +- **Params**: `:param name: Description.` — include default values and when params are conditionally required. +- **Return**: `:return: Description.` — describe the structure, not just the type. +- **Code references**: Use double backticks (````field_name````). +- **Constants**: Document with a docstring on the line immediately after the assignment. + +```python +COVERAGE_CUTOFF = 30 +"""Minimum median exome coverage differentiating high from low coverage sites.""" + + +def my_function( + ht: hl.Table, + max_af: Optional[float] = None, +) -> hl.Table: + """Short summary of what the function does. + + Extended description with more detail about behavior, edge cases, or + design decisions. + + .. note:: + + Any important caveats go here. + + :param ht: Input Table with ``freq`` and ``context`` fields. + :param max_af: Maximum allele frequency threshold. Default is None (no + cutoff). + :return: Table with ``expected_variants`` and ``observed_variants`` fields + added. + """ +``` + +### Type Annotations + +- **All functions** must have type annotations on parameters and return values. +- Use `typing.List`, `typing.Optional`, `typing.Union`, `typing.Tuple`, `typing.Dict` for generic types. +- For Hail expressions, use the `hl.expr.*` prefix: `hl.expr.StructExpression`, `hl.expr.BooleanExpression`, `hl.expr.Float64Expression`, `hl.expr.Int32Expression`, `hl.expr.ArrayExpression`, `hl.expr.NumericExpression`, etc. +- For Hail table/matrix types: `hl.Table`, `hl.MatrixTable`. +- Use `Tuple[str, ...]` for variable-length tuples (not `Tuple[str]` which means exactly one element). +- Never use mutable defaults (`List`, `Dict`) in function signatures — use `None` and assign inside the function body. + +### Function Design + +- **Expression-based where possible**: Prefer functions that take and return Hail expressions rather than Tables. This makes them composable and reusable in different contexts. Use `expr._indices.source` to recover the source Table when needed internally. +- **Single Table param named `ht`**: When a function takes one Table, name it `ht`. Use descriptive names only when taking multiple Tables (e.g., `mutation_ht`, `gencode_ht`). +- **Pure transformations**: Utils functions should be HTs in / HTs out. All file I/O (read, write, checkpoint) belongs in pipeline scripts, not utility functions. Exceptions exist for historical reasons but should not be added. +- **No lazy imports**: Always use top-level imports unless needed to resolve circular imports. +- **`Optional` for nullable params**: Always wrap with `Optional[...]` when the default is `None`. + +### Testing + +- **Policy**: Any new or modified function in a PR must have tests. +- **Format**: Class-based tests with `pytest`. One test class per function. +- **Test data**: Use `hl.Table.parallelize()` to create small inline test tables. +- **Fixtures**: Use `@pytest.fixture` for shared setup. Hail init is handled by a session-scoped fixture in `tests/conftest.py`. +- **Docstrings**: Both test classes and test methods should have docstrings. + +```python +class TestMyFunction: + """Test the my_function function.""" + + def test_basic_case(self): + """Test that basic input produces expected output.""" + ht = hl.Table.parallelize( + [{"x": 1, "y": 2.0}], + hl.tstruct(x=hl.tint32, y=hl.tfloat64), + ) + result = ht.annotate(z=my_function(ht.x, ht.y)).collect()[0] + assert result.z == 3.0 +``` + +## Hail Best Practices + +### When to checkpoint vs cache + +- **`checkpoint(new_temp_file(...))`**: Use for intermediate results that feed into multiple downstream operations or that follow expensive computations (joins, aggregations). Materializes to disk and breaks the query plan, preventing Hail from re-executing the upstream DAG. +- **`.cache()`**: Use for small results that will be reused immediately. Keeps data in memory but doesn't break the query plan as reliably. +- **After checkpoint, `.count()` is free**: It reads materialized metadata rather than re-executing the query. Place `.count()` after a checkpoint when you need the count. + +### Avoid `.count()` for logging + +Never use `.count()` just to log how many rows a table has. On large tables this forces full materialization and can cause Spark shuffle failures. Only call `.count()` when the result is actually needed for computation. + +### `naive_coalesce()` after aggressive filters + +When filtering a large table down to a small subset (e.g., LOFTEE HC LoF from all variants), most partitions become empty. This causes shuffle skew in downstream `group_by` aggregations. Call `.naive_coalesce(N)` after the filter to rebalance. + +### `_localize=False` for expression results + +Use `ht.aggregate(expr, _localize=False)` when the result will be used as a Hail expression in downstream operations (e.g., annotating another table). This avoids collecting to Python and back. + +### Field existence checks + +Use `field_name in ht.row` to check if a field exists on a Table. Hail Tables do not have `.get()`. + +### `hl.or_else`, `hl.or_missing`, `hl.is_defined` + +- **`hl.or_else(expr, default)`**: Substitute `default` when `expr` is missing. +- **`hl.or_missing(condition, expr)`**: Return `expr` when `condition` is True, missing otherwise. +- **`hl.is_defined(expr)`**: Returns True/False (never missing) — no need to wrap in `hl.or_else`. +- **`divide_null(num, denom)`**: Safe division returning null when denominator is 0. Import from `hail.utils.misc`. + +### `approx_quantiles` is approximate + +`hl.agg.approx_quantiles` uses the t-digest algorithm and returns approximate percentiles. Document this with a `.. note::` block when using it in functions. + +### Falsy value gotchas + +When checking optional numeric parameters, always use `is not None` instead of truthiness checks. `if max_af:` will skip `max_af=0.0`, which is a valid value. Same applies to any numeric parameter where 0 is meaningful. + +### Array schema uniformity + +All elements of a Hail array field must have identical struct schemas. You cannot annotate only `array[0]` with extra fields while leaving `array[1+]` unchanged — Hail will reject the mixed schema. Promote such metadata to the parent struct level instead. + +### Rank assignment with `order_by` + +`ht.order_by(expr)` destroys the key. To rejoin ranked results: +1. `ht.add_index("_rank_idx")` before ordering. +2. `rank_ht.key_by("_rank_idx")` after ordering. +3. Use `hl.scan.count()` to assign 0-based ascending ranks: `ht.order_by(val).annotate(rank=hl.scan.count())`. + +### Small table reconstruction + +`hl.Table.parallelize(hl.eval(ht.my_array_global), schema=...)` reconstructs a small Hail Table from a global array without re-running any jobs. + +## Key Resource Classes + +```python +from gnomad.resources.resource_utils import ( + TableResource, # .ht() to read + MatrixTableResource, # .mt() to read + VersionedTableResource, # .ht() with version param + VariantDatasetResource, # .vds() to read +) +``` + +Usage: +```python +resource = TableResource(path="gs://gnomad/v4.1/constraint/metrics.ht") +ht = resource.ht() # Reads the table +``` + +## Key Constraint API + +```python +from gnomad.utils.constraint import ( + # Mutation rate + annotate_mutation_type, + annotate_with_mu, + calibration_model_group_expr, + # Consequence grouping + build_constraint_consequence_groups, + # Counting + single_variant_count_expr, + single_variant_observed_and_possible_expr, + count_observed_and_possible_by_group, + get_counts_agg_expr, + # Model building & application + build_models, + apply_models, + aggregate_expected_variants_expr, + # GERP + calculate_gerp_cutoffs, + # Constraint metrics + compute_pli, + oe_confidence_interval, + calculate_raw_z_score, + calculate_raw_z_score_sd, + get_constraint_flags, + # Ranking & binning + rank_and_assign_bins, + compute_percentile_thresholds, + annotate_bins_by_threshold, +) +``` + +## Maintaining CLAUDE.md + +When working on any gnomAD repo that has a CLAUDE.md file, proactively add useful information you discover during development — gotchas, non-obvious API behavior, schema quirks, resource path conventions, or anything else that would save a future developer (or Claude session) time. Keep additions concise and placed in the appropriate section. + +## CI/CD + +- **Pre-commit hooks**: black, autopep8, pydocstyle, isort (run automatically on commit) +- **CI** (`.github/workflows/ci.yml`): Runs on push to main and PRs — black check, isort check, pydocstyle, autopep8 check, pylint, pytest +- **Publishing** (`.github/workflows/publish.yml`): Triggered by version tags (`v[0-9]+.[0-9]+.[0-9]+`) — validates version, runs tests, publishes to PyPI diff --git a/gnomad/__init__.py b/gnomad/__init__.py new file mode 100644 index 000000000..e85b5c7d3 --- /dev/null +++ b/gnomad/__init__.py @@ -0,0 +1 @@ +"""gnomAD utilities and resources package.""" diff --git a/gnomad/utils/constraint.py b/gnomad/utils/constraint.py index a1e405b6a..ca9b9c2c3 100644 --- a/gnomad/utils/constraint.py +++ b/gnomad/utils/constraint.py @@ -1,12 +1,15 @@ """Script containing generic constraint functions that may be used in the constraint pipeline.""" import copy +import functools import logging +import operator from typing import Any, Callable, Dict, List, Optional, Tuple, Union import hail as hl from hail.utils.misc import divide_null, new_temp_file +from gnomad.assessment.summary_stats import generate_filter_combinations from gnomad.utils.reference_genome import get_reference_genome from gnomad.utils.vep import ( add_most_severe_consequence_to_consequence, @@ -28,6 +31,34 @@ Low coverage sites require an extra calibration when computing the proportion of expected variation. """ +CLASSIC_LOF_ANNOTATIONS = ( + "stop_gained", + "splice_donor_variant", + "splice_acceptor_variant", +) +"""Classic loss-of-function VEP annotations.""" + +DEFAULT_FIELDS_TO_SUM = ( + "mu_snp", + "mu", + "observed_variants", + "possible_variants", + "predicted_proportion_observed", + "coverage_correction", + "expected_variants", +) +"""Default fields summed when aggregating expected variants by constraint group.""" + +DEFAULT_GENCODE_ANNOTATIONS = ( + "transcript_id_version", + "gene_id_version", + "level", + "transcript_type", + "start_position", + "end_position", +) +"""Default GENCODE annotations added to a Table by transcript id.""" + def get_mu_annotation_expr( ht: hl.Table, @@ -92,6 +123,441 @@ def annotate_with_mu( ) +def _resolve_annotation_expr( + t: Optional[Union[hl.Table, hl.MatrixTable]] = None, + annotation_name: Optional[str] = None, + expr: Optional[hl.expr.Expression] = None, + expr_param_name: Optional[str] = None, +) -> hl.expr.Expression: + """ + Get an annotation from a Table/MatrixTable, or return a provided expression. + + Provides a consistent pattern for functions that accept either an explicit + Hail expression or fall back to a well-known field on a Table. This avoids + duplicating "resolve expr or look it up on ht" logic across callers. + + If ``expr`` is provided it is returned directly. Otherwise, ``t`` and + ``annotation_name`` must both be supplied and the named field is looked up on + ``t``. + + Example usage inside a public function:: + + def my_function( + ht: hl.Table, + freq_expr: Optional[hl.expr.StructExpression] = None, + ) -> ...: + freq_expr = _resolve_annotation_expr( + t=ht, + annotation_name="freq", + expr=freq_expr, + expr_param_name="freq_expr", + ) + # freq_expr is now guaranteed to be a valid expression — + # either the caller's explicit value or ht.freq. + + :param t: Input Table or MatrixTable to look up ``annotation_name`` on. Required + when ``expr`` is None. + :param annotation_name: Name of the field to retrieve from ``t``. Required when + ``expr`` is None. + :param expr: Expression to return directly. When provided, ``t`` and + ``annotation_name`` are ignored. + :param expr_param_name: Human-readable name for ``expr``, used in log/error + messages. Defaults to "expr". + :return: The resolved Hail expression. + """ + if expr is None and (t is None or annotation_name is None): + raise ValueError("Either t and annotation_name or expr must be provided.") + + expr_param_name = expr_param_name or "expr" + if expr is None and annotation_name in t.row: + logger.info( + "%s was not provided, using '%s'.", expr_param_name, annotation_name + ) + expr = t[annotation_name] + elif expr is None: + raise ValueError( + f"{expr_param_name} was not provided and '{annotation_name}' is " + f"not present in the input Table or MatrixTable." + ) + + return expr + + +def variant_observed_expr( + freq_expr: Optional[ + Union[hl.expr.StructExpression, hl.expr.ArrayExpression] + ] = None, + ht: Optional[hl.Table] = None, + singleton: bool = False, + max_af: Optional[float] = None, + count_missing: bool = False, +) -> hl.expr.Int32Expression: + """ + Return 1 if a variant meets frequency criteria, 0 otherwise. + + One of ``ht`` or ``freq_expr`` must be provided. If ``freq_expr`` is not provided, + ``ht.freq`` is used. When ``freq_expr`` is an ArrayExpression, the first + element is used. + + The returned count is 1 when: + + - ``singleton`` is True and ``freq_expr.AC == 1``. + - ``max_af`` is not None and ``freq_expr.AC > 0`` and + ``freq_expr.AF <= max_af``. + - Neither ``singleton`` nor ``max_af`` is set but ``freq_expr`` is + available: ``freq_expr.AC > 0``. + - Neither ``singleton`` nor ``max_af`` is set and no ``freq_expr`` is + available: unconditionally 1. + + :param ht: Input Table. Used to look up ``freq`` when ``freq_expr`` is None. + :param freq_expr: StructExpression (or ArrayExpression of structs) with + ``AC`` and ``AF`` fields. When an array, ``freq_expr[0]`` is used. + :param singleton: Count only singletons (AC == 1). Default is False. + :param max_af: Maximum allele frequency threshold. Default is None (no + cutoff). + :param count_missing: Value to substitute when the count expression is + missing (e.g. frequency is None). Default is False (0). + :return: Int32Expression equal to 0 or 1. + """ + if ht is None and freq_expr is None: + raise ValueError("Either ht or freq_expr must be provided.") + + if max_af is not None or singleton: + freq_expr = _resolve_annotation_expr(ht, "freq", freq_expr, "freq_expr") + if isinstance(freq_expr, hl.expr.ArrayExpression): + freq_expr = freq_expr[0] + + if singleton: + count_var = freq_expr.AC == 1 + elif max_af is not None: + count_var = (freq_expr.AC > 0) & (freq_expr.AF <= max_af) + elif freq_expr is not None: + count_var = freq_expr.AC > 0 + else: + count_var = True + + return hl.int(hl.or_else(count_var, count_missing)) + + +def variant_observed_and_possible_expr( + freq_expr: hl.ArrayExpression, + max_af: Optional[float] = None, + use_possible_adj: bool = True, +) -> hl.expr.StructExpression: + """ + Return per-variant observed and possible count expressions. + + For each element of the frequency array, ``observed_variants`` is 1 when the + variant meets the frequency criteria (AC > 0, optionally AF <= ``max_af``) and + 0 otherwise. ``possible_variants`` uses the same criteria but substitutes 1 + for missing frequencies (i.e., the variant site is considered possible even + when frequency data is absent). + + When ``use_possible_adj`` is True (default), the possible count is a scalar + derived from the first (adj) element of the frequency array. When False, it + is an array with one entry per downsampling, matching the shape of + ``observed_variants``. + + :param freq_expr: Array of frequency structs with ``AC`` and ``AF`` fields. + :param max_af: Maximum allele frequency threshold. Default is None (no + cutoff). + :param use_possible_adj: If True, compute possible count from only the adj + (first) frequency element. If False, compute per-downsampling. Default + is True. + :return: Struct with ``observed_variants`` (array of int) and + ``possible_variants`` (int if ``use_possible_adj``, else array of int). + """ + pos_expr = hl.array([freq_expr[0]]) if use_possible_adj else freq_expr + pos_expr = pos_expr.map( + lambda x: variant_observed_expr(freq_expr=x, max_af=max_af, count_missing=True) + ) + pos_expr = pos_expr[0] if use_possible_adj else pos_expr + return hl.struct( + observed_variants=freq_expr.map( + lambda x: variant_observed_expr(freq_expr=x, max_af=max_af) + ), + possible_variants=pos_expr, + ) + + +def weighted_sum_agg_expr( + expr: Union[hl.expr.ArrayNumericExpression, hl.expr.NumericExpression], + weight_expr: Union[hl.expr.ArrayNumericExpression, hl.expr.NumericExpression], +) -> Union[hl.expr.ArrayExpression, hl.expr.NumericExpression]: + """ + Return the weighted aggregate sum of ``expr`` weighted by ``weight_expr``. + + Both parameters may be scalar or array numeric expressions: + + - Both arrays: elements are multiplied pairwise and summed per-element + across rows (``hl.agg.array_sum``). + - Both scalars: standard ``hl.agg.sum(expr * weight_expr)``. + - Mixed: the scalar is broadcast across the array elements and summed + per-element across rows. + + :param expr: Numeric expression (scalar or array) to be weighted. + :param weight_expr: Numeric expression (scalar or array) to weight by. + :return: Weighted aggregate sum expression. + """ + expr_is_array = isinstance(expr, hl.expr.ArrayNumericExpression) + weight_is_array = isinstance(weight_expr, hl.expr.ArrayNumericExpression) + if expr_is_array and weight_is_array: + return hl.agg.array_sum(hl.zip(expr, weight_expr).map(lambda x: x[0] * x[1])) + elif not expr_is_array and not weight_is_array: + return hl.agg.sum(expr * weight_expr) + else: + return hl.agg.array_sum(expr * weight_expr) + + +def counts_agg_expr( + freq_expr: Optional[ + Union[hl.expr.ArrayExpression, hl.expr.StructExpression] + ] = None, + ht: Optional[hl.Table] = None, + count_singletons: bool = False, + max_af: Optional[float] = None, + count_missing: bool = False, +) -> hl.expr.StructExpression: + """ + Return an aggregation expression for variant and singleton counts. + + Aggregates per-variant counts (via ``variant_observed_expr``) across + rows. Each variant contributes 0 or 1 to the count based on its frequency + metadata (AC, AF). The result is a struct with ``variant_count`` and, when + ``count_singletons`` is True, ``singleton_count``. + + One of ``freq_expr`` or ``ht`` must be provided. If ``freq_expr`` is not + provided, ``ht.freq`` is used as the fallback. + + The shape of ``freq_expr`` controls whether the output counts are scalars or + arrays: + + - **StructExpression** (single frequency entry with ``AC`` / ``AF``): + returns scalar ``variant_count`` and ``singleton_count``. + - **ArrayExpression** (array of frequency structs, e.g. one per + downsampling): returns array-valued counts where each element corresponds + to a position in the input array. Internally the array is mapped through + ``variant_observed_expr`` and summed element-wise with + ``hl.agg.array_sum``. + + When ``max_af`` is set, only variants with ``AF <= max_af`` and ``AC > 0`` + are counted. Singleton counting (``count_singletons=True``) counts only + variants with ``AC == 1``, independent of the ``max_af`` filter. + + :param freq_expr: Frequency expression — an ArrayExpression of structs or a + single StructExpression with ``AC`` and ``AF`` fields. If None, falls + back to ``ht.freq``. + :param ht: Input Table. Used to look up ``freq`` when ``freq_expr`` is None. + :param count_singletons: If True, include a ``singleton_count`` field in + the result. Default is False. + :param max_af: Maximum allele frequency threshold. Variants with + ``AF > max_af`` are excluded from ``variant_count``. Does not affect + ``singleton_count``. Default is None (no cutoff). + :param count_missing: Value to substitute when frequency is missing. + Default is False (0). + :return: Aggregation StructExpression with ``variant_count`` (and + optionally ``singleton_count``). Values are scalars when ``freq_expr`` is + a StructExpression, or arrays when it is an ArrayExpression. + """ + if ht is None and freq_expr is None: + raise ValueError("Either ht or freq_expr must be provided.") + + freq_expr = _resolve_annotation_expr(ht, "freq", freq_expr, "freq_expr") + + params = {"variant_count": {"singleton": False, "max_af": max_af}} + if count_singletons: + params["singleton_count"] = {"singleton": True} + + is_struct = isinstance(freq_expr, hl.expr.StructExpression) + if is_struct: + freq_expr = hl.array([freq_expr]) + + count_expr = { + k: freq_expr.map( + lambda x: variant_observed_expr( + freq_expr=x, **p, count_missing=count_missing + ) + ) + for k, p in params.items() + } + + agg_expr = {k: hl.agg.array_sum(v) for k, v in count_expr.items()} + agg_expr = {k: agg_expr[k][0] for k in agg_expr} if is_struct else agg_expr + + return hl.struct(**agg_expr) + + +def build_constraint_consequence_groups( + csq_expr: hl.expr.StringExpression, + lof_modifier_expr: hl.expr.StringExpression, + classic_lof_annotations: Tuple[str, ...] = CLASSIC_LOF_ANNOTATIONS, + additional_groupings: Optional[ + Dict[str, Dict[str, hl.expr.BooleanExpression]] + ] = None, + additional_grouping_combinations: Optional[List[List[str]]] = None, +) -> Tuple[List[hl.expr.BooleanExpression], List[Dict[str, str]]]: + """ + Build constraint consequence groups. + + Builds constraint groups based on the consequence expression and LoF + modifier expression. By default, the following groups are built: + + - csq_set: synonymous_variant, missense_variant + - lof: classic, hc_lc, hc + + The resulting meta and corresponding constraint group filters are: + + - {"csq_set": "syn"}: synonymous_variant + - {"csq_set": "mis"}: missense_variant + - {"lof": "classic"}: classic LoF annotations (stop_gained, + splice_donor_variant, splice_acceptor_variant) + - {"lof": "hc_lc"}: LOFTEE HC or LC modifier + - {"lof": "hc"}: LOFTEE HC modifier only + + Additional groupings can be added via ``additional_groupings``, and + grouping combinations via ``additional_grouping_combinations``. + + :param csq_expr: VEP most severe consequence expression (e.g., + ``ht.most_severe_consequence``). + :param lof_modifier_expr: LOFTEE modifier expression (e.g., ``ht.modifier``). + :param classic_lof_annotations: Classic LoF annotations. Default is + ``("stop_gained", "splice_donor_variant", "splice_acceptor_variant")``. + :param additional_groupings: Additional groupings to add to the constraint + groups. Default is None. + :param additional_grouping_combinations: Additional grouping combinations to + add to the constraint groups. Default is None. + :return: Tuple of (constraint group filter expressions, meta dicts). + """ + lof_classic_expr = hl.literal(set(classic_lof_annotations)).contains(csq_expr) + lof_hc_expr = lof_modifier_expr == "HC" + lof_hc_lc_expr = lof_hc_expr | (lof_modifier_expr == "LC") + annotation_dict = { + "csq_set": { + "syn": csq_expr == "synonymous_variant", + "mis": csq_expr == "missense_variant", + }, + "lof": { + "classic": lof_classic_expr, + "hc_lc": lof_hc_lc_expr, + "hc": lof_hc_expr, + }, + } + + annotation_dict.update(additional_groupings or {}) + + grouping_combinations = [["csq_set"], ["lof"]] + grouping_combinations.extend(additional_grouping_combinations or []) + + meta = generate_filter_combinations( + grouping_combinations, + {k: list(v.keys()) for k, v in annotation_dict.items()}, + ) + constraint_group_filters = [ + functools.reduce(operator.ior, [annotation_dict[k][v] for k, v in m.items()]) + for m in meta + ] + + return constraint_group_filters, meta + + +def count_observed_and_possible_by_group( + ht: hl.Table, + possible_expr: hl.expr.Int32Expression, + observed_expr: hl.expr.ArrayExpression, + additional_grouping: Union[List[str], Tuple[str, ...]] = ("methylation_level",), + partition_hint: int = 100, + weight_exprs: Optional[ + Union[ + List[str], + Dict[str, Union[hl.expr.ArrayExpression, hl.expr.NumericExpression]], + ] + ] = None, + additional_agg_sum_exprs: Optional[ + Union[ + List[str], + Dict[str, Union[hl.expr.ArrayExpression, hl.expr.NumericExpression]], + ] + ] = None, +) -> hl.Table: + """ + Aggregate observed and possible variant counts by substitution context group. + + Groups rows by ``context``, ``ref``, ``alt``, and any fields in + ``additional_grouping``, then sums observed and possible counts within each + group. + + :param ht: Input Table with ``context``, ``ref``, ``alt`` fields and any + fields named in ``additional_grouping``. + :param possible_expr: Per-variant possible count (scalar). + :param observed_expr: Per-variant observed count (array, one element per + downsampling). + :param additional_grouping: Field names to append to the base + ``(context, ref, alt)`` grouping. Default is ``("methylation_level",)``. + :param partition_hint: Target number of partitions for the ``group_by``. + Default is 100. + :param weight_exprs: Weighted sums of ``possible_expr`` to include. Pass + field names (looked up on ``ht``) or a dict mapping output names to + weight expressions. Each produces + ``weighted_sum_agg_expr(possible_expr, weight)``. + :param additional_agg_sum_exprs: Extra fields to sum alongside + observed/possible. Pass field names (looked up on ``ht``) or a dict + mapping output names to expressions. Arrays use + ``hl.agg.array_sum``; scalars use ``hl.agg.sum``. + :return: Grouped Table with ``observed_variants``, ``possible_variants``, + and any weighted/additional sum fields. + """ + if isinstance(weight_exprs, list): + weight_exprs = {k: ht[k] for k in weight_exprs} + if isinstance(additional_agg_sum_exprs, list): + additional_agg_sum_exprs = {k: ht[k] for k in additional_agg_sum_exprs} + + weight_exprs = weight_exprs or {} + additional_agg_sum_exprs = additional_agg_sum_exprs or {} + + # Build the grouping struct for the variant count aggregation. + grouping = hl.struct(context=ht.context, ref=ht.ref, alt=ht.alt) + grouping = grouping.annotate( + **{g: ht[g] for g in additional_grouping if g not in grouping} + ) + logger.info( + "The following annotations will be used to group the input Table rows when" + " counting variants: %s.", + ", ".join(grouping.keys()), + ) + + agg_expr = { + "observed_variants": hl.agg.array_sum(observed_expr), + "possible_variants": hl.agg.sum(possible_expr), + } + + # Update the possible variant count aggregation expression to include weighted sums + # of possible variant counts. + agg_expr.update( + {k: weighted_sum_agg_expr(possible_expr, v) for k, v in weight_exprs.items()} + ) + + # Get sum aggregation expressions for requested fields. + agg_expr.update( + { + k: ( + hl.agg.array_sum(v) + if isinstance(v, hl.ArrayExpression) + else hl.agg.sum(v) + ) + for k, v in additional_agg_sum_exprs.items() + } + ) + + # Apply each variant count aggregation in `agg_expr` to get counts for all + # combinations of `grouping`. + ht = ht.group_by(**grouping).partition_hint(partition_hint).aggregate(**agg_expr) + + return ht + + +# TODO: I think we should consider removing this or at least completely changing it +# To remove pop and downsampling support, since that should just be handled as an +# array, where the same thing is done. def count_variants_by_group( ht: hl.Table, freq_expr: Optional[hl.expr.ArrayExpression] = None, @@ -742,6 +1208,7 @@ def assemble_constraint_context_ht( # Trim heptamer context to create trimer context. ht = trimer_from_heptamer(ht) + ht = ht.filter(ht.context.matches(f"[ATCG]{{{3}}}")) # Annotate mutation type (such as "CpG", "non-CpG transition", "transversion") and # collapse strands to deduplicate the context. @@ -761,6 +1228,12 @@ def assemble_constraint_context_ht( "canonical", "lof", "lof_flags", + "sift_score", + "polyphen_score", + "domains", + "uniprot_isoform", + "amino_acids", + "codons", ] vep_csq_fields = [x for x in vep_csq_fields if x in csqs.dtype.element_type] ht = ht.annotate( @@ -845,29 +1318,167 @@ def assemble_constraint_context_ht( return ht +def calculate_gerp_cutoffs( + ht: hl.Table, + gerp_expr: Optional[hl.expr.Float64Expression] = None, + lower_percentile: float = 0.05, + upper_percentile: float = 0.95, +) -> Tuple[float, float]: + """ + Find GERP cutoffs at the given percentile thresholds. + + .. note:: + + Uses ``hl.agg.approx_quantiles``, so results are approximate. + + :param ht: Input Table. + :param gerp_expr: GERP score expression. Default is ``ht.gerp``. + :param lower_percentile: Lower percentile threshold (0-1). Default is 0.05. + :param upper_percentile: Upper percentile threshold (0-1). Default is 0.95. + :return: Tuple of (lower cutoff, upper cutoff) GERP scores. + """ + if gerp_expr is None: + gerp_expr = ht.gerp + + cutoffs = ht.aggregate( + hl.agg.approx_quantiles(gerp_expr, [lower_percentile, upper_percentile]) + ) + return cutoffs[0], cutoffs[1] + + +def calibration_model_group_expr( + exomes_coverage_expr: hl.expr.Int32Expression, + cpg_expr: hl.expr.BooleanExpression, + low_cov_cutoff: Optional[int] = None, + high_cov_cutoff: int = COVERAGE_CUTOFF, + upper_cov_cutoff: Optional[int] = None, + skip_coverage_model: bool = False, + additional_grouping_exprs: Optional[Dict[str, hl.expr.StringExpression]] = None, + cpg_in_high_only: bool = False, +) -> hl.expr.StructExpression: + """ + Get the calibration model grouping annotation for a variant. + + The calibration model expression is a struct with the following fields: + + - genomic_region: The genomic region of the variant ("autosome_or_par", + "chrx_nonpar", or "chry_nonpar"). + - high_or_low_coverage: Whether the variant belongs to the high or low coverage + calibration model. The variant is assigned to the high coverage model if the + exome coverage is greater than or equal to 'high_cov_cutoff' and less than or + equal to 'upper_cov_cutoff' (if provided). The variant is assigned to the low + coverage model if `skip_coverage_model` is False and the exome coverage is + greater than 'low_cov_cutoff' (if provided) and less than 'high_cov_cutoff'. + - cpg: Whether the variant is a CpG (`cpg_expr`). + + The global parameters for the calibration model are the values of the function + parameters: `low_cov_cutoff`, `high_cov_cutoff`, `upper_cov_cutoff`, and + `skip_coverage_model`. + + :param exomes_coverage_expr: Exome coverage expression. + :param cpg_expr: CpG expression. + :param low_cov_cutoff: Low coverage cutoff. Default is None. + :param high_cov_cutoff: High coverage cutoff. Default is COVERAGE_CUTOFF. + :param upper_cov_cutoff: Upper coverage cutoff. Default is None. + :param skip_coverage_model: Whether to skip the coverage model. Default is False. + :param additional_grouping_exprs: Optional Dictionary of additional expressions to + group by. Default is None. + :return: Tuple containing the calibration model expression and the globals. + """ + high_cov_expr = exomes_coverage_expr >= high_cov_cutoff + if upper_cov_cutoff is not None: + high_cov_expr &= exomes_coverage_expr <= upper_cov_cutoff + + low_cov_expr = hl.bool(False) if skip_coverage_model else hl.bool(True) + if low_cov_cutoff is not None: + low_cov_expr &= exomes_coverage_expr > low_cov_cutoff + + # Define whether the variant should be included in the high or low coverage model. + model_expr = ( + hl.case().when(high_cov_expr, "high").when(low_cov_expr, "low").or_missing() + ) + + cpg_expr = ( + hl.or_missing(model_expr == "high", cpg_expr) if cpg_in_high_only else cpg_expr + ) + return hl.or_missing( + hl.is_defined(model_expr), + hl.struct( + high_or_low_coverage=model_expr, + model_group=hl.struct( + cpg=cpg_expr, + **(additional_grouping_exprs or {}), + ), + ), + ) + + +def _build_sum_agg_struct( + fields_to_sum: Optional[List[str]] = None, + exprs_to_sum: Optional[ + Union[hl.expr.StructExpression, Dict[str, hl.expr.NumericExpression]] + ] = None, + t: Optional[Union[hl.Table, hl.expr.StructExpression]] = None, +) -> hl.expr.StructExpression: + """ + Return an aggregation expression to sum fields or expressions in a Table/StructExpression. + + The aggregation expression is a struct with the sum or array_sum of the fields or + expressions provided in `fields_to_sum` or `exprs_to_sum`. + + :param fields_to_sum: List of fields to sum. Default is None. + :param exprs_to_sum: Dictionary of expressions to sum. Default is None. + :param t: Optional Table or StructExpression to get `fields_to_sum` from. Default + is None. + :return: Aggregation expression to sum fields or expressions in the Table. + """ + if fields_to_sum is None and exprs_to_sum is None: + raise ValueError("Either 'fields_to_sum' or 'exprs_to_sum' must be provided.") + if fields_to_sum is not None and t is None: + raise ValueError("t must be provided if 'fields_to_sum' is provided.") + + exprs_to_sum = exprs_to_sum or {} + exprs_to_sum = hl.struct(**exprs_to_sum, **{f: t[f] for f in fields_to_sum or []}) + + return hl.struct( + **{ + k: ( + hl.agg.array_sum(v) + if isinstance(v, hl.ArrayExpression) + else hl.agg.sum(v) + ) + for k, v in exprs_to_sum.items() + } + ) + + +# TODO: I have changed this so that it doesn't split up the pops anymore. I don't think +# it is necessary to do this, and it makes the code more complicated. We should just +# keep things the way we do for freq with a freq_meta. So we expect the +# observed_variants and plateau_models_expr to be arrays of the same length. def build_models( - coverage_ht: hl.Table, + ht: hl.Table, coverage_expr: hl.expr.Int32Expression, weighted: bool = False, - gen_ancs: Tuple[str] = (), - keys: Tuple[str] = ( + keys: Tuple[str, ...] = ( "context", "ref", "alt", "methylation_level", - "mu_snp", ), + model_group_expr: Optional[hl.expr.StructExpression] = None, high_cov_definition: int = COVERAGE_CUTOFF, upper_cov_cutoff: Optional[int] = None, skip_coverage_model: bool = False, log10_coverage: bool = True, -) -> Tuple[Optional[Tuple[float, float]], hl.expr.StructExpression]: + additional_grouping: Tuple[str, ...] = (), +) -> Tuple[Optional[hl.expr.StructExpression], hl.expr.DictExpression]: """ Build coverage and plateau models. This function builds models (plateau_models) using linear regression to calibrate mutation rate estimates against the proportion observed of each substitution, - context, and methylation level in `coverage_ht`. + context, and methylation level in `ht`. Two plateau models are fit, one for CpG transitions, and one for the remainder of sites (transversions and non CpG transitions). @@ -877,9 +1488,10 @@ def build_models( Plateau model: adjusts proportion of expected variation based on location in the genome and CpG status. + The x and y of the plateau models: - - x: `mu_snp` - mutation rate - - y: proportion observed ('observed_variants' or 'observed_{gen_anc}' / 'possible_variants') + - x: `mu_snp` - mutation rate + - y: proportion observed ('observed_variants' / 'possible_variants') This function also builds models (coverage models) to calibrate the proportion of expected variation at low coverage sites (sites below `high_cov_definition`). @@ -892,223 +1504,244 @@ def build_models( Low coverage sites are defined as sites with median coverage < `high_cov_definition`. The x and y of the coverage model: - - x: log10 groupings of exome coverage at low coverage sites - - y: sum('observed_variants')/ (`high_coverage_scale_factor` * sum('possible_variants' * 'mu_snp') at low coverage sites - `high_coverage_scale_factor` = sum('observed_variants') / - sum('possible_variants' * 'mu_snp') at high coverage sites + - x: groupings of exome coverage at low coverage sites (log10 transformed if + requested) + - y: sum('observed_variants') / (``high_coverage_scale_factor`` * + sum('possible_variants' * 'mu_snp')) at low coverage sites + + ``high_coverage_scale_factor`` = sum('observed_variants') / + sum('possible_variants' * 'mu_snp') at high coverage sites .. note:: - This function expects that the input Table(`coverage_ht`) was created using - `get_proportion_observed_by_coverage`, which means that `coverage_ht` should - contain only high quality synonymous variants below 0.1% frequency. - - This function also expects that the following fields are present in - `coverage_ht`: - - context - trinucleotide genomic context - - ref - the reference allele - - alt - the alternate allele - - methylation_level - methylation level - - cpg - whether the site is CpG site - - observed_variants - the number of observed variants in the dataset for each - variant. Note that the term "variant" here refers to a specific substitution, - context, methylation level, and coverage combination - - downsampling_counts_{gen_anc} (optional) - array of observed variant counts per - genetic ancestry group after downsampling. Used only when `gen_ancs` is specified. - - mu_snp - mutation rate - - possible_variants - the number of possible variants in the dataset for each - variant - - :param coverage_ht: Input coverage Table. + This function expects that the input Table (`ht`) contains only high quality + synonymous variants below 0.1% frequency. + + The following fields are expected in `ht`: + + - context: trinucleotide genomic context. + - ref: the reference allele. + - alt: the alternate allele. + - methylation_level: methylation level. + - mu_snp: mutation rate. + - cpg: whether the variant is a CpG. + - observed_variants: the number of observed variants in the dataset for each + variant. Note that the term "variant" here refers to a specific substitution, + context, methylation level, and coverage combination. + - possible_variants: the number of possible variants in the dataset for each + variant. + + :param ht: Input Table. :param coverage_expr: Expression that defines the coverage metric. :param weighted: Whether to weight the plateau models (a linear regression model) by 'possible_variants'. Default is False. - :param gen_ancs: List of genetic ancestry groups used to build plateau models. - Default is (). :param keys: Annotations used to group observed and possible variant counts. - Default is ("context", "ref", "alt", "methylation_level", "mu_snp"). - :param high_cov_definition: Lower median coverage cutoff. Sites with coverage above this cutoff - are considered well covered. Default is `COVERAGE_CUTOFF`. - :param upper_cov_cutoff: Upper median coverage cutoff. Sites with coverage above this cutoff - are excluded from the high coverage Table. Default is None. - :param skip_coverage_model: Whether to skip generating the coverage model. If set to True, - None is returned instead of the coverage model. Default is False. + Default is ("context", "ref", "alt", "methylation_level"). + :param model_group_expr: Expression with ``high_or_low_coverage`` annotation + to group variants into high or low coverage models. If not provided, the + ``calibration_model_group_expr`` function is used to define the grouping. + :param high_cov_definition: Lower coverage cutoff. Sites with coverage above this + cutoff are considered well covered. Default is ``COVERAGE_CUTOFF``. + :param upper_cov_cutoff: Upper coverage cutoff. Sites with coverage above this + cutoff are excluded from the high coverage Table. Default is None. + :param skip_coverage_model: Whether to skip generating the coverage model. If set + to True, None is returned instead of the coverage model. Default is False. :param log10_coverage: Whether to convert coverage sites with log10 when building the coverage model. Default is True. - :return: Coverage model and plateau models. + :param additional_grouping: Additional annotations to group by before + counting the observed and possible variants. Default is (). + :return: Tuple of (coverage model, plateau models). Coverage model is None + when ``skip_coverage_model`` is True. """ - # Annotate coverage_ht with coverage_expr set as a temporary annotation - # '_coverage_metric' before modifying the coverage_ht. - coverage_ht = coverage_ht.annotate(_coverage_metric=coverage_expr) + if model_group_expr is None: + # Define whether the variant should be included in the high or low coverage + # model. + model_group_expr = calibration_model_group_expr( + coverage_expr, + ht.cpg, + low_cov_cutoff=0, + high_cov_cutoff=high_cov_definition, + upper_cov_cutoff=upper_cov_cutoff, + skip_coverage_model=skip_coverage_model, + cpg_in_high_only=True, + ) - # Filter to sites with coverage_expr equal to or above `high_cov_definition`. - high_cov_ht = coverage_ht.filter( - coverage_ht._coverage_metric >= high_cov_definition - ) + grouping = keys + additional_grouping + mu_type_fields = ("cpg", "transition", "mutation_type", "mutation_type_model") + # all() accepts a generator expression directly (no list needed) and + # short-circuits on the first False. + has_mu_type = all(x in ht.row for x in mu_type_fields) + grouping += mu_type_fields if has_mu_type else () - # Filter to sites with coverage_expr equal to or below `upper_cov_cutoff` if - # specified. - if upper_cov_cutoff is not None: - high_cov_ht = high_cov_ht.filter( - high_cov_ht._coverage_metric <= upper_cov_cutoff + grouping_exprs = {"build_model": model_group_expr} + if not skip_coverage_model: + grouping_exprs["exomes_coverage"] = hl.or_missing( + model_group_expr.high_or_low_coverage == "low", coverage_expr ) - agg_expr = { - "observed_variants": hl.agg.sum(high_cov_ht.observed_variants), - "possible_variants": hl.agg.sum(high_cov_ht.possible_variants), - } - for gen_anc in gen_ancs: - agg_expr[f"observed_{gen_anc}"] = hl.agg.array_sum( - high_cov_ht[f"downsampling_counts_{gen_anc}"] + ht = ( + ht.group_by(*grouping, **grouping_exprs) + .aggregate( + mu_snp=hl.agg.take(ht.mu_snp, 1)[0], + **_build_sum_agg_struct( + fields_to_sum=["observed_variants", "possible_variants"], t=ht + ), ) + .key_by(*keys) + ) - # Generate a Table with all necessary annotations (x and y listed above) - # for the plateau models. - high_cov_group_ht = high_cov_ht.group_by(*keys).aggregate(**agg_expr) - high_cov_group_ht = annotate_mutation_type(high_cov_group_ht) + if not has_mu_type: + ht = annotate_mutation_type(ht) # Build plateau models. - plateau_models_agg_expr = build_plateau_models( - cpg_expr=high_cov_group_ht.cpg, - mu_snp_expr=high_cov_group_ht.mu_snp, - observed_variants_expr=high_cov_group_ht.observed_variants, - possible_variants_expr=high_cov_group_ht.possible_variants, - gen_ancs_observed_variants_array_expr=[ - high_cov_group_ht[f"observed_{gen_anc}"] for gen_anc in gen_ancs - ], - weighted=weighted, - ) - if gen_ancs: - # Map the models to their corresponding genetic ancestry groups if - # gen_ancs is specified. - _plateau_models = dict( - high_cov_group_ht.aggregate(hl.struct(**plateau_models_agg_expr)) - ) - gen_anc_models = _plateau_models["gen_anc"] - plateau_models = { - gen_anc: hl.literal(gen_anc_models[idx]) - for idx, gen_anc in enumerate(gen_ancs) - } - plateau_models["total"] = _plateau_models["total"] - plateau_models = hl.struct(**plateau_models) - else: - plateau_models = high_cov_group_ht.aggregate( - hl.struct(**plateau_models_agg_expr) + is_high_expr = ht.build_model.high_or_low_coverage == "high" + agg_expr = { + "plateau": hl.agg.filter( + is_high_expr, + build_plateau_models( + ht.mu_snp, + ht.observed_variants, + ht.possible_variants, + model_group_expr=ht.build_model.model_group, + weighted=weighted, + ), ) + } if not skip_coverage_model: - # Filter to sites with coverage below `high_cov_definition` and larger than 0. - low_cov_ht = coverage_ht.filter( - (coverage_ht._coverage_metric < high_cov_definition) - & (coverage_ht._coverage_metric > 0) - ) + # The coverage model is only built using the full dataset observed variants + # so use the first element if the observed_variants is an array. + obs_is_array = isinstance(ht.observed_variants, hl.expr.ArrayExpression) + obs_expr = ht.observed_variants[0] if obs_is_array else ht.observed_variants # Create a metric that represents the relative mutability of the exome calculated # on high coverage sites and will be used as scaling factor when building the # coverage model. - high_coverage_scale_factor = high_cov_ht.aggregate( - hl.agg.sum(high_cov_ht.observed_variants) - / hl.agg.sum(high_cov_ht.possible_variants * high_cov_ht.mu_snp) + autosome_or_par_expr = ( + ht.build_model.model_group.genomic_region == "autosome_or_par" + ) + agg_expr["high_coverage_scale_factor"] = hl.agg.filter( + is_high_expr & autosome_or_par_expr, + hl.agg.sum(obs_expr) / hl.agg.sum(ht.possible_variants * ht.mu_snp), ) - # Generate a Table with all necessary annotations (x and y listed above) - # for the coverage model. - if log10_coverage: - logger.info("Converting coverage sites by log10.") - cov_value = hl.log10(low_cov_ht._coverage_metric) - else: - cov_value = low_cov_ht._coverage_metric + # Get the observed variant count and mu_snp for low coverage sites. + agg_expr["coverage"] = hl.agg.filter( + (ht.build_model.high_or_low_coverage == "low") & autosome_or_par_expr, + hl.agg.group_by( + ht.exomes_coverage, + hl.struct( + obs=hl.agg.sum(obs_expr), + mu_snp=hl.agg.sum(ht.possible_variants * ht.mu_snp), + ), + ), + ) + + models = ht.aggregate(hl.struct(**agg_expr), _localize=False) - low_cov_group_ht = low_cov_ht.group_by(cov_value=cov_value).aggregate( - low_coverage_oe=hl.agg.sum(low_cov_ht.observed_variants) - / ( - high_coverage_scale_factor - * hl.agg.sum(low_cov_ht.possible_variants * low_cov_ht.mu_snp) + # Build coverage model. + coverage_model = None + if not skip_coverage_model: + coverage_model = models.coverage.map_values( + lambda x: x.annotate( + low_coverage_oe=x.obs / (models.high_coverage_scale_factor * x.mu_snp) ) ) - # Build the coverage model. # TODO: consider weighting here as well. - coverage_model_expr = build_coverage_model( - low_coverage_oe_expr=low_cov_group_ht.low_coverage_oe, - coverage_expr=low_cov_group_ht.cov_value, + coverage_model = ( + coverage_model.items() + .aggregate( + lambda x: build_coverage_model( + x[1].low_coverage_oe, x[0], log10_coverage=log10_coverage + ) + ) + .beta ) - coverage_model = tuple(low_cov_group_ht.aggregate(coverage_model_expr).beta) - else: - coverage_model = None - return coverage_model, plateau_models + return coverage_model, models.plateau def build_plateau_models( - cpg_expr: hl.expr.BooleanExpression, mu_snp_expr: hl.expr.Float64Expression, - observed_variants_expr: hl.expr.Int64Expression, - possible_variants_expr: hl.expr.Int64Expression, - gen_ancs_observed_variants_array_expr: List[hl.expr.ArrayExpression] = [], + observed_variants_expr: Union[hl.expr.ArrayExpression, hl.expr.Int64Expression], + possible_variants_expr: Union[hl.expr.ArrayExpression, hl.expr.Int64Expression], weighted: bool = False, -) -> Dict[str, Union[Dict[bool, hl.expr.ArrayExpression], hl.ArrayExpression]]: + cpg_expr: Optional[hl.expr.BooleanExpression] = None, + model_group_expr: Optional[hl.expr.StructExpression] = None, +) -> Union[hl.expr.DictExpression, hl.expr.ArrayExpression, hl.expr.StructExpression]: + """ + Build plateau models to calibrate mutation rate against proportion observed. + + Fits a linear regression of ``observed_variants_expr / possible_variants_expr`` + on ``mu_snp_expr``. When either observed or possible expressions are arrays + (e.g., one model per downsampling), the regression is applied element-wise + via ``hl.agg.array_agg``. + + When ``model_group_expr`` or ``cpg_expr`` is provided, the result is a + ``DictExpression`` keyed by the grouping struct. Otherwise the result is the + regression beta directly. + + :param mu_snp_expr: Mutation rate expression. + :param observed_variants_expr: Observed variant counts (scalar or array). + :param possible_variants_expr: Possible variant counts (scalar or array). + :param weighted: If True, use weighted least squares with + ``possible_variants_expr`` as weights. Default is False. + :param cpg_expr: Boolean expression indicating CpG sites. When provided, + adds a ``cpg`` field to the grouping struct. + :param model_group_expr: Struct expression to group by in the aggregation. + :return: Regression betas, optionally grouped by ``model_group_expr`` + (and/or ``cpg_expr``). """ - Build plateau models to calibrate mutation rate to compute predicted proportion observed value. + obs_is_array = isinstance(observed_variants_expr, hl.expr.ArrayExpression) + pos_is_array = isinstance(possible_variants_expr, hl.expr.ArrayExpression) - The x and y of the plateau models: - - x: `mu_snp_expr` - - y: `observed_variants_expr` / `possible_variants_expr` - or `gen_ancs_observed_variants_array_expr`[index] / `possible_variants_expr` - if `gen_ancs` is specified + def _linreg( + o: hl.expr.NumericExpression, + p: hl.expr.NumericExpression, + ) -> hl.expr.StructExpression: + """ + Run linear regression of observed/possible on mutation rate. - :param cpg_expr: BooleanExpression noting whether a site is a CPG site. - :param mu_snp_expr: Float64Expression of the mutation rate. - :param observed_variants_expr: Int64Expression of the observed variant counts. - :param possible_variants_expr: Int64Expression of the possible variant counts. - :param gen_ancs_observed_variants_array_expr: Nested ArrayExpression with all observed - variant counts ArrayNumericExpressions for specified genetic ancestry groups. e.g., `[[1,1, - 1],[1,1,1]]`. Default is None. - :param weighted: Whether to generalize the model to weighted least squares using - 'possible_variants'. Default is False. - :return: A dictionary of intercepts and slopes of plateau models. The keys are - 'total' (for all sites) and 'gen_anc' (optional; for genetic ancestry groups). The values for - 'total' is a dictionary (e.g., >>), and the value for 'gen_anc' is a nested list of dictionaries (e. - g., >>>>). The - key of the dictionary in the nested list is CpG status (BooleanExpression), and - the value is an ArrayExpression containing intercept and slope values. - """ - # Build plateau models for all sites - plateau_models_agg_expr = { - "total": hl.agg.group_by( - cpg_expr, - hl.agg.linreg( - observed_variants_expr / possible_variants_expr, - [1, mu_snp_expr], - weight=possible_variants_expr if weighted else None, - ).beta, + :param o: Observed variant count expression. + :param p: Possible variant count expression. + :return: Regression beta coefficients. + """ + return hl.agg.linreg( + o / p, [1, mu_snp_expr], weight=p if weighted else None + ).beta + + if obs_is_array and pos_is_array: + agg_expr = hl.agg.array_agg( + lambda x: _linreg(*x), + hl.zip(observed_variants_expr, possible_variants_expr), ) - } - if gen_ancs_observed_variants_array_expr: - # Build plateau models using sites in genetic ancestry group downsamplings if - # genetic ancestry group is specified. - plateau_models_agg_expr["gen_anc"] = hl.agg.array_agg( - lambda gen_anc_obs_var_array_expr: hl.agg.array_agg( - lambda gen_anc_observed_variants: hl.agg.group_by( - cpg_expr, - hl.agg.linreg( - gen_anc_observed_variants / possible_variants_expr, - [1, mu_snp_expr], - weight=possible_variants_expr, - ).beta, - ), - gen_anc_obs_var_array_expr, - ), - gen_ancs_observed_variants_array_expr, + elif obs_is_array: + agg_expr = hl.agg.array_agg( + lambda x: _linreg(x, possible_variants_expr), observed_variants_expr + ) + elif pos_is_array: + agg_expr = hl.agg.array_agg( + lambda x: _linreg(observed_variants_expr, x), possible_variants_expr ) - return plateau_models_agg_expr + else: + agg_expr = _linreg(observed_variants_expr, possible_variants_expr) + + if model_group_expr is None and cpg_expr is None: + return agg_expr + + model_group_expr = model_group_expr or hl.struct() + if cpg_expr is not None: + model_group_expr = model_group_expr.annotate(cpg=cpg_expr) + + return hl.agg.group_by(model_group_expr, agg_expr) def build_coverage_model( low_coverage_oe_expr: hl.expr.Float64Expression, coverage_expr: hl.expr.Float64Expression, + log10_coverage: bool = False, ) -> hl.expr.StructExpression: """ Build coverage model. @@ -1117,14 +1750,23 @@ def build_coverage_model( proportion of expected variation at low coverage sites. The x and y of the coverage model: - - x: `coverage_expr` - - y: `low_coverage_oe_expr` + + - x: `coverage_expr` + - y: `low_coverage_oe_expr` :param low_coverage_oe_expr: The Float64Expression of observed:expected ratio for a given coverage level. :param coverage_expr: The Float64Expression of the coverage expression. + :param log10_coverage: Whether to convert coverage sites by log10 when building the + coverage model. Default is False. :return: StructExpression with intercept and slope of the model. """ + if log10_coverage: + logger.info( + "Converting coverage sites by log10 when building the coverage model." + ) + coverage_expr = hl.log10(coverage_expr) + return hl.agg.linreg(low_coverage_oe_expr, [1, coverage_expr]) @@ -1176,55 +1818,62 @@ def get_all_gen_anc_lengths( def get_constraint_grouping_expr( - vep_annotation_expr: hl.StructExpression, - coverage_expr: Optional[hl.Int32Expression] = None, + vep_annotation_expr: hl.expr.StructExpression, + coverage_expr: Optional[hl.expr.Int32Expression] = None, include_transcript_group: bool = True, include_canonical_group: bool = True, include_mane_select_group: bool = False, -) -> Dict[str, Union[hl.StringExpression, hl.Int32Expression, hl.BooleanExpression]]: +) -> Dict[ + str, + Union[hl.expr.StringExpression, hl.expr.Int32Expression, hl.expr.BooleanExpression], +]: """ Collect annotations used for constraint groupings. Function collects the following annotations: - - annotation - 'most_severe_consequence' annotation in `vep_annotation_expr` - - modifier - classic lof annotation from 'lof' annotation in - `vep_annotation_expr`, LOFTEE annotation from 'lof' annotation in - `vep_annotation_expr`, PolyPhen annotation from 'polyphen_prediction' in - `vep_annotation_expr`, or "None" if neither is defined - - gene - 'gene_symbol' annotation inside `vep_annotation_expr` - - coverage - exome coverage if `coverage_expr` is specified - - transcript - id from 'transcript_id' in `vep_annotation_expr` (added when - `include_transcript_group` is True) - - canonical from `vep_annotation_expr` (added when `include_canonical_group` is - True) - - mane_select from `vep_annotation_expr` (added when `include_mane_select_group` is - True) + + - annotation - most_severe_consequence from ``vep_annotation_expr`` + - modifier - first non-missing of lof or polyphen_prediction + from ``vep_annotation_expr``, or the literal "None" + - gene - gene_symbol from ``vep_annotation_expr`` + - gene_id - gene_id from ``vep_annotation_expr`` + - coverage - exome coverage if ``coverage_expr`` is specified + - transcript - transcript_id from ``vep_annotation_expr`` (added when + ``include_transcript_group`` is True) + - canonical - from ``vep_annotation_expr`` (added when + ``include_canonical_group`` is True) + - mane_select - from ``vep_annotation_expr`` (added when + ``include_mane_select_group`` is True) .. note:: + This function expects that the following fields are present in - `vep_annotation_expr`: - - lof - - polyphen_prediction - - most_severe_consequence - - gene_symbol - - transcript_id (if `include_transcript_group` is True) - - canonical (if `include_canonical_group` is True) - - mane_select (if `include_mane_select_group` is True) + ``vep_annotation_expr``: + + - lof + - most_severe_consequence + - gene_symbol + - gene_id + - polyphen_prediction (optional; missing used if absent) + - transcript_id (if ``include_transcript_group`` is True) + - canonical (if ``include_canonical_group`` is True) + - mane_select (if ``include_mane_select_group`` is True) :param vep_annotation_expr: StructExpression of VEP annotation. - :param coverage_expr: Optional Int32Expression of exome coverage. Default is None. + :param coverage_expr: Int32Expression of exome coverage. Default is None. :param include_transcript_group: Whether to include the transcript annotation in the groupings. Default is True. :param include_canonical_group: Whether to include canonical annotation in the groupings. Default is True. :param include_mane_select_group: Whether to include mane_select annotation in the groupings. Default is False. - - :return: A dictionary with keys as annotation names and values as actual - annotations. + :return: Dict mapping annotation names to Hail expressions. """ lof_expr = vep_annotation_expr.lof - polyphen_prediction_expr = vep_annotation_expr.polyphen_prediction + if "polyphen_prediction" in vep_annotation_expr: + polyphen_prediction_expr = vep_annotation_expr.polyphen_prediction + else: + polyphen_prediction_expr = hl.missing(hl.tstr) # Create constraint annotations to be used for groupings. groupings = { @@ -1242,57 +1891,60 @@ def get_constraint_grouping_expr( if include_canonical_group: groupings["canonical"] = hl.or_else(vep_annotation_expr.canonical == 1, False) if include_mane_select_group: - groupings["mane_select"] = hl.or_else( - hl.is_defined(vep_annotation_expr.mane_select), False - ) + groupings["mane_select"] = hl.is_defined(vep_annotation_expr.mane_select) return groupings def annotate_exploded_vep_for_constraint_groupings( ht: hl.Table, - coverage_expr: hl.expr.Int32Expression, + coverage_expr: Optional[hl.expr.Int32Expression] = None, vep_annotation: str = "transcript_consequences", include_canonical_group: bool = True, include_mane_select_group: bool = False, -) -> Tuple[Union[hl.Table, hl.MatrixTable], Tuple[str]]: - """ - Annotate Table with annotations used for constraint groupings. - - Function explodes the specified VEP annotation (`vep_annotation`) and adds the following annotations: - - annotation -'most_severe_consequence' annotation in `vep_annotation` - - modifier - classic lof annotation from 'lof' annotation in - `vep_annotation`, LOFTEE annotation from 'lof' annotation in - `vep_annotation`, PolyPhen annotation from 'polyphen_prediction' in - `vep_annotation`, or "None" if neither is defined - - gene - 'gene_symbol' annotation inside `vep_annotation` - - transcript - id from 'transcript_id' in `vep_annotation` (added when - `include_transcript_group` is True) - - canonical from `vep_annotation` (added when `include_canonical_group` is - True) - - mane_select from `vep_annotation` (added when `include_mane_select_group` is - True) +) -> Tuple[hl.Table, Tuple[str, ...]]: + """ + Explode a VEP annotation and add constraint grouping fields. + + Explodes the specified VEP annotation (``vep_annotation``) and adds the + following annotations via ``get_constraint_grouping_expr``: + + - annotation - most_severe_consequence from ``vep_annotation`` + - modifier - first non-missing of lof or polyphen_prediction + from ``vep_annotation``, or the literal "None" + - gene - gene_symbol from ``vep_annotation`` + - gene_id - gene_id from ``vep_annotation`` + - coverage - exome coverage if ``coverage_expr`` is specified + - transcript - transcript_id from ``vep_annotation`` (added when + ``vep_annotation`` is "transcript_consequences") + - canonical - from ``vep_annotation`` (added when + ``include_canonical_group`` is True) + - mane_select - from ``vep_annotation`` (added when + ``include_mane_select_group`` is True) .. note:: - This function expects that the following annotations are present in `ht`: - - vep - - exome_coverage - :param ht: Input Table or MatrixTable. - :param coverage_expr: Expression that defines the coverage metric. - :param vep_annotation: Name of annotation in 'vep' annotation (one of - "transcript_consequences" and "worst_csq_by_gene") that will be used for - obtaining constraint annotations. Default is "transcript_consequences". - :param include_canonical_group: Whether to include 'canonical' annotation in the - groupings. Default is True. Ignored unless `vep_annotation` is "transcript_consequences". - :param include_mane_select_group: Whether to include 'mane_select' annotation in the - groupings. Default is False. Ignored unless `vep_annotation` is "transcript_consequences". - :return: A tuple of input Table or MatrixTable with grouping annotations added and - the names of added annotations. + This function expects that a ``vep`` annotation is present in ``ht``. + + :param ht: Input Table. + :param coverage_expr: Expression that defines the coverage metric. Default + is None. + :param vep_annotation: Name of annotation in vep (one of + ``"transcript_consequences"`` and ``"worst_csq_by_gene"``) that will be + used for obtaining constraint annotations. Default is + ``"transcript_consequences"``. + :param include_canonical_group: Whether to include canonical annotation + in the groupings. Default is True. Ignored unless ``vep_annotation`` is + ``"transcript_consequences"``. + :param include_mane_select_group: Whether to include mane_select + annotation in the groupings. Default is False. Ignored unless + ``vep_annotation`` is ``"transcript_consequences"``. + :return: Tuple of (annotated Table, names of added grouping fields). """ # Annotate ht with coverage_expr set as a temporary annotation '_coverage_metric' # before modifying ht. - ht = ht.annotate(_coverage_metric=coverage_expr) + if coverage_expr is not None: + ht = ht.annotate(_coverage_metric=coverage_expr) if vep_annotation == "transcript_consequences": if not include_canonical_group and not include_mane_select_group: @@ -1321,15 +1973,16 @@ def annotate_exploded_vep_for_constraint_groupings( # Collect the annotations used for groupings. groupings = get_constraint_grouping_expr( ht[vep_annotation], - coverage_expr=ht._coverage_metric, + coverage_expr=None if coverage_expr is None else ht._coverage_metric, include_transcript_group=include_transcript_group, include_canonical_group=include_canonical_group, include_mane_select_group=include_mane_select_group, ) - return ht.annotate(**groupings), tuple(groupings.keys()) + return ht.transmute(**groupings), tuple(groupings.keys()) +# TODO: Not totally sure this is needed anymore... def compute_expected_variants( ht: hl.Table, plateau_models_expr: hl.StructExpression, @@ -1391,6 +2044,7 @@ def compute_expected_variants( return agg_expr +# TODO: Can probably be modified some given my other changes. def oe_aggregation_expr( ht: hl.Table, filter_expr: hl.expr.BooleanExpression, @@ -1468,6 +2122,210 @@ def oe_aggregation_expr( return hl.agg.group_by(filter_expr, agg_expr).get(True, hl.missing(agg_expr.dtype)) +def apply_plateau_models( + mu_expr: hl.Float64Expression, + plateau_models_expr: hl.ArrayExpression, +) -> Union[hl.ArrayExpression, hl.Float64Expression]: + """ + Compute the predicted probability observed. + + The predicted probability observed is computed as the mutation rate adjusted by the + plateau model. + + :param mu_expr: Mutation rate expression. + :param plateau_models_expr: This can be either a single plateau model, where the + first element is the intercept and the second element is the slope, or an array + of plateau models. + :return: Predicted probability observed expression. + """ + + def _apply_model(plateau_model: hl.ArrayExpression) -> hl.Float64Expression: + """ + Apply the plateau model to the mutation rate expression. + + :param plateau_model: ArrayExpression of the plateau model. + :return: Predicted probability observed expression. + """ + slope = plateau_model[1] + intercept = plateau_model[0] + ppo_expr = mu_expr * slope + intercept + + return ppo_expr + + if plateau_models_expr.dtype.element_type == hl.tarray(hl.tfloat64): + return plateau_models_expr.map(lambda x: _apply_model(x)) + + return _apply_model(plateau_models_expr) + + +def coverage_correction_expr( + coverage_expr: hl.Float64Expression, + coverage_model: Tuple[float, float], + low_coverage_expr: Optional[hl.BooleanExpression] = None, + coverage_cutoff: Optional[int] = None, + log10_coverage: bool = False, +) -> hl.Float64Expression: + """ + Compute the coverage correction expression. + + .. note:: + + One and only one of `low_coverage_expr` or `coverage_cutoff` must be specified. + + The coverage correction expression is computed as follows: + + - If the coverage is 0, the coverage correction is 0. + - If the low coverage expression (`low_coverage_expr`) is True, or the + coverage (`coverage_expr`) is below the coverage cutoff (`coverage_cutoff`), + the coverage model is applied to the coverage. + - Otherwise, the coverage correction is 1. + + :param coverage_expr: Float64Expression of the coverage. + :param coverage_model: Tuple of the intercept and slope of the coverage model. + :param low_coverage_expr: Optional BooleanExpression indicating whether the site is + a low coverage site, and the coverage model should be applied. Default is None. + :param coverage_cutoff: Optional coverage cutoff. If specified, the coverage model + is applied to sites with coverage below this cutoff. Default is None. + :param log10_coverage: Whether to convert coverage sites by log10 when applying the + coverage model. Default is False. + :return: Float64Expression of the coverage correction. + """ + if low_coverage_expr is None and coverage_cutoff is None: + raise ValueError( + "Either 'low_coverage_expr' or 'coverage_cutoff' must be specified!" + ) + if low_coverage_expr is not None and coverage_cutoff is not None: + raise ValueError( + "Only one of 'low_coverage_expr' or 'coverage_cutoff' can be specified!" + ) + + if coverage_cutoff is not None: + low_coverage_expr = coverage_expr < coverage_cutoff + + if log10_coverage: + cov_corr_expr = hl.log10(coverage_expr) + else: + cov_corr_expr = coverage_expr + + return ( + hl.case() + .when(coverage_expr == 0, 0) + .when( + low_coverage_expr, + coverage_model[1] * cov_corr_expr + coverage_model[0], + ) + .default(1) + ) + + +def apply_models( + mu_expr: hl.expr.Float64Expression, + plateau_models_expr: hl.expr.ArrayExpression, + possible_variants_expr: hl.expr.Int64Expression, + coverage_model: Optional[Tuple[float, float]] = None, + coverage_expr: Optional[hl.expr.Int32Expression] = None, + cpg_expr: Optional[hl.expr.BooleanExpression] = None, + model_group_expr: Optional[hl.expr.StructExpression] = None, + high_cov_definition: int = COVERAGE_CUTOFF, + log10_coverage: bool = True, +) -> hl.expr.StructExpression: + """ + Apply calibration models to compute expected variant counts. + + Applies plateau and (optionally) coverage models to produce a struct with + ``mu``, ``predicted_proportion_observed``, ``expected_variants``, and + (when a coverage model is provided) ``coverage_correction``. + + :param mu_expr: Mutation rate expression. + :param plateau_models_expr: Single plateau model (array of [intercept, + slope]) or an array of plateau models. + :param possible_variants_expr: Possible variant counts to multiply the + predicted proportion observed by. + :param coverage_model: Tuple of (intercept, slope) of the coverage model. + Default is None. + :param coverage_expr: Int32Expression of the coverage. Required when + ``coverage_model`` is provided or ``model_group_expr`` is None. + :param cpg_expr: BooleanExpression indicating CpG sites. Required when + ``model_group_expr`` is None. + :param model_group_expr: Expression with ``high_or_low_coverage`` + annotation to group variants into high or low coverage models. If not + provided, ``calibration_model_group_expr`` is used. Default is None. + :param high_cov_definition: Coverage threshold for high/low classification. + Default is ``COVERAGE_CUTOFF``. + :param log10_coverage: Whether to log10-transform coverage when applying + the coverage model. Default is True. + :return: StructExpression with "mu", "predicted_proportion_observed", + "expected_variants", and optionally "coverage_correction". + """ + if coverage_model is not None and coverage_expr is None: + raise ValueError( + "If 'coverage_model' is specified, 'coverage_expr' must also be specified!" + ) + if model_group_expr is None and (coverage_expr is None or cpg_expr is None): + raise ValueError( + "If 'model_group_expr' is not specified, 'coverage_expr' and 'cpg_expr' must" + " be specified!" + ) + + if model_group_expr is None: + # Get the annotations relevant for applying the calibration models. + model_group_expr = calibration_model_group_expr( + coverage_expr, + cpg_expr, + high_cov_cutoff=high_cov_definition, + skip_coverage_model=coverage_model is None, + ) + + # Apply plateau models. + ppo_expr = apply_plateau_models(mu_expr, plateau_models_expr) + apply_expr = hl.struct( + mu=mu_expr * possible_variants_expr, + predicted_proportion_observed=ppo_expr, + expected_variants=ppo_expr * possible_variants_expr, + ) + + # Get the coverage correction expression if a coverage model is provided. + if coverage_model is not None: + cov_corr_expr = coverage_correction_expr( + coverage_expr, + coverage_model, + low_coverage_expr=model_group_expr.high_or_low_coverage == "low", + log10_coverage=log10_coverage, + ) + apply_expr = apply_expr.annotate( + mu=apply_expr.mu * cov_corr_expr, + expected_variants=apply_expr.expected_variants * cov_corr_expr, + coverage_correction=cov_corr_expr, + ) + + return apply_expr + + +def aggregate_constraint_metrics_expr( + t: Union[hl.Table, hl.StructExpression], + fields_to_sum: Union[List[str], Tuple[str, ...]] = DEFAULT_FIELDS_TO_SUM, + additional_exprs_to_sum: Optional[Dict[str, hl.expr.Expression]] = None, +) -> hl.expr.StructExpression: + """ + Get an aggregation expression for the sum of expected variants and other fields. + + An aggregate sum or array sum is created for each field in ``fields_to_sum``. + + :param t: Input Table or StructExpression. + :param fields_to_sum: Fields in ``t`` to get an aggregate sum expression for. + Default is ``DEFAULT_FIELDS_TO_SUM``. + :param additional_exprs_to_sum: Dictionary of additional expressions to get + an aggregate sum expression for. Field names are the keys and expressions are + the values. Default is None. + :return: StructExpression with the sum of expected variants and other fields. + """ + return _build_sum_agg_struct( + fields_to_sum=list(fields_to_sum), + exprs_to_sum=additional_exprs_to_sum, + t=t, + ) + + def compute_pli( ht: hl.Table, obs_expr: hl.expr.Int64Expression, @@ -1558,37 +2416,22 @@ def compute_pli( return hl.struct(**{f"p{k}": pli_expr[k] / row_sum_expr for k in pi.keys()}) -def oe_confidence_interval( +def _oe_ci_discretized_poisson( obs_expr: hl.expr.Int64Expression, exp_expr: hl.expr.Float64Expression, alpha: float = 0.05, ) -> hl.expr.StructExpression: """ - Determine the confidence interval around the observed:expected ratio. - - For a given pair of observed (`obs_expr`) and expected (`exp_expr`) values, the - function computes the density of the Poisson distribution (performed using Hail's - `dpois` module) with fixed k (`x` in `dpois` is set to the observed number of - variants) over a range of lambda (`lamb` in `dpois`) values, which are given by the - expected number of variants times a varying parameter ranging between 0 and 2 (the - observed:expected ratio is typically between 0 and 1, so we want to extend the - upper bound of the confidence interval to capture this). The cumulative density - function of the Poisson distribution density is computed and the value of the - varying parameter is extracted at points corresponding to `alpha` (defaults to 5%) - and 1-`alpha` (defaults to 95%) to indicate the lower and upper bounds of the - confidence interval. - - The following annotations are in the output StructExpression: - - lower - the lower bound of confidence interval - - upper - the upper bound of confidence interval - - :param obs_expr: Expression for the observed variant counts of pLoF, missense, or - synonymous variants in `ht`. - :param exp_expr: Expression for the expected variant counts of pLoF, missense, or - synonymous variants in `ht`. - :param alpha: The significance level used to compute the confidence interval. - Default is 0.05. - :return: StructExpression for the confidence interval lower and upper bounds. + Compute OE confidence interval via discretized Poisson CDF. + + Sweeps the OE ratio parameter over [0, 2) in steps of 0.001, evaluates + ``dpois(obs, exp * x)`` at each point, normalises the cumulative sum, and + reads off the bounds at ``alpha`` and ``1 - alpha``. + + :param obs_expr: Observed variant count expression. + :param exp_expr: Expected variant count expression. + :param alpha: Significance level. Default is 0.05. + :return: Struct with ``lower`` and ``upper`` bounds. """ # Set up range between 0 and 2. range_expr = hl.range(0, 2000).map(lambda x: hl.float64(x) / 1000) @@ -1612,10 +2455,73 @@ def oe_confidence_interval( ) -def calculate_raw_z_score( +def _oe_ci_gamma( + obs_expr: hl.expr.Int32Expression, + exp_expr: hl.expr.Float64Expression, + alpha: float = 0.05, +) -> hl.expr.StructExpression: + """ + Compute OE confidence interval using the Gamma distribution. + + Uses Hail's ``hl.qgamma`` quantile function. + + :param obs_expr: Observed variant count expression. + :param exp_expr: Expected variant count expression. + :param alpha: Significance level. Default is 0.05. + :return: Struct with ``lower`` and ``upper`` bounds. + """ + try: + qgamma = hl.qgamma + except AttributeError: + raise RuntimeError( + "_oe_ci_gamma requires hl.qgamma, available in Hail >= 0.2.137. " + "Use method='poisson' or upgrade Hail." + ) + shape = obs_expr + hl.literal(1.0) + scale = divide_null(hl.literal(1.0), exp_expr) + return hl.struct( + lower=qgamma(hl.literal(alpha), shape, scale), + upper=qgamma(hl.literal(1.0 - alpha), shape, scale), + ) + + +def oe_confidence_interval( obs_expr: hl.expr.Int64Expression, exp_expr: hl.expr.Float64Expression, + alpha: float = 0.05, + method: str = "gamma", ) -> hl.expr.StructExpression: + """ + Compute a confidence interval around the observed/expected ratio. + + Two methods are available: + + - ``"gamma"`` (default): uses ``hl.qgamma`` to compute exact quantiles of + the Gamma posterior. Fast and precise. + - ``"poisson"``: sweeps a discretized Poisson likelihood over the OE + parameter space [0, 2). Retained for backwards compatibility. + + :param obs_expr: Observed variant count expression. + :param exp_expr: Expected variant count expression. + :param alpha: Significance level for the confidence interval. Default is + 0.05 (90% CI). + :param method: CI method — ``"gamma"`` or ``"poisson"``. Default is + ``"gamma"``. + :return: Struct with ``lower`` and ``upper`` bounds. + :raises ValueError: If ``method`` is not ``"gamma"`` or ``"poisson"``. + """ + if method == "gamma": + return _oe_ci_gamma(obs_expr, exp_expr, alpha) + elif method == "poisson": + return _oe_ci_discretized_poisson(obs_expr, exp_expr, alpha) + else: + raise ValueError(f"Unknown CI method: {method!r}. Use 'gamma' or 'poisson'.") + + +def calculate_raw_z_score( + obs_expr: hl.expr.Int64Expression, + exp_expr: hl.expr.Float64Expression, +) -> hl.expr.Float64Expression: """ Compute the signed raw z-score using observed and expected variant counts. @@ -1624,7 +2530,7 @@ def calculate_raw_z_score( :param obs_expr: Observed variant count expression. :param exp_expr: Expected variant count expression. - :return: StructExpression for the raw z-score. + :return: Raw z-score expression. """ chisq_expr = divide_null((obs_expr - exp_expr) ** 2, exp_expr) return hl.sqrt(chisq_expr) * hl.if_else(obs_expr > exp_expr, -1, 1) @@ -1633,8 +2539,8 @@ def calculate_raw_z_score( def get_constraint_flags( exp_expr: hl.expr.Float64Expression, raw_z_expr: hl.expr.Float64Expression, - raw_z_lower_threshold: Optional[float] = -5.0, - raw_z_upper_threshold: Optional[float] = 5.0, + raw_z_lower_threshold: Optional[Union[float, hl.expr.Float64Expression]] = -5.0, + raw_z_upper_threshold: Optional[Union[float, hl.expr.Float64Expression]] = 5.0, flag_postfix: str = "", ) -> Dict[str, hl.expr.Expression]: """ @@ -1659,9 +2565,9 @@ def get_constraint_flags( """ outlier_expr = False if raw_z_lower_threshold is not None: - outlier_expr |= raw_z_expr < raw_z_lower_threshold + outlier_expr |= hl.or_else(raw_z_expr < raw_z_lower_threshold, False) if raw_z_upper_threshold is not None: - outlier_expr |= raw_z_expr > raw_z_upper_threshold + outlier_expr |= hl.or_else(raw_z_expr > raw_z_upper_threshold, False) if flag_postfix: flag_postfix = f"_{flag_postfix}" @@ -1710,46 +2616,34 @@ def calculate_raw_z_score_sd( def add_gencode_transcript_annotations( ht: hl.Table, gencode_ht: hl.Table, - annotations: List[str] = [ - "transcript_id_version", - "gene_id_version", - "level", - "transcript_type", - "start_position", - "end_position", - ], + annotations: Union[List[str], Tuple[str, ...]] = DEFAULT_GENCODE_ANNOTATIONS, remove_y_par: bool = True, ) -> hl.Table: """ Add GENCODE annotations to Table based on transcript id. - .. note:: - - Added annotations by default are: - - level - - transcript_type - - start_position (start of the transcript) - - end_position (end of the transcript) + In addition to the annotations specified by ``annotations``, the following + computed annotations are always added: - Computed annotations are: - chromosome - cds_length - num_coding_exons :param ht: Input Table. :param gencode_ht: Table with GENCODE annotations. - :param annotations: List of GENCODE annotations to add. Default is - ["transcript_id_version", "gene_id_version", "level", "transcript_type", - "start_position", "end_position"]. + :param annotations: GENCODE annotations to add. Default is + ``DEFAULT_GENCODE_ANNOTATIONS``. :param remove_y_par: Whether to remove features for the Y chromosome PAR regions. Default is True because the Y chromosome PAR regions are typically not included in the constraint calculations and both chrX and chrY will have the same 'transcript_id' field for these regions. This parameter can only be True if - `gencode_ht` includes a 'transcript_id_version' field because Y_PAR is included - in the version of the transcript, which has been stripped from the + ``gencode_ht`` includes a 'transcript_id_version' field because Y_PAR is + included in the version of the transcript, which has been stripped from the 'transcript_id' field. :return: Table with transcript annotations from GENCODE added. """ + annotations = list(annotations) + if remove_y_par and "transcript_id_version" not in gencode_ht.row: raise ValueError( "remove_y_par is True but 'transcript_id_version' is not in gencode_ht" @@ -1758,6 +2652,7 @@ def add_gencode_transcript_annotations( if remove_y_par: gencode_ht = gencode_ht.filter( ~gencode_ht.transcript_id_version.endswith("Y_PAR") + & ~gencode_ht.transcript_id_version.endswith("PAR_Y") ) gencode_ht = gencode_ht.annotate( @@ -1805,3 +2700,286 @@ def add_gencode_transcript_annotations( ht = ht.annotate(**gencode_transcript_ht[ht.transcript]) return ht + + +def rank_and_assign_bins( + value_expr: hl.expr.Float64Expression, + bin_granularities: Optional[Dict[str, int]] = None, + prefix: str = "", +) -> hl.StructExpression: + """Rank rows by a numeric expression and assign bin labels. + + **Rank-based binning**: every row receives a unique position in the sorted + order, and bins are derived from that position. This differs from + threshold-based binning (see :func:`annotate_bins_by_threshold`), where + bins are assigned by comparing values against pre-computed boundary values. + + Rows are ordered ascending by ``value_expr``. Each row is assigned a + 0-based ``{prefix}rank`` and a ``{prefix}bin_{name}`` field for every + entry in ``bin_granularities``, computed as + ``hl.int(rank * multiplier / n_rows)``. + + Used by :func:`rank_array_element_metrics` to rank metrics within array + elements. + + :param value_expr: Numeric expression to rank by (ascending). + :param bin_granularities: Mapping of bin name to multiplier. Each entry + produces a ``{prefix}bin_{name}`` field. Default is + ``{"percentile": 100, "decile": 10, "sextile": 6}``. + :param prefix: String prepended to ``rank`` and ``bin_{name}`` field + names. Default is ``""`` (no prefix). + :return: Struct with ``{prefix}rank`` and ``{prefix}bin_{name}`` fields + for each entry in ``bin_granularities``. + """ + if bin_granularities is None: + bin_granularities = {"percentile": 100, "decile": 10, "sextile": 6} + + ht = value_expr._indices.source + source_key = list(ht.key) + n_rows = ht.count() + ranked_ht = ht.select(_=value_expr).order_by("_").add_index("rank") + ranked_ht = ranked_ht.select( + *source_key, + **{f"{prefix}rank": ranked_ht.rank}, + **{ + f"{prefix}bin_{name}": hl.int(ranked_ht.rank * multiplier / n_rows) + for name, multiplier in bin_granularities.items() + }, + ).cache() + + return ranked_ht.key_by(*source_key)[ht.key] + + +def compute_percentile_thresholds( + ht: hl.Table, + metric_expr: hl.expr.Float64Expression, + outlier_expr: Optional[hl.expr.BooleanExpression] = None, + transcript_filter_expr: Optional[hl.expr.BooleanExpression] = None, + percentiles: Tuple[float, ...] = (1, 5, 10, 15, 25, 50, 75), + quantile_k: int = 1000, +) -> Dict[float, float]: + """Compute approximate percentile thresholds for a metric expression. + + **Threshold-based binning, step 1**: computes the boundary values that + define bin edges. The returned dict is passed to + :func:`annotate_bins_by_threshold` (step 2) to assign each row to a bin. + + This two-step approach differs from rank-based binning (see + :func:`rank_and_assign_bins`), where every row receives a unique position + in the sorted order. Threshold-based binning allows thresholds to be + computed on a filtered subset (e.g., representative transcripts) and then + applied to all rows, so multiple rows can share the same bin. + + Optionally filters to a subset of rows and excludes outliers, then + computes approximate quantile thresholds at the requested percentiles in + a single aggregation pass. + + .. note:: + + Uses ``hl.agg.approx_quantiles``, so results are approximate. Increase + ``quantile_k`` for higher accuracy. + + :param ht: Input Table. + :param metric_expr: Float expression to compute thresholds for. Must be + defined on ``ht``. + :param outlier_expr: Optional boolean expression that is ``True`` for rows + to exclude. When None (default), no outlier filtering is applied. + :param transcript_filter_expr: Optional boolean expression that is ``True`` + for rows to include. When None (default), all rows are included. + :param percentiles: Percentile values (0-100) at which to compute + thresholds. Default is (1, 5, 10, 15, 25, 50, 75). + :param quantile_k: Accuracy parameter for + :func:`hail.expr.aggregators.approx_quantiles`. Default is 1000. + :return: Dict mapping each percentile to its threshold value. + """ + qs = [p / 100.0 for p in percentiles] + filter_expr = hl.is_defined(metric_expr) + if outlier_expr is not None: + filter_expr = filter_expr & ~outlier_expr + if transcript_filter_expr is not None: + filter_expr = filter_expr & transcript_filter_expr + + result = ht.aggregate( + hl.struct( + thresholds=hl.agg.filter( + filter_expr, hl.agg.approx_quantiles(metric_expr, qs, k=quantile_k) + ), + n=hl.agg.count_where(filter_expr), + ) + ) + logger.info( + "Computed percentile thresholds on %d transcripts.", + result.n, + ) + + return dict(zip(percentiles, result.thresholds)) + + +def annotate_bins_by_threshold( + ht: hl.Table, + metric_exprs: Dict[str, hl.expr.Float64Expression], + thresholds: Dict[Tuple[str, str], List[float]], + granularities: Union[List[str], Tuple[str, ...]], + field_name: str = "constraint_bins", +) -> hl.Table: + """ + Annotate rows with bin assignments using pre-computed thresholds. + + **Threshold-based binning, step 2**: assigns each row to a bin by + comparing its metric value against boundary values produced by + :func:`compute_percentile_thresholds` (step 1). For each + ``(granularity, metric)`` pair, the bin equals the number of threshold + boundaries the value exceeds. Bin 0 is the most constrained (below all + thresholds); bin N means the value exceeds all N boundaries. + + This differs from rank-based binning (see :func:`rank_and_assign_bins`), + where every row gets a unique position. Here, multiple rows can share a + bin, and the thresholds may have been derived from a different subset of + rows than those being annotated. + + :param ht: Input Table. + :param metric_exprs: Mapping of metric name to the Float64Expression to + bin (e.g. ``{"lof": ht.lof_oe_upper, "mis": ht.mis_oe_upper}``). + :param thresholds: Mapping of ``(granularity, metric)`` to an ordered list + of threshold values, as produced by + :func:`compute_percentile_thresholds`. + :param granularities: Granularity names to iterate over (e.g. + ``["decile", "ventile"]``). Each must appear as the first element of + at least one key in ``thresholds``. + :param field_name: Name of the struct field to annotate on ``ht``. + Default is ``"constraint_bins"``. + :return: Table with an added struct field containing per-granularity, + per-metric bin assignments. + """ + miss = hl.missing(hl.tint32) + + def _bin_expr( + value_expr: hl.expr.Float64Expression, + threshold_list: List[float], + ) -> hl.expr.Int32Expression: + """Count how many thresholds a value exceeds. + + :param value_expr: Metric value to bin. + :param threshold_list: Ordered list of boundary values. + :return: Number of boundaries exceeded (0 = below all thresholds). + """ + arr = hl.literal(threshold_list) + return hl.sum(arr.map(lambda t: hl.int(value_expr >= t))) + + granularities_expr = { + gran: hl.struct( + **{ + metric: hl.if_else( + hl.is_defined(metric_exprs[metric]), + _bin_expr(metric_exprs[metric], thresholds[(gran, metric)]), + miss, + ) + for metric in metric_exprs + } + ) + for gran in granularities + } + + return ht.annotate(**{field_name: hl.struct(**granularities_expr)}) + + +def rank_array_element_metrics( + ht: hl.Table, + array_field: str, + element_value_fn: Callable[ + [hl.expr.StructExpression], Dict[str, hl.expr.Float64Expression] + ], + filter_fn: Optional[Callable[[hl.Table], hl.expr.BooleanExpression]] = None, + bin_granularities: Optional[Dict[str, int]] = None, + rank_field_prefix: str = "", +) -> hl.Table: + """ + Rank metrics within array elements and annotate rank structs back. + + **Rank-based binning for array fields**: applies + :func:`rank_and_assign_bins` independently to each element of an array + field. For each element, ``element_value_fn`` extracts named metric + values, which are ranked across rows (optionally on a filtered subset + via ``filter_fn``). Each array element is then annotated with + ``{metric_name}_rank`` structs containing rank and bin fields. + + Rows not matching ``filter_fn`` get missing rank annotations. + + This differs from threshold-based binning (see + :func:`compute_percentile_thresholds` and + :func:`annotate_bins_by_threshold`), where bins are assigned by + comparing values against pre-computed boundaries rather than sorted + position. + + :param ht: Input Table. + :param array_field: Name of the array field on ``ht``. + :param element_value_fn: Function that takes an array element + (StructExpression) and returns a dict mapping metric names to + Float64Expressions to rank. Applied identically to every element. + :param filter_fn: Optional function that takes a Table and returns a + BooleanExpression to filter rows before ranking. When None, all + rows are ranked. + :param bin_granularities: Bin granularities passed to + :func:`rank_and_assign_bins`. + :param rank_field_prefix: Prefix for the ``rank`` and ``bin_{name}`` + sub-fields within each rank struct. Passed through to + :func:`rank_and_assign_bins`. Default is ``""`` (no prefix). + :return: Table with ``{metric_name}_rank`` structs added to each array + element. The table is returned with its original key restored. + """ + original_key = list(ht.key) + ht = ht.add_index("_rank_idx").key_by("_rank_idx").cache() + + subset_ht = ht.filter(filter_fn(ht)) if filter_fn is not None else ht + + # Extract values to rank using element_value_fn applied via .map(). + subset_ht = subset_ht.select( + _rank_values=subset_ht[array_field].map( + lambda elem: hl.struct(**element_value_fn(elem)) + ) + ).cache() + + # Determine element count and metric names from a sample row. + sample = subset_ht.take(1)[0]._rank_values + n_elements = len(sample) + metric_names = list(sample[0]) + + # Rank each metric within each array element. + subset_ht = subset_ht.annotate( + _rank_values=[ + hl.struct( + **{ + name: rank_and_assign_bins( + subset_ht._rank_values[i][name], + bin_granularities, + prefix=rank_field_prefix, + ) + for name in metric_names + } + ) + for i in range(n_elements) + ] + ).cache() + + # Join ranks back to the original table. Use or_missing so that + # unranked rows get correctly-typed missing rank annotations without + # needing a manually-constructed missing struct for if_else. + ht = ht.annotate(_ranks=subset_ht[ht._rank_idx]._rank_values) + ht = ht.annotate( + **{ + array_field: [ + ht[array_field][i].annotate( + **{ + f"{name}_rank": hl.or_missing( + hl.is_defined(ht._ranks), + ht._ranks[i][name], + ) + for name in metric_names + } + ) + for i in range(n_elements) + ] + } + ).drop("_ranks") + + return ht.key_by(*original_key).drop("_rank_idx") diff --git a/gnomad/utils/file_utils.py b/gnomad/utils/file_utils.py index e44833cc4..5a9653855 100644 --- a/gnomad/utils/file_utils.py +++ b/gnomad/utils/file_utils.py @@ -290,3 +290,59 @@ def create_vds( combiner.run() vds = hl.vds.read_vds(output_path) return vds + + +def print_global_struct(t: Union[hl.Table, hl.Struct, hl.StructExpression]) -> None: + """ + Pretty-print a Hail global struct with nested indentation. + + Accepts a Table (uses its globals), a StructExpression (evaluates it), + or an already-evaluated Struct. + + :param t: Table, StructExpression, or Struct to print. + """ + if isinstance(t, hl.Table): + t = t.globals + if isinstance(t, hl.StructExpression): + t = hl.eval(t) + + def _format_struct(s: hl.Struct, level: int = 1) -> str: + indent = " " * level + lines = [] + for k, v in s.items(): + if isinstance(v, hl.Struct): + v = f"\n{_format_struct(v, level + 1)}" + lines.append(f"{indent}{k}: {v}") + return "\n".join(lines) + + logger.info("\nGlobal struct:\n%s", _format_struct(t)) + + +def convert_multi_array_to_array_of_structs( + t: Union[hl.Table, hl.expr.StructExpression], + array_fields_to_combine: List[str], + new_array_field: str, +) -> Union[hl.Table, hl.expr.StructExpression]: + """ + Zip parallel array fields into a single array of structs. + + For example, given fields ``a = [1, 2]`` and ``b = [3, 4]``, produces + ``ab = [{a: 1, b: 3}, {a: 2, b: 4}]``. The original array fields are + dropped. + + .. note:: + + All arrays in ``array_fields_to_combine`` must have the same length. + + :param t: Table or StructExpression containing the array fields. + :param array_fields_to_combine: Names of the array fields to zip. + :param new_array_field: Name of the resulting array-of-structs field. + :return: Input with the original arrays replaced by ``new_array_field``. + """ + return t.annotate( + **{ + new_array_field: hl.range(t[array_fields_to_combine[0]].length()).map( + lambda i: hl.struct(**{f: t[f][i] for f in array_fields_to_combine}) + ) + } + ).drop(*array_fields_to_combine) diff --git a/gnomad/utils/vep.py b/gnomad/utils/vep.py index e0222dc1d..f15d32e97 100644 --- a/gnomad/utils/vep.py +++ b/gnomad/utils/vep.py @@ -1202,3 +1202,42 @@ def _update_csq_struct(csq_expr: hl.expr.StructExpression): return csq_expr.annotate(**_update_csq_struct(csq_expr)) else: return csq_expr.map(lambda x: x.annotate(**_update_csq_struct(x))) + + +def mane_select_over_canonical_filter_expr( + transcript_expr: hl.expr.StringExpression, + mane_select_expr: hl.expr.BooleanExpression, + canonical_expr: hl.expr.BooleanExpression, + gene_id_expr: hl.expr.StringExpression, +) -> hl.expr.BooleanExpression: + """ + Return a boolean expression selecting MANE Select transcripts with canonical fallback. + + For each gene, selects the MANE Select transcript if one exists; otherwise + falls back to the canonical transcript. Only Ensembl transcripts (ENST + prefix) are considered. + + .. note:: + + In VEP 105 (used for gnomAD v4), all MANE Select transcripts are also + annotated as canonical. As a result, this function produces the same set + of transcripts as a simple canonical filter for v4 data. + + :param transcript_expr: Transcript ID expression (e.g., ``ht.transcript``). + :param mane_select_expr: Boolean expression indicating MANE Select status. + :param canonical_expr: Boolean expression indicating canonical status. + :param gene_id_expr: Gene ID expression used to group transcripts per gene. + :return: Boolean expression that is ``True`` for the selected transcripts. + """ + ht = transcript_expr._indices.source + genes = ht.group_by(gene_id=gene_id_expr).aggregate( + mane_present=hl.agg.any(mane_select_expr), + canonical_present=hl.agg.any(canonical_expr), + ) + genes = genes.annotate(only_canonical=~genes.mane_present & genes.canonical_present) + gene_info = genes[gene_id_expr] + + return transcript_expr.startswith("ENST") & ( + (gene_info.mane_present & mane_select_expr) + | (gene_info.only_canonical & canonical_expr) + ) diff --git a/tests/utils/test_constraint.py b/tests/utils/test_constraint.py new file mode 100644 index 000000000..dc3d49bf1 --- /dev/null +++ b/tests/utils/test_constraint.py @@ -0,0 +1,1256 @@ +"""Tests for the constraint utility module.""" + +import hail as hl +import pytest + +from gnomad.utils.constraint import ( + build_constraint_consequence_groups, + calculate_gerp_cutoffs, + calculate_raw_z_score, + compute_percentile_thresholds, + count_observed_and_possible_by_group, + counts_agg_expr, + get_constraint_grouping_expr, + oe_confidence_interval, + rank_and_assign_bins, + variant_observed_and_possible_expr, + variant_observed_expr, + weighted_sum_agg_expr, +) + + +class TestOeConfidenceInterval: + """Test the oe_confidence_interval function.""" + + @pytest.mark.skipif( + not hasattr(hl, "qgamma"), reason="hl.qgamma requires Hail >= 0.2.137" + ) + def test_gamma_returns_lower_and_upper(self): + """Test that gamma method returns a struct with lower < upper.""" + ht = hl.Table.parallelize( + [{"obs": 10, "exp": 20.0}], + hl.tstruct(obs=hl.tint64, exp=hl.tfloat64), + ) + ht = ht.annotate(ci=oe_confidence_interval(ht.obs, ht.exp, method="gamma")) + result = ht.collect()[0] + + assert result.ci.lower < result.ci.upper + + def test_poisson_returns_lower_and_upper(self): + """Test that poisson method returns a struct with lower < upper.""" + ht = hl.Table.parallelize( + [{"obs": 10, "exp": 20.0}], + hl.tstruct(obs=hl.tint64, exp=hl.tfloat64), + ) + ht = ht.annotate(ci=oe_confidence_interval(ht.obs, ht.exp, method="poisson")) + result = ht.collect()[0] + + assert result.ci.lower < result.ci.upper + + def test_poisson_obs_zero_lower_is_zero(self): + """Test that poisson method returns lower=0 when obs=0.""" + ht = hl.Table.parallelize( + [{"obs": 0, "exp": 10.0}], + hl.tstruct(obs=hl.tint64, exp=hl.tfloat64), + ) + ht = ht.annotate(ci=oe_confidence_interval(ht.obs, ht.exp, method="poisson")) + result = ht.collect()[0] + + assert result.ci.lower == 0 + + def test_invalid_method_raises(self): + """Test that an invalid method raises ValueError.""" + ht = hl.Table.parallelize( + [{"obs": 5, "exp": 10.0}], + hl.tstruct(obs=hl.tint64, exp=hl.tfloat64), + ) + with pytest.raises(ValueError, match="Unknown CI method"): + ht.annotate(ci=oe_confidence_interval(ht.obs, ht.exp, method="invalid")) + + @pytest.mark.skipif( + not hasattr(hl, "qgamma"), reason="hl.qgamma requires Hail >= 0.2.137" + ) + def test_gamma_and_poisson_give_similar_results(self): + """Test that gamma and poisson methods give roughly similar CIs.""" + ht = hl.Table.parallelize( + [{"obs": 15, "exp": 20.0}], + hl.tstruct(obs=hl.tint64, exp=hl.tfloat64), + ) + ht = ht.annotate( + ci_gamma=oe_confidence_interval(ht.obs, ht.exp, method="gamma"), + ci_poisson=oe_confidence_interval(ht.obs, ht.exp, method="poisson"), + ) + result = ht.collect()[0] + + # Both should be in a similar ballpark (within 0.2 of each other). + assert abs(result.ci_gamma.lower - result.ci_poisson.lower) < 0.2 + assert abs(result.ci_gamma.upper - result.ci_poisson.upper) < 0.2 + + +class TestCalculateRawZScore: + """Test the calculate_raw_z_score function.""" + + def test_obs_less_than_exp_positive_z(self): + """Test that fewer observed than expected gives a positive z-score.""" + ht = hl.Table.parallelize( + [{"obs": 5, "exp": 20.0}], + hl.tstruct(obs=hl.tint64, exp=hl.tfloat64), + ) + ht = ht.annotate(z=calculate_raw_z_score(ht.obs, ht.exp)) + result = ht.collect()[0] + + assert result.z > 0 + + def test_obs_greater_than_exp_negative_z(self): + """Test that more observed than expected gives a negative z-score.""" + ht = hl.Table.parallelize( + [{"obs": 30, "exp": 10.0}], + hl.tstruct(obs=hl.tint64, exp=hl.tfloat64), + ) + ht = ht.annotate(z=calculate_raw_z_score(ht.obs, ht.exp)) + result = ht.collect()[0] + + assert result.z < 0 + + def test_obs_equals_exp_zero_z(self): + """Test that obs == exp gives z-score of 0.""" + ht = hl.Table.parallelize( + [{"obs": 10, "exp": 10.0}], + hl.tstruct(obs=hl.tint64, exp=hl.tfloat64), + ) + ht = ht.annotate(z=calculate_raw_z_score(ht.obs, ht.exp)) + result = ht.collect()[0] + + assert result.z == 0.0 + + +class TestCalculateGerpCutoffs: + """Test the calculate_gerp_cutoffs function.""" + + def test_cutoffs_within_expected_range(self): + """Test that GERP cutoffs fall within the data range.""" + ht = hl.Table.parallelize( + [{"gerp": float(i)} for i in range(100)], + hl.tstruct(gerp=hl.tfloat64), + ) + lower, upper = calculate_gerp_cutoffs(ht) + + # 5th percentile should be around 5, 95th around 95. + assert 0 <= lower <= 10 + assert 90 <= upper <= 99 + + def test_custom_percentiles(self): + """Test with custom percentile thresholds.""" + ht = hl.Table.parallelize( + [{"gerp": float(i)} for i in range(100)], + hl.tstruct(gerp=hl.tfloat64), + ) + lower, upper = calculate_gerp_cutoffs( + ht, lower_percentile=0.25, upper_percentile=0.75 + ) + + assert 20 <= lower <= 30 + assert 70 <= upper <= 80 + + def test_custom_gerp_expr(self): + """Test with a custom GERP expression.""" + ht = hl.Table.parallelize( + [{"my_gerp": float(i)} for i in range(100)], + hl.tstruct(my_gerp=hl.tfloat64), + ) + lower, upper = calculate_gerp_cutoffs(ht, gerp_expr=ht.my_gerp) + + assert 0 <= lower <= 10 + assert 90 <= upper <= 99 + + +class TestBuildConstraintConsequenceGroups: + """Test the build_constraint_consequence_groups function.""" + + @pytest.fixture + def sample_table(self): + """Fixture to create a table with consequence and modifier fields.""" + return hl.Table.parallelize( + [ + {"csq": "synonymous_variant", "modifier": "NA"}, + {"csq": "missense_variant", "modifier": "NA"}, + {"csq": "stop_gained", "modifier": "HC"}, + {"csq": "splice_donor_variant", "modifier": "LC"}, + {"csq": "missense_variant", "modifier": "NA"}, + ], + hl.tstruct(csq=hl.tstr, modifier=hl.tstr), + ) + + def test_meta_contains_expected_keys(self, sample_table): + """Test that returned meta dicts contain expected keys.""" + _, meta = build_constraint_consequence_groups( + sample_table.csq, sample_table.modifier + ) + + meta_keys = [set(m.keys()) for m in meta] + assert {"csq_set"} in meta_keys + assert {"lof"} in meta_keys + + def test_default_groups_count(self, sample_table): + """Test that the default number of groups is 5 (syn, mis, classic, hc_lc, hc).""" + filters, meta = build_constraint_consequence_groups( + sample_table.csq, sample_table.modifier + ) + + assert len(meta) == 5 + assert len(filters) == 5 + + def test_syn_filter_selects_correct_rows(self, sample_table): + """Test that the synonymous filter selects only synonymous variants.""" + filters, meta = build_constraint_consequence_groups( + sample_table.csq, sample_table.modifier + ) + + # Find the syn filter. + syn_idx = next(i for i, m in enumerate(meta) if m.get("csq_set") == "syn") + ht = sample_table.annotate(is_syn=filters[syn_idx]) + results = ht.collect() + + assert results[0].is_syn is True # synonymous_variant + assert results[1].is_syn is False # missense_variant + assert results[2].is_syn is False # stop_gained + + def test_hc_lof_filter_selects_correct_rows(self, sample_table): + """Test that the HC LoF filter selects only HC modifier rows.""" + filters, meta = build_constraint_consequence_groups( + sample_table.csq, sample_table.modifier + ) + + hc_idx = next(i for i, m in enumerate(meta) if m.get("lof") == "hc") + ht = sample_table.annotate(is_hc=filters[hc_idx]) + results = ht.collect() + + assert results[2].is_hc is True # stop_gained, HC + assert results[3].is_hc is False # splice_donor_variant, LC + + def test_meta_values(self, sample_table): + """Test that meta contains the expected value combinations.""" + _, meta = build_constraint_consequence_groups( + sample_table.csq, sample_table.modifier + ) + + meta_values = [tuple(sorted(m.items())) for m in meta] + assert (("csq_set", "syn"),) in meta_values + assert (("csq_set", "mis"),) in meta_values + assert (("lof", "classic"),) in meta_values + assert (("lof", "hc_lc"),) in meta_values + assert (("lof", "hc"),) in meta_values + + +class TestRankAndAssignBins: + """Test the rank_and_assign_bins function.""" + + @pytest.fixture + def sample_table(self): + """Fixture to create a keyed table with float values.""" + return hl.Table.parallelize( + [{"id": i, "val": float(i)} for i in range(10)], + hl.tstruct(id=hl.tint32, val=hl.tfloat64), + key="id", + ) + + def test_rank_ordering_ascending(self, sample_table): + """Test that ranks are assigned in ascending order of value.""" + result = sample_table.annotate(bins=rank_and_assign_bins(sample_table.val)) + rows = result.order_by("id").collect() + + # Values 0..9 should get ranks 0..9. + for i, r in enumerate(rows): + assert r.bins.rank == i + + def test_default_bin_fields(self, sample_table): + """Test that default bin granularities produce expected fields.""" + result = sample_table.annotate(bins=rank_and_assign_bins(sample_table.val)) + row = result.collect()[0] + + assert hasattr(row.bins, "bin_percentile") + assert hasattr(row.bins, "bin_decile") + assert hasattr(row.bins, "bin_sextile") + + def test_custom_bin_granularities(self, sample_table): + """Test with custom bin granularities.""" + result = sample_table.annotate( + bins=rank_and_assign_bins( + sample_table.val, bin_granularities={"quintile": 5} + ) + ) + rows = result.order_by("id").collect() + + # 10 rows, quintile bins: 0,0, 1,1, 2,2, 3,3, 4,4 + bins = [r.bins.bin_quintile for r in rows] + assert bins == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4] + + def test_decile_bins_range(self, sample_table): + """Test that decile bins are in the range [0, 9].""" + result = sample_table.annotate(bins=rank_and_assign_bins(sample_table.val)) + rows = result.collect() + deciles = [r.bins.bin_decile for r in rows] + + assert min(deciles) == 0 + assert max(deciles) == 9 + + def test_single_row(self): + """Test ranking a single-row table.""" + ht = hl.Table.parallelize( + [{"id": 1, "val": 5.0}], + hl.tstruct(id=hl.tint32, val=hl.tfloat64), + key="id", + ) + result = ht.annotate(bins=rank_and_assign_bins(ht.val)) + row = result.collect()[0] + + assert row.bins.rank == 0 + assert row.bins.bin_decile == 0 + assert row.bins.bin_percentile == 0 + + def test_tied_values(self): + """Test that tied values all receive the same bin.""" + ht = hl.Table.parallelize( + [{"id": i, "val": 1.0} for i in range(10)], + hl.tstruct(id=hl.tint32, val=hl.tfloat64), + key="id", + ) + result = ht.annotate( + bins=rank_and_assign_bins(ht.val, bin_granularities={"half": 2}) + ) + rows = result.collect() + bins = {r.bins.bin_half for r in rows} + + # All same value — ranks are arbitrary but bins should be contiguous. + assert bins.issubset({0, 1}) + + def test_descending_input_still_ranks_ascending(self): + """Test that input order does not affect ascending rank assignment.""" + ht = hl.Table.parallelize( + [{"id": i, "val": float(9 - i)} for i in range(10)], + hl.tstruct(id=hl.tint32, val=hl.tfloat64), + key="id", + ) + result = ht.annotate(bins=rank_and_assign_bins(ht.val)) + rows = result.order_by("id").collect() + + # id=0 has val=9 (largest) -> rank 9; id=9 has val=0 (smallest) -> rank 0. + assert rows[0].bins.rank == 9 + assert rows[9].bins.rank == 0 + + def test_negative_values(self): + """Test ranking with negative values.""" + ht = hl.Table.parallelize( + [{"id": i, "val": float(i - 5)} for i in range(10)], + hl.tstruct(id=hl.tint32, val=hl.tfloat64), + key="id", + ) + result = ht.annotate(bins=rank_and_assign_bins(ht.val)) + rows = result.order_by("id").collect() + + # val ranges from -5 to 4; id=0 (val=-5) should get rank 0. + assert rows[0].bins.rank == 0 + assert rows[9].bins.rank == 9 + + +class TestComputeOeUpperPercentileThresholds: + """Test the compute_percentile_thresholds function.""" + + @pytest.fixture + def sample_table(self): + """Fixture to create a table with known metric values.""" + return hl.Table.parallelize( + [{"metric": float(i) / 100.0} for i in range(101)], + hl.tstruct(metric=hl.tfloat64), + ) + + def test_returns_dict_with_correct_keys(self, sample_table): + """Test that the returned dict has keys matching requested percentiles.""" + result = compute_percentile_thresholds( + sample_table, sample_table.metric, percentiles=(10, 50, 90) + ) + + assert set(result.keys()) == {10, 50, 90} + + def test_thresholds_in_ascending_order(self, sample_table): + """Test that thresholds are in ascending order for ascending percentiles.""" + result = compute_percentile_thresholds( + sample_table, sample_table.metric, percentiles=(10, 25, 50, 75) + ) + + values = [result[p] for p in (10, 25, 50, 75)] + assert values == sorted(values) + + def test_thresholds_within_data_range(self, sample_table): + """Test that thresholds fall within the data range.""" + result = compute_percentile_thresholds( + sample_table, sample_table.metric, percentiles=(10, 50, 90) + ) + + for v in result.values(): + assert 0.0 <= v <= 1.0 + + def test_outlier_filtering(self): + """Test that outlier rows are excluded from threshold computation.""" + ht = hl.Table.parallelize( + [{"metric": float(i) / 10.0, "is_outlier": i > 8} for i in range(10)], + hl.tstruct(metric=hl.tfloat64, is_outlier=hl.tbool), + ) + + result = compute_percentile_thresholds( + ht, ht.metric, outlier_expr=ht.is_outlier, percentiles=(50,) + ) + + # With outliers excluded, median should be around 0.4 (values 0..8 / 10). + assert result[50] < 0.5 + + def test_transcript_filter(self): + """Test that transcript_filter_expr restricts which rows are used.""" + ht = hl.Table.parallelize( + [ + {"metric": 0.1, "include": True}, + {"metric": 0.2, "include": True}, + {"metric": 0.9, "include": False}, + {"metric": 1.0, "include": False}, + ], + hl.tstruct(metric=hl.tfloat64, include=hl.tbool), + ) + + result = compute_percentile_thresholds( + ht, + ht.metric, + transcript_filter_expr=ht.include, + percentiles=(50,), + ) + + # Only 0.1 and 0.2 included, so median should be ~0.15. + assert result[50] < 0.3 + + def test_default_percentiles(self): + """Test that default percentiles (1,5,10,15,25,50,75) are used.""" + ht = hl.Table.parallelize( + [{"metric": float(i) / 100.0} for i in range(101)], + hl.tstruct(metric=hl.tfloat64), + ) + + result = compute_percentile_thresholds(ht, ht.metric) + assert set(result.keys()) == {1, 5, 10, 15, 25, 50, 75} + + def test_missing_metric_excluded(self): + """Test that rows with missing metric values are excluded.""" + ht = hl.Table.parallelize( + [ + {"metric": 0.1}, + {"metric": 0.2}, + {"metric": None}, + {"metric": None}, + ], + hl.tstruct(metric=hl.tfloat64), + ) + + result = compute_percentile_thresholds(ht, ht.metric, percentiles=(50,)) + + # Only 0.1 and 0.2 are non-missing, median ~ 0.15. + assert result[50] < 0.3 + + def test_no_filters_includes_all_rows(self): + """Test that omitting outlier and transcript filters includes all rows.""" + ht = hl.Table.parallelize( + [{"metric": float(i)} for i in range(10)], + hl.tstruct(metric=hl.tfloat64), + ) + + result_all = compute_percentile_thresholds(ht, ht.metric, percentiles=(50,)) + result_filtered = compute_percentile_thresholds( + ht, + ht.metric, + outlier_expr=hl.literal(False), + transcript_filter_expr=hl.literal(True), + percentiles=(50,), + ) + + # Should be the same when filters are effectively no-ops. + assert abs(result_all[50] - result_filtered[50]) < 0.01 + + def test_combined_outlier_and_transcript_filter(self): + """Test that outlier and transcript filters are combined correctly.""" + ht = hl.Table.parallelize( + [ + {"metric": 0.1, "include": True, "is_outlier": False}, + {"metric": 0.5, "include": True, "is_outlier": True}, + {"metric": 0.9, "include": False, "is_outlier": False}, + ], + hl.tstruct(metric=hl.tfloat64, include=hl.tbool, is_outlier=hl.tbool), + ) + + result = compute_percentile_thresholds( + ht, + ht.metric, + outlier_expr=ht.is_outlier, + transcript_filter_expr=ht.include, + percentiles=(50,), + ) + + # Only first row passes both filters, so median is ~0.1. + assert abs(result[50] - 0.1) < 0.05 + + +class TestRankVsThresholdBinning: + """Compare rank_and_assign_bins and compute_percentile_thresholds. + + rank_and_assign_bins assigns bins by exact row rank (deterministic for + distinct values, arbitrary tie-breaking for duplicates). + compute_percentile_thresholds computes approximate quantile + thresholds — binning by value comparison puts all tied rows in the same bin. + + These should agree when all values are distinct but can diverge when many + values are tied. + """ + + def test_distinct_values_bins_agree(self): + """Test that both methods agree when all values are distinct.""" + n = 100 + ht = hl.Table.parallelize( + [{"id": i, "val": float(i) / n} for i in range(n)], + hl.tstruct(id=hl.tint32, val=hl.tfloat64), + key="id", + ) + + # Rank-based decile bins. + ranked = ht.annotate( + bins=rank_and_assign_bins(ht.val, bin_granularities={"decile": 10}) + ) + rank_rows = ranked.order_by("id").collect() + rank_bins = {r.id: r.bins.bin_decile for r in rank_rows} + + # Threshold-based decile bins (boundaries at 10th, 20th, ..., 90th). + thresholds = compute_percentile_thresholds( + ht, ht.val, percentiles=tuple(range(10, 100, 10)) + ) + threshold_list = [thresholds[p] for p in sorted(thresholds)] + + rows = ht.order_by("id").collect() + for r in rows: + rank_bin = rank_bins[r.id] + # Threshold bin: count how many thresholds the value exceeds. + thresh_bin = sum(1 for t in threshold_list if r.val >= t) + assert ( + rank_bin == thresh_bin + ), f"id={r.id}, val={r.val}: rank_bin={rank_bin}, thresh_bin={thresh_bin}" + + def test_tied_values_threshold_bins_consistent(self): + """Test that threshold-based bins put all tied values in the same bin. + + rank_and_assign_bins may split tied values across bins because it + assigns unique ranks to each row regardless of ties. + """ + # 50 rows at val=0.1 and 50 rows at val=0.9. + ht = hl.Table.parallelize( + [{"id": i, "val": 0.1 if i < 50 else 0.9} for i in range(100)], + hl.tstruct(id=hl.tint32, val=hl.tfloat64), + key="id", + ) + + # Threshold-based: all 0.1 rows must be in the same bin, all 0.9 in another. + thresholds = compute_percentile_thresholds(ht, ht.val, percentiles=(50,)) + rows = ht.collect() + low_bins = set() + high_bins = set() + for r in rows: + bin_val = sum(1 for t in thresholds.values() if r.val >= t) + if r.val < 0.5: + low_bins.add(bin_val) + else: + high_bins.add(bin_val) + + # All 0.1 rows get the same bin, all 0.9 rows get the same bin. + assert len(low_bins) == 1 + assert len(high_bins) == 1 + assert low_bins != high_bins + + def test_tied_values_rank_bins_may_split(self): + """Test that rank_and_assign_bins splits tied values across bins. + + With 100 identical values and 10 decile bins, rank-based binning will + arbitrarily assign ~10 rows per bin even though the values are the same. + """ + ht = hl.Table.parallelize( + [{"id": i, "val": 1.0} for i in range(100)], + hl.tstruct(id=hl.tint32, val=hl.tfloat64), + key="id", + ) + + ranked = ht.annotate( + bins=rank_and_assign_bins(ht.val, bin_granularities={"decile": 10}) + ) + rows = ranked.collect() + deciles = {r.bins.bin_decile for r in rows} + + # All values are identical, but rank bins still spread across deciles. + assert len(deciles) > 1 + + def test_mostly_tied_with_outlier_diverges(self): + """Test divergence when most values are tied with one outlier. + + Threshold-based: outlier alone in top bin, all tied values in bin 0. + Rank-based: tied values spread across multiple bins, outlier in top bin. + """ + ht = hl.Table.parallelize( + [{"id": i, "val": 0.5 if i < 99 else 10.0} for i in range(100)], + hl.tstruct(id=hl.tint32, val=hl.tfloat64), + key="id", + ) + + # Threshold at the 50th percentile. + thresholds = compute_percentile_thresholds(ht, ht.val, percentiles=(50,)) + # The 50th pct threshold should be ~0.5 (the tied value), so all 0.5 + # rows fall at or above it → threshold bin 1, and the outlier also bin 1. + # But with rank-based bins, about half the 0.5 rows are in bin 0 and + # half in bin 1. + ranked = ht.annotate( + bins=rank_and_assign_bins(ht.val, bin_granularities={"half": 2}) + ) + rank_rows = ranked.collect() + + # Rank-based: the 0.5-valued rows are split across bin 0 and bin 1. + tied_rank_bins = {r.bins.bin_half for r in rank_rows if r.val == 0.5} + assert len(tied_rank_bins) == 2, "Rank-based should split tied values" + + # Threshold-based: all 0.5-valued rows are in the same bin. + tied_thresh_bins = set() + for r in rank_rows: + if r.val == 0.5: + tied_thresh_bins.add(sum(1 for t in thresholds.values() if r.val >= t)) + assert len(tied_thresh_bins) == 1, "Threshold-based should not split ties" + + +class TestSingleVariantCountExpr: + """Test the variant_observed_expr function.""" + + def test_ac_positive_counts_as_one(self): + """Test that a variant with AC > 0 counts as 1.""" + ht = hl.Table.parallelize( + [{"freq": hl.Struct(AC=5, AF=0.01)}], + hl.tstruct(freq=hl.tstruct(AC=hl.tint32, AF=hl.tfloat64)), + ) + ht = ht.annotate(count=variant_observed_expr(freq_expr=ht.freq)) + result = ht.collect()[0] + + assert result.count == 1 + + def test_ac_zero_counts_as_zero(self): + """Test that a variant with AC == 0 counts as 0.""" + ht = hl.Table.parallelize( + [{"freq": hl.Struct(AC=0, AF=0.0)}], + hl.tstruct(freq=hl.tstruct(AC=hl.tint32, AF=hl.tfloat64)), + ) + ht = ht.annotate(count=variant_observed_expr(freq_expr=ht.freq)) + result = ht.collect()[0] + + assert result.count == 0 + + def test_singleton_counts_only_ac_one(self): + """Test that singleton mode only counts variants with AC == 1.""" + ht = hl.Table.parallelize( + [ + {"freq": hl.Struct(AC=1, AF=0.001)}, + {"freq": hl.Struct(AC=5, AF=0.01)}, + ], + hl.tstruct(freq=hl.tstruct(AC=hl.tint32, AF=hl.tfloat64)), + ) + ht = ht.annotate(count=variant_observed_expr(freq_expr=ht.freq, singleton=True)) + results = ht.collect() + + assert results[0].count == 1 + assert results[1].count == 0 + + def test_max_af_filters_by_frequency(self): + """Test that max_af filters variants by allele frequency.""" + ht = hl.Table.parallelize( + [ + {"freq": hl.Struct(AC=5, AF=0.001)}, + {"freq": hl.Struct(AC=5, AF=0.1)}, + ], + hl.tstruct(freq=hl.tstruct(AC=hl.tint32, AF=hl.tfloat64)), + ) + ht = ht.annotate(count=variant_observed_expr(freq_expr=ht.freq, max_af=0.01)) + results = ht.collect() + + assert results[0].count == 1 + assert results[1].count == 0 + + def test_no_freq_counts_as_one(self): + """Test that when no freq_expr is provided and no filtering, count is 1.""" + ht = hl.Table.parallelize( + [{"x": 1}], + hl.tstruct(x=hl.tint32), + ) + ht = ht.annotate(count=variant_observed_expr(ht=ht)) + result = ht.collect()[0] + + assert result.count == 1 + + def test_raises_when_no_ht_or_freq(self): + """Test that ValueError is raised when neither ht nor freq_expr is given.""" + with pytest.raises(ValueError, match="Either ht or freq_expr"): + variant_observed_expr() + + def test_max_af_zero_filters_all(self): + """Test that max_af=0.0 filters out all variants (AF cannot be <= 0 with AC > 0).""" + ht = hl.Table.parallelize( + [ + {"freq": hl.Struct(AC=1, AF=0.001)}, + {"freq": hl.Struct(AC=5, AF=0.01)}, + ], + hl.tstruct(freq=hl.tstruct(AC=hl.tint32, AF=hl.tfloat64)), + ) + ht = ht.annotate(count=variant_observed_expr(freq_expr=ht.freq, max_af=0.0)) + results = ht.collect() + + assert results[0].count == 0 + assert results[1].count == 0 + + def test_ht_fallback_to_freq_field(self): + """Test that passing ht without freq_expr falls back to ht.freq.""" + ht = hl.Table.parallelize( + [ + {"freq": hl.Struct(AC=1, AF=0.001)}, + {"freq": hl.Struct(AC=0, AF=0.0)}, + ], + hl.tstruct(freq=hl.tstruct(AC=hl.tint32, AF=hl.tfloat64)), + ) + ht = ht.annotate(count=variant_observed_expr(ht=ht, max_af=0.01)) + results = ht.collect() + + assert results[0].count == 1 + assert results[1].count == 0 + + +class TestVariantObservedAndPossibleExpr: + """Test the variant_observed_and_possible_expr function.""" + + @pytest.fixture + def freq_table(self): + """Create a Table with a frequency array.""" + return hl.Table.parallelize( + [ + {"freq": [hl.Struct(AC=5, AF=0.001), hl.Struct(AC=3, AF=0.0005)]}, + {"freq": [hl.Struct(AC=0, AF=0.0), hl.Struct(AC=0, AF=0.0)]}, + ], + hl.tstruct(freq=hl.tarray(hl.tstruct(AC=hl.tint32, AF=hl.tfloat64))), + ) + + def test_observed_is_array(self, freq_table): + """Test that observed_variants is an array with one entry per freq element.""" + ht = freq_table.annotate(**variant_observed_and_possible_expr(freq_table.freq)) + result = ht.collect()[0] + assert len(result.observed_variants) == 2 + + def test_observed_variant_counted(self, freq_table): + """Test that a variant with AC > 0 has observed_variants of 1.""" + ht = freq_table.annotate(**variant_observed_and_possible_expr(freq_table.freq)) + result = ht.collect()[0] + assert result.observed_variants == [1, 1] + + def test_unobserved_variant(self, freq_table): + """Test that a variant with AC == 0 has observed_variants of 0.""" + ht = freq_table.annotate(**variant_observed_and_possible_expr(freq_table.freq)) + result = ht.collect()[1] + assert result.observed_variants == [0, 0] + + def test_possible_adj_is_scalar(self, freq_table): + """Test that possible_variants is a scalar when use_possible_adj is True.""" + ht = freq_table.annotate( + **variant_observed_and_possible_expr(freq_table.freq, use_possible_adj=True) + ) + result = ht.collect()[0] + assert isinstance(result.possible_variants, int) + + def test_possible_adj_counts_observed(self, freq_table): + """Test that possible_variants is 1 for an observed variant.""" + ht = freq_table.annotate(**variant_observed_and_possible_expr(freq_table.freq)) + result = ht.collect()[0] + assert result.possible_variants == 1 + + def test_possible_adj_counts_unobserved(self, freq_table): + """Test that possible_variants is 0 for an unobserved variant with defined freq.""" + ht = freq_table.annotate(**variant_observed_and_possible_expr(freq_table.freq)) + result = ht.collect()[1] + # AC == 0 with defined freq: not possible (count_missing only counts None). + assert result.possible_variants == 0 + + def test_possible_no_adj_is_array(self, freq_table): + """Test that possible_variants is an array when use_possible_adj is False.""" + ht = freq_table.annotate( + **variant_observed_and_possible_expr( + freq_table.freq, use_possible_adj=False + ) + ) + result = ht.collect()[0] + assert isinstance(result.possible_variants, list) + assert len(result.possible_variants) == 2 + + def test_possible_counts_missing_freq(self): + """Test that possible_variants is 1 when frequency is missing.""" + ht = hl.Table.parallelize( + [{"freq": [None, hl.Struct(AC=3, AF=0.0005)]}], + hl.tstruct(freq=hl.tarray(hl.tstruct(AC=hl.tint32, AF=hl.tfloat64))), + ) + ht = ht.annotate( + **variant_observed_and_possible_expr(ht.freq, use_possible_adj=False) + ) + result = ht.collect()[0] + # Missing freq element should still count as possible. + assert result.possible_variants[0] == 1 + + def test_max_af_filters_observed(self): + """Test that max_af filters out variants with AF above threshold.""" + ht = hl.Table.parallelize( + [{"freq": [hl.Struct(AC=5, AF=0.01), hl.Struct(AC=3, AF=0.0005)]}], + hl.tstruct(freq=hl.tarray(hl.tstruct(AC=hl.tint32, AF=hl.tfloat64))), + ) + ht = ht.annotate(**variant_observed_and_possible_expr(ht.freq, max_af=0.001)) + result = ht.collect()[0] + # First element AF=0.01 > max_af, second AF=0.0005 <= max_af. + assert result.observed_variants == [0, 1] + + +class TestGetCountsAggExpr: + """Test the counts_agg_expr function.""" + + @pytest.fixture + def sample_table(self): + """Fixture to create a table with frequency data.""" + return hl.Table.parallelize( + [ + {"freq": hl.Struct(AC=1, AF=0.001)}, + {"freq": hl.Struct(AC=5, AF=0.01)}, + {"freq": hl.Struct(AC=0, AF=0.0)}, + {"freq": hl.Struct(AC=3, AF=0.05)}, + ], + hl.tstruct(freq=hl.tstruct(AC=hl.tint32, AF=hl.tfloat64)), + ) + + def test_variant_count_no_filter(self, sample_table): + """Test variant count with no filtering (AC > 0).""" + result = sample_table.aggregate(counts_agg_expr(freq_expr=sample_table.freq)) + + # 3 variants have AC > 0. + assert result.variant_count == 3 + + def test_variant_count_with_max_af(self, sample_table): + """Test variant count with max_af filter.""" + result = sample_table.aggregate( + counts_agg_expr(freq_expr=sample_table.freq, max_af=0.01) + ) + + # AC=1/AF=0.001 and AC=5/AF=0.01 pass the filter. + assert result.variant_count == 2 + + def test_singleton_count(self, sample_table): + """Test that singleton count is returned when requested.""" + result = sample_table.aggregate( + counts_agg_expr(freq_expr=sample_table.freq, count_singletons=True) + ) + + assert result.singleton_count == 1 + assert result.variant_count == 3 + + def test_no_singleton_key_by_default(self, sample_table): + """Test that singleton_count is not present when not requested.""" + result = sample_table.aggregate(counts_agg_expr(freq_expr=sample_table.freq)) + + assert not hasattr(result, "singleton_count") + + def test_raises_when_no_ht_or_freq(self): + """Test that ValueError is raised when neither ht nor freq_expr is given.""" + with pytest.raises(ValueError, match="Either ht or freq_expr"): + counts_agg_expr(freq_expr=None, ht=None) + + def test_ht_only_no_freq_expr(self, sample_table): + """Test that passing ht without freq_expr falls back to ht.freq.""" + result = sample_table.aggregate(counts_agg_expr(ht=sample_table)) + + # Same as test_variant_count_no_filter: 3 variants have AC > 0. + assert result.variant_count == 3 + + def test_max_af_zero_counts_none(self, sample_table): + """Test that max_af=0.0 counts no variants.""" + result = sample_table.aggregate( + counts_agg_expr(freq_expr=sample_table.freq, max_af=0.0) + ) + + assert result.variant_count == 0 + + +class TestWeightedAggSumExpr: + """Test the weighted_sum_agg_expr function.""" + + def test_scalar_scalar(self): + """Test weighted sum with two scalar expressions.""" + ht = hl.Table.parallelize( + [{"val": 2.0, "weight": 3.0}, {"val": 4.0, "weight": 5.0}], + hl.tstruct(val=hl.tfloat64, weight=hl.tfloat64), + ) + result = ht.aggregate(weighted_sum_agg_expr(ht.val, ht.weight)) + + # 2*3 + 4*5 = 26 + assert result == 26.0 + + def test_array_array(self): + """Test weighted sum with two array expressions (pairwise multiply).""" + ht = hl.Table.parallelize( + [ + {"val": [1.0, 2.0], "weight": [3.0, 4.0]}, + {"val": [5.0, 6.0], "weight": [7.0, 8.0]}, + ], + hl.tstruct(val=hl.tarray(hl.tfloat64), weight=hl.tarray(hl.tfloat64)), + ) + result = ht.aggregate(weighted_sum_agg_expr(ht.val, ht.weight)) + + # element 0: 1*3 + 5*7 = 38, element 1: 2*4 + 6*8 = 56 + assert result == [38.0, 56.0] + + def test_scalar_array_mixed(self): + """Test weighted sum with scalar expr and array weight (broadcast).""" + ht = hl.Table.parallelize( + [ + {"val": 2.0, "weight": [1.0, 10.0]}, + {"val": 3.0, "weight": [1.0, 10.0]}, + ], + hl.tstruct(val=hl.tfloat64, weight=hl.tarray(hl.tfloat64)), + ) + result = ht.aggregate(weighted_sum_agg_expr(ht.val, ht.weight)) + + # element 0: 2*1 + 3*1 = 5, element 1: 2*10 + 3*10 = 50 + assert result == [5.0, 50.0] + + def test_array_scalar_mixed(self): + """Test weighted sum with array expr and scalar weight (broadcast).""" + ht = hl.Table.parallelize( + [ + {"val": [1.0, 2.0], "weight": 3.0}, + {"val": [4.0, 5.0], "weight": 6.0}, + ], + hl.tstruct(val=hl.tarray(hl.tfloat64), weight=hl.tfloat64), + ) + result = ht.aggregate(weighted_sum_agg_expr(ht.val, ht.weight)) + + # element 0: 1*3 + 4*6 = 27, element 1: 2*3 + 5*6 = 36 + assert result == [27.0, 36.0] + + def test_single_row(self): + """Test weighted sum with a single row.""" + ht = hl.Table.parallelize( + [{"val": 5.0, "weight": 2.0}], + hl.tstruct(val=hl.tfloat64, weight=hl.tfloat64), + ) + result = ht.aggregate(weighted_sum_agg_expr(ht.val, ht.weight)) + + assert result == 10.0 + + +class TestCountObservedAndPossibleByGroup: + """Test the count_observed_and_possible_by_group function.""" + + @pytest.fixture + def sample_table(self): + """Fixture to create a table with context, ref, alt, and count fields.""" + return hl.Table.parallelize( + [ + { + "context": "ACG", + "ref": "C", + "alt": "T", + "methylation_level": 0, + "obs": [1, 0], + "poss": 1, + }, + { + "context": "ACG", + "ref": "C", + "alt": "T", + "methylation_level": 0, + "obs": [1, 1], + "poss": 1, + }, + { + "context": "TCG", + "ref": "C", + "alt": "A", + "methylation_level": 1, + "obs": [0, 0], + "poss": 1, + }, + ], + hl.tstruct( + context=hl.tstr, + ref=hl.tstr, + alt=hl.tstr, + methylation_level=hl.tint32, + obs=hl.tarray(hl.tint32), + poss=hl.tint32, + ), + ) + + def test_basic_grouping(self, sample_table): + """Test that rows are grouped by context, ref, alt, methylation_level.""" + result = count_observed_and_possible_by_group( + sample_table, + possible_expr=sample_table.poss, + observed_expr=sample_table.obs, + ) + rows = result.collect() + + # Two groups: (ACG, C, T, 0) and (TCG, C, A, 1). + assert len(rows) == 2 + + def test_observed_summed(self, sample_table): + """Test that observed arrays are summed within groups.""" + result = count_observed_and_possible_by_group( + sample_table, + possible_expr=sample_table.poss, + observed_expr=sample_table.obs, + ) + rows = {r.context: r for r in result.collect()} + + # ACG group: [1,0] + [1,1] = [2,1] + assert rows["ACG"].observed_variants == [2, 1] + # TCG group: [0,0] + assert rows["TCG"].observed_variants == [0, 0] + + def test_possible_summed(self, sample_table): + """Test that possible counts are summed within groups.""" + result = count_observed_and_possible_by_group( + sample_table, + possible_expr=sample_table.poss, + observed_expr=sample_table.obs, + ) + rows = {r.context: r for r in result.collect()} + + assert rows["ACG"].possible_variants == 2 + assert rows["TCG"].possible_variants == 1 + + def test_no_additional_grouping(self, sample_table): + """Test grouping without methylation_level.""" + result = count_observed_and_possible_by_group( + sample_table, + possible_expr=sample_table.poss, + observed_expr=sample_table.obs, + additional_grouping=(), + ) + rows = result.collect() + + # Without methylation_level, still two groups because context/ref/alt differ. + assert len(rows) == 2 + + def test_weight_exprs_dict(self, sample_table): + """Test that weight_exprs produces a weighted sum field.""" + sample_table = sample_table.annotate(mu=0.5) + result = count_observed_and_possible_by_group( + sample_table, + possible_expr=sample_table.poss, + observed_expr=sample_table.obs, + weight_exprs={"weighted_poss": sample_table.mu}, + ) + rows = {r.context: r for r in result.collect()} + + # ACG: poss=1*0.5 + poss=1*0.5 = 1.0 + assert abs(rows["ACG"].weighted_poss - 1.0) < 1e-6 + + def test_additional_agg_sum_exprs(self, sample_table): + """Test that additional_agg_sum_exprs sums extra fields.""" + sample_table = sample_table.annotate(extra=1) + result = count_observed_and_possible_by_group( + sample_table, + possible_expr=sample_table.poss, + observed_expr=sample_table.obs, + additional_agg_sum_exprs={"extra": sample_table.extra}, + ) + rows = {r.context: r for r in result.collect()} + + assert rows["ACG"].extra == 2 + assert rows["TCG"].extra == 1 + + +class TestGetConstraintGroupingExpr: + """Test the get_constraint_grouping_expr function.""" + + @pytest.fixture + def vep_table(self): + """Fixture to create a table with VEP annotation struct.""" + return hl.Table.parallelize( + [ + { + "vep": hl.Struct( + most_severe_consequence="missense_variant", + lof=None, + polyphen_prediction="probably_damaging", + gene_symbol="BRCA1", + gene_id="ENSG00000012048", + transcript_id="ENST00000357654", + canonical=1, + mane_select="NM_007294.4", + ), + }, + { + "vep": hl.Struct( + most_severe_consequence="stop_gained", + lof="HC", + polyphen_prediction=None, + gene_symbol="TP53", + gene_id="ENSG00000141510", + transcript_id="ENST00000269305", + canonical=0, + mane_select=None, + ), + }, + ], + hl.tstruct( + vep=hl.tstruct( + most_severe_consequence=hl.tstr, + lof=hl.tstr, + polyphen_prediction=hl.tstr, + gene_symbol=hl.tstr, + gene_id=hl.tstr, + transcript_id=hl.tstr, + canonical=hl.tint32, + mane_select=hl.tstr, + ), + ), + ) + + def test_default_fields(self, vep_table): + """Test that default groupings include annotation, modifier, gene, gene_id, transcript, canonical.""" + groupings = get_constraint_grouping_expr(vep_table.vep) + + assert set(groupings.keys()) == { + "annotation", + "modifier", + "gene", + "gene_id", + "transcript", + "canonical", + } + + def test_modifier_uses_lof_over_polyphen(self, vep_table): + """Test that modifier prefers lof when present, falls back to polyphen.""" + vep_table = vep_table.annotate(**get_constraint_grouping_expr(vep_table.vep)) + rows = vep_table.collect() + + # Row 0: lof is None, polyphen is "probably_damaging" + assert rows[0].modifier == "probably_damaging" + # Row 1: lof is "HC" + assert rows[1].modifier == "HC" + + def test_modifier_falls_back_to_none_string(self): + """Test that modifier is 'None' when both lof and polyphen are missing.""" + ht = hl.Table.parallelize( + [ + { + "vep": hl.Struct( + most_severe_consequence="synonymous_variant", + lof=None, + polyphen_prediction=None, + gene_symbol="GENE1", + gene_id="ENSG00000000001", + transcript_id="ENST00000000001", + canonical=1, + mane_select=None, + ), + }, + ], + hl.tstruct( + vep=hl.tstruct( + most_severe_consequence=hl.tstr, + lof=hl.tstr, + polyphen_prediction=hl.tstr, + gene_symbol=hl.tstr, + gene_id=hl.tstr, + transcript_id=hl.tstr, + canonical=hl.tint32, + mane_select=hl.tstr, + ), + ), + ) + ht = ht.annotate(**get_constraint_grouping_expr(ht.vep)) + result = ht.collect()[0] + + assert result.modifier == "None" + + def test_canonical_is_boolean(self, vep_table): + """Test that canonical is converted to a boolean.""" + vep_table = vep_table.annotate(**get_constraint_grouping_expr(vep_table.vep)) + rows = vep_table.collect() + + assert rows[0].canonical is True # canonical=1 + assert rows[1].canonical is False # canonical=0 + + def test_include_mane_select(self, vep_table): + """Test that mane_select is included when requested.""" + groupings = get_constraint_grouping_expr( + vep_table.vep, include_mane_select_group=True + ) + + assert "mane_select" in groupings + + vep_table = vep_table.annotate(**groupings) + rows = vep_table.collect() + + assert rows[0].mane_select is True # has mane_select value + assert rows[1].mane_select is False # mane_select is None + + def test_exclude_transcript_and_canonical(self, vep_table): + """Test that transcript and canonical can be excluded.""" + groupings = get_constraint_grouping_expr( + vep_table.vep, + include_transcript_group=False, + include_canonical_group=False, + ) + + assert "transcript" not in groupings + assert "canonical" not in groupings + assert set(groupings.keys()) == {"annotation", "modifier", "gene", "gene_id"} + + def test_coverage_expr_included(self, vep_table): + """Test that coverage is included when coverage_expr is provided.""" + vep_table = vep_table.annotate(cov=30) + groupings = get_constraint_grouping_expr( + vep_table.vep, coverage_expr=vep_table.cov + ) + + assert "coverage" in groupings + + def test_polyphen_missing_from_struct(self): + """Test that missing polyphen_prediction field is handled gracefully.""" + ht = hl.Table.parallelize( + [ + { + "vep": hl.Struct( + most_severe_consequence="stop_gained", + lof=None, + gene_symbol="GENE1", + gene_id="ENSG00000000001", + transcript_id="ENST00000000001", + canonical=1, + mane_select=None, + ), + }, + ], + hl.tstruct( + vep=hl.tstruct( + most_severe_consequence=hl.tstr, + lof=hl.tstr, + gene_symbol=hl.tstr, + gene_id=hl.tstr, + transcript_id=hl.tstr, + canonical=hl.tint32, + mane_select=hl.tstr, + ), + ), + ) + ht = ht.annotate(**get_constraint_grouping_expr(ht.vep)) + result = ht.collect()[0] + + # lof is None, polyphen is missing from struct → modifier should be "None" + assert result.modifier == "None" diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py new file mode 100644 index 000000000..9ac049487 --- /dev/null +++ b/tests/utils/test_file_utils.py @@ -0,0 +1,151 @@ +"""Tests for the file_utils utility module.""" + +import logging + +import hail as hl + +from gnomad.utils.file_utils import ( + convert_multi_array_to_array_of_structs, + print_global_struct, +) + + +class TestPrintGlobalStruct: + """Test the print_global_struct function.""" + + def test_with_table(self, caplog): + """Test output format when passing a Table.""" + ht = hl.Table.parallelize( + [{"x": 1}], + hl.tstruct(x=hl.tint32), + ) + ht = ht.annotate_globals(foo="bar", count=42) + + with caplog.at_level(logging.INFO): + print_global_struct(ht) + + assert "foo: bar" in caplog.text + assert "count: 42" in caplog.text + + def test_with_struct(self, caplog): + """Test output format with an already-evaluated Struct.""" + s = hl.Struct(a=1, b="hello", nested=hl.Struct(c=3.0)) + + with caplog.at_level(logging.INFO): + print_global_struct(s) + + assert "a: 1" in caplog.text + assert "b: hello" in caplog.text + assert "c: 3.0" in caplog.text + + def test_with_struct_expression(self, caplog): + """Test output format when passing a StructExpression.""" + ht = hl.Table.parallelize( + [{"x": 1}], + hl.tstruct(x=hl.tint32), + ) + ht = ht.annotate_globals(foo="bar", count=42) + + with caplog.at_level(logging.INFO): + print_global_struct(ht.globals) + + assert "foo: bar" in caplog.text + assert "count: 42" in caplog.text + + def test_nested_indentation(self, caplog): + """Test that nested structs are indented deeper than top-level fields.""" + s = hl.Struct(top="val", nested=hl.Struct(inner="deep")) + + with caplog.at_level(logging.INFO): + print_global_struct(s) + + # Top-level fields get 4 spaces, nested get 8. + assert " top: val" in caplog.text + assert " inner: deep" in caplog.text + + def test_multiple_nested_levels(self, caplog): + """Test formatting with multiple nesting levels.""" + s = hl.Struct(level1=hl.Struct(level2=hl.Struct(value=99))) + + with caplog.at_level(logging.INFO): + print_global_struct(s) + + assert " level1:" in caplog.text + assert " level2:" in caplog.text + assert " value: 99" in caplog.text + + +class TestConvertMultiArrayToArrayOfStructs: + """Test the convert_multi_array_to_array_of_structs function.""" + + def test_basic_conversion(self): + """Test converting two parallel arrays into an array of structs.""" + ht = hl.Table.parallelize( + [{"a": [1, 2, 3], "b": [4, 5, 6], "other": "keep"}], + hl.tstruct( + a=hl.tarray(hl.tint32), + b=hl.tarray(hl.tint32), + other=hl.tstr, + ), + ) + + result_ht = convert_multi_array_to_array_of_structs(ht, ["a", "b"], "combined") + result = result_ht.collect()[0] + + # Original fields should be dropped. + assert not hasattr(result, "a") + assert not hasattr(result, "b") + + # Non-combined field should remain. + assert result.other == "keep" + + # Check the combined array. + assert len(result.combined) == 3 + assert result.combined[0].a == 1 + assert result.combined[0].b == 4 + assert result.combined[1].a == 2 + assert result.combined[1].b == 5 + assert result.combined[2].a == 3 + assert result.combined[2].b == 6 + + def test_three_arrays(self): + """Test converting three parallel arrays.""" + ht = hl.Table.parallelize( + [{"x": [1.0, 2.0], "y": [3.0, 4.0], "z": [5.0, 6.0]}], + hl.tstruct( + x=hl.tarray(hl.tfloat64), + y=hl.tarray(hl.tfloat64), + z=hl.tarray(hl.tfloat64), + ), + ) + + result_ht = convert_multi_array_to_array_of_structs( + ht, ["x", "y", "z"], "merged" + ) + result = result_ht.collect()[0] + + assert len(result.merged) == 2 + assert result.merged[0].x == 1.0 + assert result.merged[0].y == 3.0 + assert result.merged[0].z == 5.0 + + def test_with_struct_expression(self): + """Test conversion on a StructExpression (annotated within a table).""" + ht = hl.Table.parallelize( + [{"s": hl.Struct(a=[10, 20], b=[30, 40])}], + hl.tstruct( + s=hl.tstruct( + a=hl.tarray(hl.tint32), + b=hl.tarray(hl.tint32), + ) + ), + ) + + result_ht = ht.annotate( + s=convert_multi_array_to_array_of_structs(ht.s, ["a", "b"], "combined") + ) + result = result_ht.collect()[0] + + assert len(result.s.combined) == 2 + assert result.s.combined[0].a == 10 + assert result.s.combined[0].b == 30 diff --git a/tests/utils/test_vep.py b/tests/utils/test_vep.py index e4d12b75d..285f49ca8 100644 --- a/tests/utils/test_vep.py +++ b/tests/utils/test_vep.py @@ -5,6 +5,7 @@ from gnomad.utils.vep import ( get_loftee_end_trunc_filter_expr, + mane_select_over_canonical_filter_expr, update_loftee_end_trunc_filter, ) @@ -256,3 +257,103 @@ def test_empty_filter_handling(self): assert results[1].updated_csq.lof_filter == "END_TRUNC" assert results[1].updated_csq.lof == "LC" + + +class TestManeSelectOverCanonicalFilterExpr: + """Test the mane_select_over_canonical_filter_expr function.""" + + @pytest.fixture + def sample_table(self): + """Fixture to create a table with transcript/mane_select/canonical/gene_id fields.""" + return hl.Table.parallelize( + [ + # Gene A: has MANE Select transcript. + { + "transcript": "ENST00001", + "mane_select": True, + "canonical": True, + "gene_id": "ENSG_A", + }, + { + "transcript": "ENST00002", + "mane_select": False, + "canonical": True, + "gene_id": "ENSG_A", + }, + # Gene B: no MANE Select, only canonical. + { + "transcript": "ENST00003", + "mane_select": False, + "canonical": True, + "gene_id": "ENSG_B", + }, + { + "transcript": "ENST00004", + "mane_select": False, + "canonical": False, + "gene_id": "ENSG_B", + }, + # Gene C: non-ENST transcript should be excluded. + { + "transcript": "NM_00001", + "mane_select": True, + "canonical": True, + "gene_id": "ENSG_C", + }, + ], + hl.tstruct( + transcript=hl.tstr, + mane_select=hl.tbool, + canonical=hl.tbool, + gene_id=hl.tstr, + ), + ) + + def test_mane_select_preferred_over_canonical(self, sample_table): + """Test that MANE Select is chosen when available for a gene.""" + ht = sample_table.annotate( + selected=mane_select_over_canonical_filter_expr( + sample_table.transcript, + sample_table.mane_select, + sample_table.canonical, + sample_table.gene_id, + ) + ) + results = ht.collect() + + # Gene A: ENST00001 (MANE Select) should be selected, ENST00002 should not. + result_map = {r.transcript: r.selected for r in results} + assert result_map["ENST00001"] is True + assert result_map["ENST00002"] is False + + def test_canonical_fallback_when_no_mane(self, sample_table): + """Test that canonical is used as fallback when no MANE Select exists.""" + ht = sample_table.annotate( + selected=mane_select_over_canonical_filter_expr( + sample_table.transcript, + sample_table.mane_select, + sample_table.canonical, + sample_table.gene_id, + ) + ) + results = ht.collect() + + result_map = {r.transcript: r.selected for r in results} + # Gene B: ENST00003 (canonical) should be selected, ENST00004 should not. + assert result_map["ENST00003"] is True + assert result_map["ENST00004"] is False + + def test_non_enst_excluded(self, sample_table): + """Test that non-ENST transcripts are excluded.""" + ht = sample_table.annotate( + selected=mane_select_over_canonical_filter_expr( + sample_table.transcript, + sample_table.mane_select, + sample_table.canonical, + sample_table.gene_id, + ) + ) + results = ht.collect() + + result_map = {r.transcript: r.selected for r in results} + assert result_map["NM_00001"] is False