Skip to content

Commit a6cba80

Browse files
jgallowa07claude
andauthored
Implement add_phenotypes_to_df for predictions on new variants (#181)
* Implement add_phenotypes_to_df for predictions on new variants (#173) This PR implements Model.add_phenotypes_to_df() to enable predictions on new variant data not seen during training, addressing issue #173. **Core Implementation:** - Implemented add_phenotypes_to_df() method in Model class - Converts input DataFrames to jaxmodels.Data format for predictions - Handles substitution conversion to reference frame - Validates mutations and raises informative errors for unseen mutations - Preserves all input DataFrame columns in output **Verbosity Control (Bonus Feature):** - Added verbose parameter to jaxmodels.fit() and Model.fit() - Enables silent fitting for doctests and automated workflows - Wrapped all progress print statements with verbose checks **Testing:** - Added 13 comprehensive unit tests covering all functionality - Includes explicit parameter validation test - All 36 model tests pass (13 new) - Includes working doctest example **Code Quality:** - Ruff linting: ✓ All checks passed - Black formatting: ✓ All checks passed - Full test coverage with edge cases Files changed: - multidms/model.py: Core implementation (+167 lines) - multidms/jaxmodels.py: Verbose parameter (+77 lines) - tests/test_model.py: Comprehensive tests (+272 lines) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * Fix Black action failing on Python 3.9 Pin psf/black action to v24.10.0 instead of @stable to fix TypeError with union type syntax (str | None) on Python 3.9. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]> --------- Co-authored-by: Claude <[email protected]>
1 parent 912f406 commit a6cba80

File tree

4 files changed

+477
-41
lines changed

4 files changed

+477
-41
lines changed

.github/workflows/build_test_package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
pip install -e ".[dev]"
4545
4646
- name: Black Format Check
47-
uses: psf/black@stable
47+
uses: psf/black@24.10.0
4848
with:
4949
options: "--check"
5050
src: "."

multidms/jaxmodels.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ def fit(
371371
beta_init: dict[str, Float[Array, " n_mutations"]] | None = None,
372372
alpha_init: dict[str, Float] | None = None,
373373
beta_clip_range: tuple[Float, Float] | None = None,
374+
verbose: bool = True,
374375
) -> tuple[Model, list[float]]:
375376
r"""
376377
Fit a model to data.
@@ -406,6 +407,8 @@ def fit(
406407
If None, no clipping is applied. Example: (-10.0, 10.0).
407408
This constrains mutation effect parameters during optimization
408409
to prevent extreme values.
410+
verbose: Whether to print progress information during fitting (default: True).
411+
If False, suppresses all print output.
409412
410413
Returns:
411414
Tuple of (fitted model, loss trajectory).
@@ -571,7 +574,8 @@ def prox_block(β_block, hyperparameters, scaling=1.0):
571574

572575
try:
573576
for k in range(block_iters):
574-
print(f"iter {k + 1}:")
577+
if verbose:
578+
print(f"iter {k + 1}:")
575579
obj_old = objective_total(
576580
model,
577581
data_sets,
@@ -589,13 +593,16 @@ def prox_block(β_block, hyperparameters, scaling=1.0):
589593
model_calibration, model_rest, data_sets, scale=scale
590594
)
591595
model = eqx.combine(model_calibration, model_rest)
592-
print(
593-
f" calibration block: error={state_calibration.error:.2e}, "
594-
f"stepsize={state_calibration.stepsize:.1e}, "
595-
f"iter={state_calibration.iter_num}"
596-
)
597-
for d in model.φ:
598-
print(f" {d}: α={model.α[d]:.2f}, θ={jnp.exp(model.logθ[d]):.2f}")
596+
if verbose:
597+
print(
598+
f" calibration block: error={state_calibration.error:.2e}, "
599+
f"stepsize={state_calibration.stepsize:.1e}, "
600+
f"iter={state_calibration.iter_num}"
601+
)
602+
for d in model.φ:
603+
print(
604+
f" {d}: α={model.α[d]:.2f}, θ={jnp.exp(model.logθ[d]):.2f}"
605+
)
599606

600607
# β0 block
601608
model_β0, model_rest = eqx.partition(model, filter_spec=filter_spec_β0)
@@ -607,12 +614,13 @@ def prox_block(β_block, hyperparameters, scaling=1.0):
607614
beta0_ridge=beta0_ridge,
608615
)
609616
model = eqx.combine(model_β0, model_rest)
610-
print(
611-
f" β0 block: error={state_β0.error:.2e}, "
612-
f"stepsize={state_β0.stepsize:.1e}, iter={state_β0.iter_num}"
613-
)
614-
for d in model.φ:
615-
print(f" {d}: β0={model.φ[d].β0:.2f}")
617+
if verbose:
618+
print(
619+
f" β0 block: error={state_β0.error:.2e}, "
620+
f"stepsize={state_β0.stepsize:.1e}, iter={state_β0.iter_num}"
621+
)
622+
for d in model.φ:
623+
print(f" {d}: β0={model.φ[d].β0:.2f}")
616624

617625
# determine bundle idxs (mutations that are non-wt in any condition)
618626
bundle_idxs = jax.lax.associative_scan(
@@ -644,11 +652,12 @@ def prox_block(β_block, hyperparameters, scaling=1.0):
644652
model,
645653
model.φ[d].β.at[idxs].set(β_block[d]),
646654
)
647-
print(
648-
f" β_nonbundle: error={state_nonbundle.error:.2e}, "
649-
f"stepsize={state_nonbundle.stepsize:.1e}, "
650-
f"iter={state_nonbundle.iter_num}"
651-
)
655+
if verbose:
656+
print(
657+
f" β_nonbundle: error={state_nonbundle.error:.2e}, "
658+
f"stepsize={state_nonbundle.stepsize:.1e}, "
659+
f"iter={state_nonbundle.iter_num}"
660+
)
652661

653662
# β bundle block
654663
idxs = jnp.where(bundle_idxs)[0]
@@ -674,19 +683,21 @@ def prox_block(β_block, hyperparameters, scaling=1.0):
674683
model,
675684
model.φ[d].β.at[idxs].set(β_block[d]),
676685
)
677-
print(
678-
f" β_bundle: error={state_bundle.error:.2e}, "
679-
f"stepsize={state_bundle.stepsize:.1e}, "
680-
f"iter={state_bundle.iter_num}"
681-
)
686+
if verbose:
687+
print(
688+
f" β_bundle: error={state_bundle.error:.2e}, "
689+
f"stepsize={state_bundle.stepsize:.1e}, "
690+
f"iter={state_bundle.iter_num}"
691+
)
682692

683693
# diagnostics
684-
for d in model.φ:
685-
if d != model.reference_condition:
686-
sparsity = (
687-
model.φ[d].β - model.φ[model.reference_condition].β == 0
688-
).mean()
689-
print(f" {d} sparsity={sparsity:.1%}")
694+
if verbose:
695+
for d in model.φ:
696+
if d != model.reference_condition:
697+
sparsity = (
698+
model.φ[d].β - model.φ[model.reference_condition].β == 0
699+
).mean()
700+
print(f" {d} sparsity={sparsity:.1%}")
690701

691702
obj = objective_total(
692703
model,
@@ -696,9 +707,11 @@ def prox_block(β_block, hyperparameters, scaling=1.0):
696707
scale=scale,
697708
beta0_ridge=beta0_ridge,
698709
)
699-
print(f" {obj=:.2e}")
710+
if verbose:
711+
print(f" {obj=:.2e}")
700712
objective_error = abs(obj_old - obj) / max(abs(obj_old), abs(obj), 1)
701-
print(f" {objective_error=:.2e}")
713+
if verbose:
714+
print(f" {objective_error=:.2e}")
702715

703716
# store loss for trajectory
704717
loss_trajectory.append(float(obj))

multidms/model.py

Lines changed: 159 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def fit(
137137
ge_kwargs: dict = None,
138138
cal_kwargs: dict = None,
139139
loss_kwargs: dict = None,
140+
verbose: bool = True,
140141
):
141142
"""
142143
Fit the model to data.
@@ -164,6 +165,8 @@ def fit(
164165
Keyword arguments for calibration (α) optimizer (e.g., tol, maxiter, maxls).
165166
loss_kwargs : dict, optional
166167
Keyword arguments for the loss function (e.g., δ for Huber loss).
168+
verbose : bool
169+
Whether to print progress information during fitting (default: True).
167170
168171
Returns
169172
-------
@@ -219,6 +222,7 @@ def fit(
219222
ge_kwargs=ge_kwargs,
220223
cal_kwargs=cal_kwargs,
221224
loss_kwargs=loss_kwargs,
225+
verbose=verbose,
222226
)
223227

224228
return self
@@ -314,29 +318,176 @@ def get_variants_df(self, phenotype_as_effect: bool = True) -> pd.DataFrame:
314318
def add_phenotypes_to_df(
315319
self,
316320
df: pd.DataFrame,
317-
phenotype_as_effect: bool = True,
321+
substitutions_col: str = "aa_substitutions",
322+
condition_col: str = "condition",
323+
predicted_phenotype_col: str = "predicted_func_score",
324+
overwrite_cols: bool = False,
318325
) -> pd.DataFrame:
319326
"""
320327
Add model predictions to a DataFrame of variants.
321328
322329
Parameters
323330
----------
324331
df : pd.DataFrame
325-
DataFrame with 'condition' and 'aa_substitutions' columns.
326-
phenotype_as_effect : bool
327-
If True, report effects. If False, report raw latent phenotypes.
332+
DataFrame with columns specified by `condition_col` and
333+
`substitutions_col`. Additional columns will be preserved in output.
334+
substitutions_col : str
335+
Column in `df` giving variants as substitution strings.
336+
Default is 'aa_substitutions'.
337+
condition_col : str
338+
Column in `df` giving the condition for each variant.
339+
Values must exist in the model's conditions. Default is 'condition'.
340+
predicted_phenotype_col : str
341+
Name of column to add containing predicted functional scores.
342+
Default is 'predicted_func_score'.
343+
overwrite_cols : bool
344+
If the specified predicted phenotype column already exists in `df`,
345+
overwrite it? If False, raise an error.
328346
329347
Returns
330348
-------
331349
pd.DataFrame
332-
Input DataFrame with added prediction columns.
350+
A copy of `df` with predictions added.
351+
352+
Raises
353+
------
354+
ValueError
355+
If model is not fitted, required columns are missing, indices are
356+
not unique, conditions are invalid, or substitutions contain
357+
mutations not seen during training.
358+
359+
Example
360+
-------
361+
>>> import pandas as pd
362+
>>> from multidms import Data, Model
363+
>>> df_train = pd.DataFrame({
364+
... 'condition': ['a', 'a', 'b', 'b'],
365+
... 'aa_substitutions': ['', 'M1A', '', 'M1A'],
366+
... 'func_score': [0.0, 1.2, 0.1, 1.5]
367+
... })
368+
>>> data = Data(df_train, reference='a') # doctest: +ELLIPSIS
369+
>>> model = Model(data, ge_type='Identity', l2reg=0.01)
370+
>>> _ = model.fit(maxiter=5, warmstart=False, verbose=False)
371+
>>> df_new = pd.DataFrame({
372+
... 'condition': ['a', 'b'],
373+
... 'aa_substitutions': ['M1A', 'M1A']
374+
... })
375+
>>> result = model.add_phenotypes_to_df(df_new)
376+
>>> 'predicted_func_score' in result.columns
377+
True
378+
>>> len(result)
379+
2
333380
"""
334381
if self._jax_model is None:
335382
raise ValueError("Model has not been fitted. Call fit() first.")
336383

337-
# See issue #173 for implementing prediction on new variants
338-
# and calling predict_score()
339-
raise NotImplementedError("add_phenotypes_to_df is not yet implemented in v2.0")
384+
# Validate input
385+
if substitutions_col not in df.columns:
386+
raise ValueError(f"`df` lacks column '{substitutions_col}'")
387+
if condition_col not in df.columns:
388+
raise ValueError(f"`df` lacks column '{condition_col}'")
389+
if not df.index.is_unique:
390+
raise ValueError("`df` must have unique indices")
391+
392+
# Check for invalid conditions
393+
invalid_conditions = set(df[condition_col]) - set(self._data.conditions)
394+
if invalid_conditions:
395+
raise ValueError(
396+
f"Invalid conditions in df: {invalid_conditions}. "
397+
f"Valid conditions: {self._data.conditions}"
398+
)
399+
400+
# Return copy
401+
ret = df.copy()
402+
403+
# Check if column exists and handle overwrite
404+
if predicted_phenotype_col in ret.columns and not overwrite_cols:
405+
raise ValueError(
406+
f"`df` already contains column '{predicted_phenotype_col}'. "
407+
"Set overwrite_cols=True to overwrite."
408+
)
409+
410+
# Initialize prediction column
411+
ret[predicted_phenotype_col] = np.nan
412+
413+
# Get reference binarymap for encoding
414+
ref_bmap = self._data.binarymaps[self._data.reference]
415+
416+
# Process each condition separately
417+
for condition, condition_df in df.groupby(condition_col):
418+
# Convert substitutions to reference frame if needed
419+
variant_subs = condition_df[substitutions_col]
420+
if condition not in self._data.reference_sequence_conditions:
421+
variant_subs = condition_df.apply(
422+
lambda x: self._data.convert_subs_wrt_ref_seq(
423+
condition, x[substitutions_col]
424+
),
425+
axis=1,
426+
)
427+
428+
# Build binary variant matrix
429+
row_ind = [] # row indices of elements that are one
430+
col_ind = [] # column indices of elements that are one
431+
unseen_mutations = set()
432+
433+
for ivariant, subs in enumerate(variant_subs):
434+
try:
435+
for isub in ref_bmap.sub_str_to_indices(subs):
436+
row_ind.append(ivariant)
437+
col_ind.append(isub)
438+
except ValueError:
439+
# Extract the individual mutations that are unseen
440+
if subs: # non-empty string
441+
for mut in subs.split():
442+
if mut not in self._data.mutations:
443+
unseen_mutations.add(mut)
444+
445+
# If there are unseen mutations, raise an error
446+
if unseen_mutations:
447+
raise ValueError(
448+
f"Variants contain mutations not seen during training: "
449+
f"{sorted(unseen_mutations)}"
450+
)
451+
452+
# Create sparse matrix
453+
import scipy.sparse
454+
from jax.experimental import sparse as jsparse
455+
456+
X = jsparse.BCOO.from_scipy_sparse(
457+
scipy.sparse.csr_matrix(
458+
(np.ones(len(row_ind), dtype="int8"), (row_ind, col_ind)),
459+
shape=(len(condition_df), ref_bmap.binarylength),
460+
dtype="int8",
461+
)
462+
)
463+
464+
# Create jaxmodels.Data object for this condition
465+
# We need x_wt from the training data
466+
x_wt = self._jax_data_sets[condition].x_wt
467+
468+
# Create a temporary Data object with dummy functional scores
469+
import multidms.jaxmodels as jaxmodels
470+
471+
temp_data = jaxmodels.Data(
472+
x_wt=x_wt,
473+
X=X,
474+
functional_scores=np.zeros(len(condition_df)), # dummy values
475+
)
476+
477+
# Make predictions using jaxmodels
478+
temp_data_sets = {condition: temp_data}
479+
predictions = self._jax_model.predict_score(temp_data_sets)
480+
481+
# Extract predictions for this condition
482+
phenotype_predictions = np.array(predictions[condition])
483+
assert len(phenotype_predictions) == len(condition_df)
484+
485+
# Add predictions to result dataframe
486+
ret.loc[
487+
condition_df.index.values, predicted_phenotype_col
488+
] = phenotype_predictions
489+
490+
return ret
340491

341492
def __repr__(self):
342493
"""String representation."""

0 commit comments

Comments
 (0)