diff --git a/locator/core.py b/locator/core.py index f376768..fd22951 100644 --- a/locator/core.py +++ b/locator/core.py @@ -165,7 +165,7 @@ def __init__(self, config=None): # noqa: C901 - **width** (*int*): Width of neural network layers. - **nlayers** (*int*): Number of neural network layers. - **dropout_prop** (*float*): Dropout proportion. - - **pca_components** (*int*): If set, prepend a PCA-initialized linear projection of this width as the first layer and fine-tune it. Recommended when n_SNPs >> n_samples. Default None (disabled). + - **pca_components** (*int or "auto"*): If set, prepend a PCA-initialized linear projection of this width as the first layer and fine-tune it. Use ``"auto"`` to pick the width from the genotype-PCA scree elbow. Recommended when n_SNPs >> n_samples. Default None (disabled). - **pca_finetune** (*bool*): Whether to unfreeze the PCA projection for a low-learning-rate fine-tuning phase. Default True. False keeps the projection frozen at its PCA initialization. - **pca_finetune_lr** (*float*): Learning rate for the PCA fine-tuning phase. Default 1e-4. - **keras_verbose** (*int*): Verbosity level for Keras training. diff --git a/locator/pca.py b/locator/pca.py index 42550de..5412baa 100644 --- a/locator/pca.py +++ b/locator/pca.py @@ -80,3 +80,40 @@ def compute_pca_projection_gram(genotype_matrix, n_components): W = tf.linalg.matmul(Xc, evecs, transpose_a=True) / scale bias = tf.reshape(-tf.linalg.matmul(mean, W), [-1]) return W, bias + + +def scree_elbow(genotype_matrix): + """Pick a PCA rank from the genotype-PCA scree elbow. + + Computes the explained-variance spectrum (Gram-matrix eigenvalues over the + samples) and returns the rank at the chord-distance elbow: the point of the + explained-variance curve farthest below the straight line joining its first + and last points. For genotype data the spectrum drops steeply then flattens, + so this is a stable, data-driven projection width. + + Parameters + ---------- + genotype_matrix : tf.Tensor or np.ndarray + Genotype matrix of shape ``(n_samples, n_snps)``; training samples only. + + Returns + ------- + int + The elbow rank, at least 1. + """ + X = tf.cast(genotype_matrix, tf.float32) + mean = tf.reduce_mean(X, axis=0, keepdims=True) + Xc = X - mean + gram = tf.linalg.matmul(Xc, Xc, transpose_b=True) + evals = np.clip(tf.linalg.eigvalsh(gram).numpy()[::-1], 0.0, None) + total = evals.sum() + if total <= 0.0: + return 1 + # Keep the meaningful components; a centred n-sample matrix has rank n-1. + evr = evals[evals > total * 1e-9] / total + if len(evr) < 3 or evr[0] == evr[-1]: + return max(1, len(evr)) + x = np.arange(len(evr), dtype=np.float64) / (len(evr) - 1) + y = (evr - evr.min()) / (evr.max() - evr.min()) + chord = y[0] + x * (y[-1] - y[0]) + return int(np.argmax(chord - y)) + 1 diff --git a/locator/training.py b/locator/training.py index 22e7f78..a57950e 100644 --- a/locator/training.py +++ b/locator/training.py @@ -31,7 +31,7 @@ loss_with_range_penalty, rasterize_species_range, ) -from .pca import compute_pca_projection_gram +from .pca import compute_pca_projection_gram, scree_elbow from .sample_weights import weight_samples @@ -643,7 +643,7 @@ def loss_fn(y_true, y_pred): # noqa: F811 self._loss_fn = loss_fn - pca_components = self.config.get("pca_components") + pca_components = self._resolve_pca_components() inner = create_network( input_shape=input_shape, width=self.config.get("width", 256), @@ -704,6 +704,37 @@ def _get_genotype_table(self): self._genotype_table_src = self.filtered_genotypes return self._genotype_table + def _resolve_pca_components(self): + """Resolve the pca_components config value to a concrete width. + + Returns None (no projection) or an int. The string ``"auto"`` is + resolved to the genotype-PCA scree elbow of the training split and + written back to the config, so every fold and the saved metadata of a + run share one rank. + """ + pca_components = self.config.get("pca_components") + if pca_components is None or isinstance(pca_components, int): + return pca_components + if pca_components != "auto": + raise ValueError( + f"pca_components must be None, an int, or 'auto'; got {pca_components!r}" + ) + index_set = getattr(self, "index_set", None) + if index_set is None or getattr(self, "filtered_genotypes", None) is None: + raise ValueError( + "pca_components='auto' needs training data; pass an explicit " + "integer when building an architecture to load weights into" + ) + train_geno = tf.gather( + self._get_genotype_table(), + np.asarray(index_set.train, dtype=np.int32), + axis=0, + ) + rank = scree_elbow(train_geno) + self.config["pca_components"] = rank + print(f"pca_components='auto': using scree-elbow rank {rank}") + return rank + def _inject_pca_weights(self, model, pca_components): """Initialize the pca_projection layer with PCA loadings, gate closed. diff --git a/tests/test_pca_init.py b/tests/test_pca_init.py index 35259ad..7b3eb29 100644 --- a/tests/test_pca_init.py +++ b/tests/test_pca_init.py @@ -50,6 +50,17 @@ def test_train_builds_pca_projection(genotype_data, basic_config): assert layer.units == 8 +def test_pca_components_auto_resolves_to_elbow(genotype_data, basic_config): + """pca_components='auto' picks a concrete projection width from the scree.""" + genotypes, samples, _, _, _ = genotype_data + loc = Locator(_pca_config(basic_config, pca_components="auto", max_epochs=2)) + loc.train(genotypes=genotypes, samples=samples) + + resolved = loc.config["pca_components"] + assert isinstance(resolved, int) and resolved >= 1 + assert loc.model.get_layer(PCA_LAYER_NAME).units == resolved + + def test_frozen_projection_keeps_pca_loadings(genotype_data, basic_config): """With pca_finetune=False the projection stays at its PCA initialization.""" genotypes, samples, _, _, _ = genotype_data @@ -151,3 +162,20 @@ def test_gradient_gate_controls_gradient_flow(): assert np.allclose(y.numpy(), x.numpy()) # Gradient reaching the input is scaled by the gate. assert np.allclose(grad.numpy(), expected_grad) + + +def test_scree_elbow_finds_low_rank_structure(): + """scree_elbow returns a small rank for data with a few dominant axes. + + Defined last: it runs TensorFlow ops (see the note above). + """ + from locator.pca import scree_elbow + + rng = np.random.default_rng(0) + n, p, k = 60, 2000, 3 + factors = rng.standard_normal((n, k)) + loadings = rng.standard_normal((k, p)) + X = factors @ loadings * 10.0 + rng.standard_normal((n, p)) + + rank = scree_elbow(X) + assert 2 <= rank <= 8