Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion locator/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self, config=None): # noqa: C901
- **width** (*int*): Width of neural network layers.
- **nlayers** (*int*): Number of neural network layers.
- **dropout_prop** (*float*): Dropout proportion.
- **pca_components** (*int*): If set, prepend a PCA-initialized linear projection of this width as the first layer and fine-tune it. Recommended when n_SNPs >> n_samples. Default None (disabled).
- **pca_components** (*int or "auto"*): If set, prepend a PCA-initialized linear projection of this width as the first layer and fine-tune it. Use ``"auto"`` to pick the width from the genotype-PCA scree elbow. Recommended when n_SNPs >> n_samples. Default None (disabled).
- **pca_finetune** (*bool*): Whether to unfreeze the PCA projection for a low-learning-rate fine-tuning phase. Default True. False keeps the projection frozen at its PCA initialization.
- **pca_finetune_lr** (*float*): Learning rate for the PCA fine-tuning phase. Default 1e-4.
- **keras_verbose** (*int*): Verbosity level for Keras training.
Expand Down
37 changes: 37 additions & 0 deletions locator/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,40 @@ def compute_pca_projection_gram(genotype_matrix, n_components):
W = tf.linalg.matmul(Xc, evecs, transpose_a=True) / scale
bias = tf.reshape(-tf.linalg.matmul(mean, W), [-1])
return W, bias


def scree_elbow(genotype_matrix):
"""Pick a PCA rank from the genotype-PCA scree elbow.

Computes the explained-variance spectrum (Gram-matrix eigenvalues over the
samples) and returns the rank at the chord-distance elbow: the point of the
explained-variance curve farthest below the straight line joining its first
and last points. For genotype data the spectrum drops steeply then flattens,
so this is a stable, data-driven projection width.

Parameters
----------
genotype_matrix : tf.Tensor or np.ndarray
Genotype matrix of shape ``(n_samples, n_snps)``; training samples only.

Returns
-------
int
The elbow rank, at least 1.
"""
X = tf.cast(genotype_matrix, tf.float32)
mean = tf.reduce_mean(X, axis=0, keepdims=True)
Xc = X - mean
gram = tf.linalg.matmul(Xc, Xc, transpose_b=True)
evals = np.clip(tf.linalg.eigvalsh(gram).numpy()[::-1], 0.0, None)
total = evals.sum()
if total <= 0.0:
return 1
# Keep the meaningful components; a centred n-sample matrix has rank n-1.
evr = evals[evals > total * 1e-9] / total
if len(evr) < 3 or evr[0] == evr[-1]:
return max(1, len(evr))
x = np.arange(len(evr), dtype=np.float64) / (len(evr) - 1)
y = (evr - evr.min()) / (evr.max() - evr.min())
chord = y[0] + x * (y[-1] - y[0])
return int(np.argmax(chord - y)) + 1
35 changes: 33 additions & 2 deletions locator/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
loss_with_range_penalty,
rasterize_species_range,
)
from .pca import compute_pca_projection_gram
from .pca import compute_pca_projection_gram, scree_elbow
from .sample_weights import weight_samples


Expand Down Expand Up @@ -643,7 +643,7 @@ def loss_fn(y_true, y_pred): # noqa: F811

self._loss_fn = loss_fn

pca_components = self.config.get("pca_components")
pca_components = self._resolve_pca_components()
inner = create_network(
input_shape=input_shape,
width=self.config.get("width", 256),
Expand Down Expand Up @@ -704,6 +704,37 @@ def _get_genotype_table(self):
self._genotype_table_src = self.filtered_genotypes
return self._genotype_table

def _resolve_pca_components(self):
"""Resolve the pca_components config value to a concrete width.

Returns None (no projection) or an int. The string ``"auto"`` is
resolved to the genotype-PCA scree elbow of the training split and
written back to the config, so every fold and the saved metadata of a
run share one rank.
"""
pca_components = self.config.get("pca_components")
if pca_components is None or isinstance(pca_components, int):
return pca_components
if pca_components != "auto":
raise ValueError(
f"pca_components must be None, an int, or 'auto'; got {pca_components!r}"
)
index_set = getattr(self, "index_set", None)
if index_set is None or getattr(self, "filtered_genotypes", None) is None:
raise ValueError(
"pca_components='auto' needs training data; pass an explicit "
"integer when building an architecture to load weights into"
)
train_geno = tf.gather(
self._get_genotype_table(),
np.asarray(index_set.train, dtype=np.int32),
axis=0,
)
rank = scree_elbow(train_geno)
self.config["pca_components"] = rank
print(f"pca_components='auto': using scree-elbow rank {rank}")
return rank

def _inject_pca_weights(self, model, pca_components):
"""Initialize the pca_projection layer with PCA loadings, gate closed.

Expand Down
28 changes: 28 additions & 0 deletions tests/test_pca_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def test_train_builds_pca_projection(genotype_data, basic_config):
assert layer.units == 8


def test_pca_components_auto_resolves_to_elbow(genotype_data, basic_config):
"""pca_components='auto' picks a concrete projection width from the scree."""
genotypes, samples, _, _, _ = genotype_data
loc = Locator(_pca_config(basic_config, pca_components="auto", max_epochs=2))
loc.train(genotypes=genotypes, samples=samples)

resolved = loc.config["pca_components"]
assert isinstance(resolved, int) and resolved >= 1
assert loc.model.get_layer(PCA_LAYER_NAME).units == resolved


def test_frozen_projection_keeps_pca_loadings(genotype_data, basic_config):
"""With pca_finetune=False the projection stays at its PCA initialization."""
genotypes, samples, _, _, _ = genotype_data
Expand Down Expand Up @@ -151,3 +162,20 @@ def test_gradient_gate_controls_gradient_flow():
assert np.allclose(y.numpy(), x.numpy())
# Gradient reaching the input is scaled by the gate.
assert np.allclose(grad.numpy(), expected_grad)


def test_scree_elbow_finds_low_rank_structure():
"""scree_elbow returns a small rank for data with a few dominant axes.

Defined last: it runs TensorFlow ops (see the note above).
"""
from locator.pca import scree_elbow

rng = np.random.default_rng(0)
n, p, k = 60, 2000, 3
factors = rng.standard_normal((n, k))
loadings = rng.standard_normal((k, p))
X = factors @ loadings * 10.0 + rng.standard_normal((n, p))

rank = scree_elbow(X)
assert 2 <= rank <= 8
Loading