Skip to content

add Gram matrix fast path for ZnS computation#25

Open
kevinkorfmann wants to merge 4 commits intokr-colab:mainfrom
kevinkorfmann:feat/fast-zns
Open

add Gram matrix fast path for ZnS computation#25
kevinkorfmann wants to merge 4 commits intokr-colab:mainfrom
kevinkorfmann:feat/fast-zns

Conversation

@kevinkorfmann
Copy link
Copy Markdown
Contributor

Summary

  • Add Gram matrix fast path for ZnS (Kelly's ZnS) computation
  • Instead of O(n*m^2) tiled pairwise r^2, computes K = S @ S.T where K is n x n, then derives sum(r^2) from
    ||K||_F^2
  • Falls back to tiled approach when missing data is present or projection mode is used
  • 70-1400x speedup depending on dataset size, machine-precision accuracy

Benchmarks (NVIDIA A100, 200 haplotypes x 50k variants)

n_haps n_snps Before (ms) After (ms) Speedup
100 50,000 1,733 1.2 1,408x
200 50,000 1,755 2.0 881x

Also speeds up zx (calls zns 3x internally).

Test plan

  • pixi run pytest tests/ -k 'zns or ld or omega or windowed' (68 passed)
  • pixi run python debug/bench_zns.py for speed + correctness comparison

Copy link
Copy Markdown
Member

@andrewkern andrewkern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this on full Ag1000G chromosome 3L data (2,940 haplotypes x 8,248,442 variants) and it OOMs:

cupy.cuda.memory.OutOfMemoryError: Out of memory allocating 97,001,678,336 bytes
(allocated so far: 48,533,833,216 bytes)

The issue is that _prepare_segregating creates the full haplotype matrix as float64, and then the standardization step S = (hap_clean - p) * inv_sqrt_pq creates another full-size (n_hap x m) float64 array. At 2940 x ~8M segregating sites, that's ~185GB for S alone.

The Gram trick itself seems cool -- K = S @ S.T is only 2940 x 2940 -- but you need to accumulate K in chunks over columns (sites) so that only a (n_hap x chunk_size) slice of S is ever in memory at once. Something like:

K = cp.zeros((n_hap, n_hap), dtype=cp.float64)
for col_start in range(0, m, chunk_size):
    S_chunk = (hap_clean[:, col_start:col_end] - p[col_start:col_end]) * inv_sqrt_pq[col_start:col_end]
    K += S_chunk @ S_chunk.T

The existing _memutil.py chunking infrastructure may be useful here.

Also: the missing data fallback to the tiled path is correct, but means the fast path won't activate on datasets with any missingness. This isn't ideal -- users with real-world data (which commonly has missing genotypes) won't see the speedup unless they exclude sites and it means we have to keep a fallback code path in the project-- more code to maintain.

Please test on full-scale data (e.g., 1000 Genomes or Ag1000G full chromosome arms) before resubmitting. The unit tests all pass at small scale but the memory issue only shows up with real genomic data. Also please deal with the missing data issue so that this could replace the existing code fully.

The Gram-trick ZnS computation OOMed on Ag1000G 3L (2940 haps x
8,248,442 variants):

  cupy.cuda.memory.OutOfMemoryError: Out of memory allocating
  97,001,678,336 bytes

Root cause: _prepare_segregating cast the full haplotype matrix to
float64, then the standardization step S = (hap - p) * inv_sqrt_pq
created another full-size (n_hap x m) float64 array.  At 2940 x ~8M
segregating sites that is ~185GB for S alone.

## What changed

1. Segregating-site filtering now uses chunked_dac_and_n (keeps int8)
   instead of _prepare_segregating (which cast to float64 upfront).

2. S is never fully materialized.  The loop chunks over columns
   (sites), builds an (n_hap x chunk_size) float64 S_chunk, and
   accumulates K += S_chunk @ S_chunk.T.  chunk_size is auto-sized
   by estimate_variant_chunk_size to fit ~40% of free GPU memory.
   K itself is (n_hap x n_hap) -- only 69MB for N=2940.

3. Missing data is handled via mean imputation: for missing entries,
   S_ki = 0 (equivalent to imputing the site mean p_i, which
   contributes nothing to covariance).  A MCAR correction factor
   of (n^2 * E[1/n_i^2])^2 compensates for dividing by n (total
   samples) instead of n_both(i,j) (valid at both sites) per pair.

4. For small m with missing data, the exact tiled O(m^2) path is
   used automatically (see path selection below).

5. A UserWarning is emitted when the MCAR correction exceeds 5%,
   directing users to missing_data='exclude' for exact results.

## Path selection

The default is missing_data='include', which uses per-site valid
data for frequency computation.  The Gram path is O(n^2 m).  The
tiled path computes exact per-pair r^2 but is O(n m^2) -- infeasible
at chromosome scale (~5 hours for N=2940, M=8M).  Selection is
automatic:

  | Condition                              | Path  | Accuracy    |
  |----------------------------------------|-------|-------------|
  | No missing data                        | Gram  | exact       |
  | missing_data='exclude'                 | Gram  | exact       |
  | missing_data='project'                 | tiled | exact       |
  | missing_data='include', n*m^2 < 5e11   | tiled | exact       |
  | missing_data='include', n*m^2 >= 5e11  | Gram  | approximate |

The last row is the only approximate path.  It activates when the
exact tiled path would take ~50s+ of GPU compute.

For users who want exact results at chromosome scale with missing
data: missing_data='exclude' drops sites with any missing genotype
first, then runs the Gram path on clean data.  This gives exact
results at full O(n^2 m) speed -- the only cost is losing some
sites, which is often acceptable since ZnS only uses segregating
sites anyway.

## Why the Gram path cannot be exact with missing data

The Gram trick gives ||K||_F^2 = sum_{ij} (S^T S)_{ij}^2, which is
sum(a).  But exact ZnS with missing data needs
sum_{ij} (S^T S)_{ij}^2 / n_both(i,j)^2, which is sum(a/b).
There is no O(n^2 m) algorithm that computes sum(a/b) from the
(n x n) Gram matrices alone -- it requires the (m x m) matrices.

The MCAR correction approximates n_both(i,j) ~ n_i * n_j / n, a
law-of-large-numbers estimate that improves with sample size.

## Gram accuracy at the switchover (10% missingness)

  |     N | Gram kicks in above | Gram error |
  |-------|---------------------|------------|
  |   200 |       50,000 sites  |       ~2%  |
  |   500 |       31,622 sites  |     0.76%  |
  |  1000 |       22,360 sites  |      0.4%  |
  |  2000 |       15,811 sites  |      0.2%  |
  |  3000 |       12,909 sites  |      0.1%  |
  |  5000 |       10,000 sites  |     0.05%  |

Error scales as ~4/N at 10% missing and shrinks proportionally with
lower missingness.  Below these site counts, the tiled path is used
and results are exact.  Real chromosome-scale datasets typically have
N > 1000 haplotypes (Ag1000G: 2940, 1000 Genomes: 5008), keeping the
Gram error well below 1%.

## Ag1000G 3L validation (2940 haps x 8,248,442 variants)

  - Before: OOM (97GB allocation failure)
  - After: 14.5s, ~3GB peak GPU overhead
  - 50K-variant subset: rel_err = 1.17e-13 (exact)

## Speedup vs tiled (no missing data)

  | n_haps | n_snps | gram   | tiled   | speedup |
  |--------|--------|--------|---------|---------|
  |    100 | 50,000 | 3.8ms  | 1,763ms |    460x |
  |    200 | 50,000 | 6.3ms  | 1,793ms |    284x |

Tested: 416/416 tests pass, Ag1000G 3L full chromosome, synthetic
data at 0-50% missingness with N=200 (exact via tiled) and N=5000.
@kevinkorfmann
Copy link
Copy Markdown
Contributor Author

Hi Andy,
Chunked the Gram accumulation over columns just like you suggested -- only an (n_hap x chunk_size) slice of S lives on GPU at a time, chunk size auto-sized to fit available memory. Also switched the segregating-site filtering to stay in int8 instead of casting to float64 upfront.

For missing data: the Gram path now handles it via mean imputation (S=0 for missing entries) with a MCAR correction. For small m where tiled is fast anyway it auto-selects the exact path, and a UserWarning fires when the correction gets significant so users know to try missing_data='exclude' if they want exact results. The fundamental issue is that the Gram trick gives you sum(a) not sum(a/b) -- you'd need the full m x m matrices for exact per-pair n_both normalization. The correction error scales as ~4/N at 10% missing, so for Ag1000G scale (N=2940) it's ~0.1%.

Tested on full Ag1000G 3L: 2940 x 8.2M runs in 14.5s (vs ~5 hours for the tiled path at that scale), ~3GB GPU overhead, matches tiled at machine precision on a 50K subset. 416/416 tests pass. Full details in the commit message.

@andrewkern
Copy link
Copy Markdown
Member

Thanks for the revision Kevin -- the chunked Gram accumulation is a great improvement and the OOM issue is resolved.

I ran some validation on real unphased Ag1000G data (3L, ~2940 haplotypes, 74% missingness) comparing the PR branch against main:

Region Mode main PR
50kb (6.8k variants) include 0.0097 0.0097 (tiled fallback, matches)
50kb exclude 0.0862 0.0864 (close)
100kb (13k variants) include 0.0091 193.7 (Gram + MCAR correction)
100kb exclude 0.0963 0.0990 (close)
500kb (78k variants) include 0.0104 392.9 (Gram + MCAR correction)
500kb exclude 0.2475 2.82

There are two problems:

  1. include mode with missing data is really off. The MCAR correction factor you've implemented at 74% missingness is ~400,000x, producing ZnS values well above 1 (ZnS is bounded 0-1). The 50kb case, I think, falls under the threshold and uses the tiled fallback, which is why it matches. The core issue is that the MCAR assumption (n_both(i,j) ~ n_i*n_j/n) breaks down badly when missingness is high or structured, which is common in real data.

  2. exclude mode diverges at 500kb (2.82 vs 0.25, also >1). The reimplemented site filtering in the PR produces a different site set from main's _prepare_segregating path. Something is off in how excluded sites are handled.

So the missing data path needs to be tightened up. Rather than reimplementing the missing data filtering, use the existing _prepare_segregating() which already handles all three modes correctly (exclude, include, project). It returns hap_clean (missing set to 0) and valid_mask with the right sites filtered.

Addresses Andy's review of 9a12a28:

1. include mode at high missingness produced ZnS >> 1 (~193, ~393 on
   real Ag1000G 3L 100kb / 500kb windows). Root cause: the standardization
   S_{ki} = (h_ki - p_i) / sqrt(p_i q_i) gives ||S_i||^2 = n_i (not n)
   under mean imputation, so r(i,j) = (S^T S)_{ij} / sqrt(n_i n_j) -- a
   per-pair factor that no global MCAR scalar can capture. Fixed by
   baking 1/sqrt(n_i) into the standardization:

       B_{ki} = (h_ki - p_i) / sqrt(n_i p_i q_i)   for valid k
              = 0                                    for missing k

   Then ||B_i||^2 = 1 exactly, so the diagonal of R = B^T B is 1 and
   ZnS = (||K||_F^2 - m) / (m(m-1)) is strictly bounded in [0, 1] by
   Cauchy-Schwarz. The MCAR correction block is removed entirely.

   For include mode under missingness this computes mean-imputation
   Pearson r^2 -- a standard estimator that differs from main's hybrid
   (per-site p_i, per-pair p_AB) by O(missing rate). For no missing
   data the two formulas reduce to identical standard Pearson r^2.

2. exclude mode at 500kb diverged from main (2.82 vs 0.25). Root cause:
   _zns_gram reimplemented exclude / segregating filtering and produced
   a different site set than main. Fixed by calling _prepare_segregating
   instead, which already handles all three modes (exclude, include,
   project) and returns hap_clean, valid_mask, m. Exclude mode now
   matches main at machine precision.

The chunked Gram accumulation is preserved -- still the original PR's
OOM fix and ~80x to 565x faster than the tiled reference under
missingness.

Validation:
- Real Ag1000G 3L unphased (1838 samples, biallelic, ~17% missing):
  exclude matches main at 1e-15 across 50 / 100 / 500 kb windows;
  include bounded in [0, 1]; 500kb runs in 0.24s without OOM.
- Synthetic sweep at 0 / 5 / 10 / 50 / 74% missing: include bounded
  in [0, 1]; exclude matches main at machine precision; no-missing
  Gram matches tiled at 1e-14.
- pytest tests/test_diploshic_stats.py tests/test_windowed_analysis.py:
  47 passed, 1 skipped.
- validate_against_allel.py: 29 PASS, 0 FAIL.
@kevinkorfmann
Copy link
Copy Markdown
Contributor Author

Thanks, Andy. I reproduced this on the unphased gamb 3L Ag1000G data and fixed both failure modes.

The validation table below is from this exact input:

  • file: /sietch_colab/data_share/Ag1000G/Ag3.0/vcf/unphased_vcf/gamb/ag1000g.agam_n1470.merged_variants.sitefilt.vcf.gz
  • chromosome: 3L
  • samples: 1470 diploid samples
  • haplotypes: 2940
  • region start used for the check: 3L:5,000,000
  • resulting variant counts:
    • 50kb: 6830
    • 100kb: 13227
    • 500kb: 78376
  • missingness in those windows was about 73-74%

Current behavior:

  • default is still missing_data='include'
  • include with missing data now uses the exact tiled computation
  • exclude uses the Gram path
  • project uses tiled sigma_D^2

What was wrong:

  • include: the MCAR correction was invalid under heavy/structured missingness, which is why ZnS exploded above 1.
  • exclude: even after restoring _prepare_segregating(), the fast Gram reduction still had a bad diagonal assumption on this unphased input, which is why the 500kb value was inflated.

What I changed:

  • include with missing data now uses the exact tiled computation.
  • exclude uses the fast Gram path, but with the correct diagonal treatment so it matches the tiled reference on the retained complete sites.
  • both paths use _prepare_segregating() semantics for filtering.

Here is the comparison on that exact dataset and region:

Region Mode main broken PR fixed PR engine used runtime
50kb (6830 variants) include 0.0097 0.0097 0.0096579 tiled 0.198s
50kb (6830 variants) exclude 0.0862 0.0864 0.0861817 Gram 0.0216s
100kb (13227 variants) include 0.0091 193.7 0.0091173 tiled 0.199s
100kb (13227 variants) exclude 0.0963 0.0990 0.0963061 Gram 0.00944s
500kb (78376 variants) include 0.0104 392.9 0.0104272 tiled 5.253s
500kb (78376 variants) exclude 0.2475 2.82 0.2475494 Gram 0.00878s

One additional note on performance: I also compared the fixed exclude Gram path to the exact tiled reference on the same retained complete sites. On this dataset, exclude leaves relatively few complete
segregating sites, so tiled is actually faster here. Gram remains enabled for exclude, but on this dataset it is not the faster path.

I also added targeted ZnS regressions for:

  • missing-data include
  • exclude parity with tiled exact
  • unphased/multiallelic-style exclude
  • direct vs windowed ZnS parity

I reran the targeted ZnS tests on poppy after the fix and they pass.

@andrewkern
Copy link
Copy Markdown
Member

I think your proposal from slack is the right approach. Thinking it through:

The mixed-n problem in _tile_r2_naive: The current include mode computes marginal frequencies p_i from per-site n_valid, but the joint frequency p_AB uses per-pair n (samples valid at both sites). This mixes marginal frequencies computed from different sample sets than the joint frequency, which biases r² when missingness patterns differ between sites.

Proposed design (2 paths, no project mode):

  • include: Tiled path using count-based r² from _tile_counts. The 4-way counts (c1/c2/c3/c4) already use only samples valid at both sites, so all quantities in each pair are computed from the same sample set — no mixed-n bias. The change to _zns_tiled_exact would be:
c1, c2, c3, c4, n = _tile_counts(hi, vi, hj, vj)
p_i = (c1 + c2) / n
p_j = (c1 + c3) / n
D = c1 / n - p_i * p_j
denom = p_i * (1 - p_i) * p_j * (1 - p_j)
r2_tile = cp.where((denom > 0) & (n >= 2), D**2 / denom, 0.0)
  • exclude: Gram matrix fast path (your existing implementation). All sites have the same n after filtering, so the Gram trick is exact.

  • estimator='sigma_d2': Stays orthogonal — keep the existing _tile_sigma_d2 fallback as-is.

This eliminates the project missing data mode entirely (it was already removed from the public API in #32), keeps the Gram fast path for the common no-missing-data case, and makes include exact via the count-based approach that _tile_counts already supports.

@andrewkern
Copy link
Copy Markdown
Member

also note-- you should rebase before you start editing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants