add Gram matrix fast path for ZnS computation#25
add Gram matrix fast path for ZnS computation#25kevinkorfmann wants to merge 4 commits intokr-colab:mainfrom
Conversation
There was a problem hiding this comment.
I tested this on full Ag1000G chromosome 3L data (2,940 haplotypes x 8,248,442 variants) and it OOMs:
cupy.cuda.memory.OutOfMemoryError: Out of memory allocating 97,001,678,336 bytes
(allocated so far: 48,533,833,216 bytes)
The issue is that _prepare_segregating creates the full haplotype matrix as float64, and then the standardization step S = (hap_clean - p) * inv_sqrt_pq creates another full-size (n_hap x m) float64 array. At 2940 x ~8M segregating sites, that's ~185GB for S alone.
The Gram trick itself seems cool -- K = S @ S.T is only 2940 x 2940 -- but you need to accumulate K in chunks over columns (sites) so that only a (n_hap x chunk_size) slice of S is ever in memory at once. Something like:
K = cp.zeros((n_hap, n_hap), dtype=cp.float64)
for col_start in range(0, m, chunk_size):
S_chunk = (hap_clean[:, col_start:col_end] - p[col_start:col_end]) * inv_sqrt_pq[col_start:col_end]
K += S_chunk @ S_chunk.TThe existing _memutil.py chunking infrastructure may be useful here.
Also: the missing data fallback to the tiled path is correct, but means the fast path won't activate on datasets with any missingness. This isn't ideal -- users with real-world data (which commonly has missing genotypes) won't see the speedup unless they exclude sites and it means we have to keep a fallback code path in the project-- more code to maintain.
Please test on full-scale data (e.g., 1000 Genomes or Ag1000G full chromosome arms) before resubmitting. The unit tests all pass at small scale but the memory issue only shows up with real genomic data. Also please deal with the missing data issue so that this could replace the existing code fully.
The Gram-trick ZnS computation OOMed on Ag1000G 3L (2940 haps x
8,248,442 variants):
cupy.cuda.memory.OutOfMemoryError: Out of memory allocating
97,001,678,336 bytes
Root cause: _prepare_segregating cast the full haplotype matrix to
float64, then the standardization step S = (hap - p) * inv_sqrt_pq
created another full-size (n_hap x m) float64 array. At 2940 x ~8M
segregating sites that is ~185GB for S alone.
## What changed
1. Segregating-site filtering now uses chunked_dac_and_n (keeps int8)
instead of _prepare_segregating (which cast to float64 upfront).
2. S is never fully materialized. The loop chunks over columns
(sites), builds an (n_hap x chunk_size) float64 S_chunk, and
accumulates K += S_chunk @ S_chunk.T. chunk_size is auto-sized
by estimate_variant_chunk_size to fit ~40% of free GPU memory.
K itself is (n_hap x n_hap) -- only 69MB for N=2940.
3. Missing data is handled via mean imputation: for missing entries,
S_ki = 0 (equivalent to imputing the site mean p_i, which
contributes nothing to covariance). A MCAR correction factor
of (n^2 * E[1/n_i^2])^2 compensates for dividing by n (total
samples) instead of n_both(i,j) (valid at both sites) per pair.
4. For small m with missing data, the exact tiled O(m^2) path is
used automatically (see path selection below).
5. A UserWarning is emitted when the MCAR correction exceeds 5%,
directing users to missing_data='exclude' for exact results.
## Path selection
The default is missing_data='include', which uses per-site valid
data for frequency computation. The Gram path is O(n^2 m). The
tiled path computes exact per-pair r^2 but is O(n m^2) -- infeasible
at chromosome scale (~5 hours for N=2940, M=8M). Selection is
automatic:
| Condition | Path | Accuracy |
|----------------------------------------|-------|-------------|
| No missing data | Gram | exact |
| missing_data='exclude' | Gram | exact |
| missing_data='project' | tiled | exact |
| missing_data='include', n*m^2 < 5e11 | tiled | exact |
| missing_data='include', n*m^2 >= 5e11 | Gram | approximate |
The last row is the only approximate path. It activates when the
exact tiled path would take ~50s+ of GPU compute.
For users who want exact results at chromosome scale with missing
data: missing_data='exclude' drops sites with any missing genotype
first, then runs the Gram path on clean data. This gives exact
results at full O(n^2 m) speed -- the only cost is losing some
sites, which is often acceptable since ZnS only uses segregating
sites anyway.
## Why the Gram path cannot be exact with missing data
The Gram trick gives ||K||_F^2 = sum_{ij} (S^T S)_{ij}^2, which is
sum(a). But exact ZnS with missing data needs
sum_{ij} (S^T S)_{ij}^2 / n_both(i,j)^2, which is sum(a/b).
There is no O(n^2 m) algorithm that computes sum(a/b) from the
(n x n) Gram matrices alone -- it requires the (m x m) matrices.
The MCAR correction approximates n_both(i,j) ~ n_i * n_j / n, a
law-of-large-numbers estimate that improves with sample size.
## Gram accuracy at the switchover (10% missingness)
| N | Gram kicks in above | Gram error |
|-------|---------------------|------------|
| 200 | 50,000 sites | ~2% |
| 500 | 31,622 sites | 0.76% |
| 1000 | 22,360 sites | 0.4% |
| 2000 | 15,811 sites | 0.2% |
| 3000 | 12,909 sites | 0.1% |
| 5000 | 10,000 sites | 0.05% |
Error scales as ~4/N at 10% missing and shrinks proportionally with
lower missingness. Below these site counts, the tiled path is used
and results are exact. Real chromosome-scale datasets typically have
N > 1000 haplotypes (Ag1000G: 2940, 1000 Genomes: 5008), keeping the
Gram error well below 1%.
## Ag1000G 3L validation (2940 haps x 8,248,442 variants)
- Before: OOM (97GB allocation failure)
- After: 14.5s, ~3GB peak GPU overhead
- 50K-variant subset: rel_err = 1.17e-13 (exact)
## Speedup vs tiled (no missing data)
| n_haps | n_snps | gram | tiled | speedup |
|--------|--------|--------|---------|---------|
| 100 | 50,000 | 3.8ms | 1,763ms | 460x |
| 200 | 50,000 | 6.3ms | 1,793ms | 284x |
Tested: 416/416 tests pass, Ag1000G 3L full chromosome, synthetic
data at 0-50% missingness with N=200 (exact via tiled) and N=5000.
|
Hi Andy, For missing data: the Gram path now handles it via mean imputation (S=0 for missing entries) with a MCAR correction. For small m where tiled is fast anyway it auto-selects the exact path, and a Tested on full Ag1000G 3L: 2940 x 8.2M runs in 14.5s (vs ~5 hours for the tiled path at that scale), ~3GB GPU overhead, matches tiled at machine precision on a 50K subset. 416/416 tests pass. Full details in the commit message. |
|
Thanks for the revision Kevin -- the chunked Gram accumulation is a great improvement and the OOM issue is resolved. I ran some validation on real unphased Ag1000G data (3L, ~2940 haplotypes, 74% missingness) comparing the PR branch against main:
There are two problems:
So the missing data path needs to be tightened up. Rather than reimplementing the missing data filtering, use the existing |
Addresses Andy's review of 9a12a28: 1. include mode at high missingness produced ZnS >> 1 (~193, ~393 on real Ag1000G 3L 100kb / 500kb windows). Root cause: the standardization S_{ki} = (h_ki - p_i) / sqrt(p_i q_i) gives ||S_i||^2 = n_i (not n) under mean imputation, so r(i,j) = (S^T S)_{ij} / sqrt(n_i n_j) -- a per-pair factor that no global MCAR scalar can capture. Fixed by baking 1/sqrt(n_i) into the standardization: B_{ki} = (h_ki - p_i) / sqrt(n_i p_i q_i) for valid k = 0 for missing k Then ||B_i||^2 = 1 exactly, so the diagonal of R = B^T B is 1 and ZnS = (||K||_F^2 - m) / (m(m-1)) is strictly bounded in [0, 1] by Cauchy-Schwarz. The MCAR correction block is removed entirely. For include mode under missingness this computes mean-imputation Pearson r^2 -- a standard estimator that differs from main's hybrid (per-site p_i, per-pair p_AB) by O(missing rate). For no missing data the two formulas reduce to identical standard Pearson r^2. 2. exclude mode at 500kb diverged from main (2.82 vs 0.25). Root cause: _zns_gram reimplemented exclude / segregating filtering and produced a different site set than main. Fixed by calling _prepare_segregating instead, which already handles all three modes (exclude, include, project) and returns hap_clean, valid_mask, m. Exclude mode now matches main at machine precision. The chunked Gram accumulation is preserved -- still the original PR's OOM fix and ~80x to 565x faster than the tiled reference under missingness. Validation: - Real Ag1000G 3L unphased (1838 samples, biallelic, ~17% missing): exclude matches main at 1e-15 across 50 / 100 / 500 kb windows; include bounded in [0, 1]; 500kb runs in 0.24s without OOM. - Synthetic sweep at 0 / 5 / 10 / 50 / 74% missing: include bounded in [0, 1]; exclude matches main at machine precision; no-missing Gram matches tiled at 1e-14. - pytest tests/test_diploshic_stats.py tests/test_windowed_analysis.py: 47 passed, 1 skipped. - validate_against_allel.py: 29 PASS, 0 FAIL.
cabfc8d to
9a12a28
Compare
|
Thanks, Andy. I reproduced this on the unphased gamb 3L Ag1000G data and fixed both failure modes. The validation table below is from this exact input:
Current behavior:
What was wrong:
What I changed:
Here is the comparison on that exact dataset and region:
One additional note on performance: I also compared the fixed exclude Gram path to the exact tiled reference on the same retained complete sites. On this dataset, exclude leaves relatively few complete I also added targeted ZnS regressions for:
I reran the targeted ZnS tests on poppy after the fix and they pass. |
|
I think your proposal from slack is the right approach. Thinking it through: The mixed-n problem in Proposed design (2 paths, no
c1, c2, c3, c4, n = _tile_counts(hi, vi, hj, vj)
p_i = (c1 + c2) / n
p_j = (c1 + c3) / n
D = c1 / n - p_i * p_j
denom = p_i * (1 - p_i) * p_j * (1 - p_j)
r2_tile = cp.where((denom > 0) & (n >= 2), D**2 / denom, 0.0)
This eliminates the |
|
also note-- you should rebase before you start editing |
Summary
||K||_F^2
Benchmarks (NVIDIA A100, 200 haplotypes x 50k variants)
Also speeds up
zx(callszns3x internally).Test plan
pixi run pytest tests/ -k 'zns or ld or omega or windowed'(68 passed)pixi run python debug/bench_zns.pyfor speed + correctness comparison