-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Issue: Refactor Data Class to Match v2.0 Contract Specification
Summary
Refactor multidms/data.py to align with the v2.0 Data class contract specification. This involves removing deprecated parameters, simplifying the API, removing scaled arrays functionality, and ensuring wildtype variants are first in each condition's data.
Motivation
The v2.0 refactor shifts responsibility for data aggregation from the Data class to users, removes the scaled arrays functionality (no longer needed for modeling), and ensures data is properly structured for the jaxmodels backend.
Required Changes
1. Remove collapse_identical_variants Parameter
Files affected:
multidms/data.py: Lines 81-86 (docstring), 194, 233, 264, 271-281
Changes:
- Remove parameter from
__init__signature (line 194) - Remove
self._collapse_identical_variantsassignment (line 233) - Remove weight overwrite comment referencing it (line 264)
- Remove the entire aggregation block (lines 271-281):
if self._collapse_identical_variants: agg_dict = { "weight": "sum", "func_score": self._collapse_identical_variants, } df = ( variants_df[cols] .assign(weight=1) .groupby(["condition", "aa_substitutions"], as_index=False) .aggregate(agg_dict) )
- Update docstring to remove parameter documentation
Migration for users:
# v1.x
data = Data(df, collapse_identical_variants='mean') # REMOVED
# v2.0
df_collapsed = df.groupby(['condition', 'aa_substitutions']).agg({'func_score': 'mean'}).reset_index()
data = Data(df_collapsed, reference='A')2. Ensure Wildtype is First Variant in Each Condition
Context: The jaxmodels.Data.from_multidms() method (line 64 in jaxmodels.py) assumes the wildtype is the first variant in each condition's data. It uses index [0] to get x_wt and [1:] to get all other variants. This is critical for the modeling to work correctly.
Critical Note: For non-reference conditions with non-identical sites, the wildtype (empty aa_substitutions) gets converted to bundle mutations in var_wrt_ref (e.g., "G3P"). Therefore, we must check aa_substitutions, NOT var_wrt_ref, when validating/sorting for wildtype.
Changes needed:
- Validate that wildtype exists for each condition BEFORE the
var_wrt_refconversion (around line 433) - Sort wildtype to be first BEFORE the conversion loop
- Raise an error if no explicit wildtype exists for a condition
Implementation:
Add validation and sorting BEFORE line 433 (before the var_wrt_ref conversion loop):
# Validate wildtype exists and sort to first position for each condition
def validate_and_sort_wt_first(group):
"""Ensure wildtype exists and sort it first."""
condition = group['condition'].iloc[0]
wt_mask = group['aa_substitutions'].str.strip() == ''
if not wt_mask.any():
raise ValueError(
f"No wildtype variant found for condition '{condition}'. "
f"Please include a row with empty 'aa_substitutions' for this condition."
)
return pd.concat([group[wt_mask], group[~wt_mask]])
df = df.groupby('condition', group_keys=False).apply(validate_and_sort_wt_first).reset_index(drop=True)
# Now proceed with var_wrt_ref conversion...
df = df.assign(var_wrt_ref=df["aa_substitutions"])This approach:
- Validates wildtype existence and sorts in a single pass (DRY)
- Uses
aa_substitutions(original input) rather thanvar_wrt_ref(computed) - Runs BEFORE the conversion loop so wildtype is properly first in all subsequent processing
3. Remove training_data Property (Keep Only arrays)
Current state: Two properties return the same data:
Data.arrays(line 636-638)Data.training_data(line 641-643) - alias
Changes:
- Remove the
training_dataproperty (lines 641-643) - Keep only
arraysas the canonical property name
4. Remove Scaled Arrays Functionality
Context: The scaled_arrays functionality was used for the old v1 modeling approach. The v2 jaxmodels backend handles bundle mutations differently and does not need this.
Files affected:
multidms/data.pymultidms/utils.py(removerereferencefunction entirely)
Changes in data.py:
-
Remove
self._scaled_arraysinitialization (line 461):self._scaled_arrays = {"X": {}, "y": y, "w": w} # REMOVE
-
Remove the scaled array computation (lines 497-499):
self._scaled_arrays["X"][condition] = rereference( X[condition], self._bundle_idxs[condition] ) # REMOVE
-
Remove
scaled_arraysproperty (lines 646-648) -
Remove
scaled_training_dataproperty (lines 651-653) -
Update
times_seencomputation (line 511): Compute from raw mutation occurrences (unscaled arrays). Thetimes_seencolumn represents how many variants in the dataset have that mutation.Change line 511 from:
times_seen = pd.Series(self._scaled_arrays["X"][condition].sum(0).todense())
to:
times_seen = pd.Series(X[condition].sum(0).todense())
-
Remove
rereferencefunction frommultidms/utils.py- it is no longer needed after removing scaled arrays. -
Update imports in
data.py(line 21):# Current from multidms.utils import rereference, split_subs # After refactor from multidms.utils import split_subs
5. Simplify Sparse Array Storage
Context: Currently, Data converts binarymap sparse arrays to jax.experimental.sparse.BCOO format (line 471-472). However, jaxmodels.Data.from_multidms() immediately converts these back to scipy.sparse.csr_array (line 68) because BCOO slicing has issues.
Research findings:
data.pyline 471-472: Converts to BCOOjaxmodels.pyline 66-72: Converts BCOO back to scipy, slices, then back to BCOO
Decision: Store as scipy sparse in Data class, let jaxmodels handle the BCOO conversion as needed.
Changes:
-
In
data.py, change line 471-473 from:X[condition] = sparse.BCOO.from_scipy_sparse( cond_bmap.binary_variants.tocoo() )
to:
X[condition] = cond_bmap.binary_variants.tocsr() # Keep as scipy sparse
-
Update the assertion on line 474 accordingly (or remove if not applicable to scipy sparse)
-
Update
jaxmodels.Data.from_multidms()to handle scipy sparse directly:# Simplified - X is already scipy sparse X = multidms_data.arrays["X"][condition] X = X[1:] # exclude WT (scipy sparse supports this) X = BCOO.from_scipy_sparse(X)
-
The
single_mut_encodingscached property (lines 701-718) usessparse.BCOO.fromdense- keep as-is since BCOO is appropriate for this cached property.
6. Update Validation and Error Messages
Contract specifies error message format:
ValueError: DataFrame missing required columns: {missing}.
Expected: condition, aa_substitutions, func_score.
Found: {present}.
Changes:
Add explicit check for required columns at the VERY START of __init__ (immediately after the docstring, line 202), before any DataFrame operations:
def __init__(self, variants_df, reference, ...):
"""See main class docstring."""
# Validate required columns FIRST (before any DataFrame access)
required_cols = ["condition", "aa_substitutions", "func_score"]
missing_cols = [col for col in required_cols if col not in variants_df.columns]
if missing_cols:
raise ValueError(
f"DataFrame missing required columns: {missing_cols}. "
f"Expected: {required_cols}. "
f"Found: {list(variants_df.columns)}."
)
# Then proceed with existing validation...
if pd.isnull(variants_df["condition"]).any():
...This must be placed BEFORE line 204 (if pd.isnull(variants_df["condition"]).any()) which already tries to access DataFrame columns.
7. Keep weights Behavior As-Is
The current behavior only assigns weights if a "weight" column exists in variants_df (line 261-264, 483-484). This is the desired behavior - no changes needed here.
Test Updates
File: tests/test_data.py
Required changes:
-
Remove/update
test_non_identical_conversion()(lines 170-186): Usescollapse_identical_variants="mean". Update to pre-aggregate data:def test_non_identical_conversion(): # Pre-aggregate the data df_collapsed = TEST_FUNC_SCORES.groupby( ['condition', 'aa_substitutions'] ).agg({'func_score': 'mean'}).reset_index() data = multidms.Data( df_collapsed, alphabet=multidms.AAS_WITHSTOP, reference="a", assert_site_integrity=True, ) # ... rest of test assertions
-
Add test for missing wildtype error:
def test_missing_wildtype_error(): """Test that error is raised when wildtype missing for a condition.""" # Remove all wildtype rows (empty aa_substitutions) df = TEST_FUNC_SCORES.query("aa_substitutions != ''") with pytest.raises(ValueError, match="No wildtype variant found"): multidms.Data(df, reference='a', alphabet=multidms.AAS_WITHSTOP)
-
Add test for wildtype-first invariant:
def test_wildtype_first_in_conditions(): """Test that wildtype is the first variant for each condition.""" # Create data where wildtype is NOT first in input df = TEST_FUNC_SCORES.copy() df = df.iloc[::-1].reset_index(drop=True) # Reverse order data = multidms.Data(df, reference='a', alphabet=multidms.AAS_WITHSTOP) # Verify wildtype (empty aa_substitutions) is first for each condition for condition in data.conditions: cond_df = data.variants_df[data.variants_df['condition'] == condition] first_aa_subs = cond_df.iloc[0]['aa_substitutions'] assert first_aa_subs.strip() == '', ( f"Wildtype should be first for condition {condition}, " f"but got '{first_aa_subs}'" )
-
Add test for error message format:
def test_missing_columns_error_message(): """Test that missing columns produce properly formatted error.""" df = pd.DataFrame({'condition': ['a'], 'func_score': [1.0]}) # missing aa_substitutions with pytest.raises(ValueError, match="DataFrame missing required columns"): multidms.Data(df, reference='a')
-
Add backward compatibility test:
def test_backward_compatibility(): """Ensure non-breaking changes remain compatible with v1.x usage.""" data = multidms.Data( TEST_FUNC_SCORES, alphabet=multidms.AAS_WITHSTOP, reference="a", ) # Verify all v1.x properties still work assert data.reference == "a" assert len(data.conditions) == 2 assert len(data.mutations) > 0 assert 'condition' in data.variants_df.columns assert 'mutation' in data.mutations_df.columns assert data.arrays is not None
-
Remove any tests that use
scaled_arraysorscaled_training_data(currently in commented v1 tests)
Notebook Updates
Files affected:
notebooks/model_collection.ipynb(line 233)notebooks/fit_delta_BA1_example.ipynb(lines 427, 2372)
Changes: Replace collapse_identical_variants usage with explicit pre-aggregation:
# Before (v1.x)
data = multidms.Data(
df,
collapse_identical_variants="mean",
reference="reference",
)
# After (v2.0)
df_collapsed = df.groupby(['condition', 'aa_substitutions']).agg({
'func_score': 'mean'
}).reset_index()
data = multidms.Data(df_collapsed, reference="reference")Implementation Checklist
Constructor Changes
- Add required columns validation at very start of
__init__(before line 204) - Remove
collapse_identical_variantsparameter from constructor signature - Remove
self._collapse_identical_variantsassignment - Remove aggregation logic block
- Update constructor docstring to remove
collapse_identical_variantsdocumentation
Wildtype Handling
- Add
validate_and_sort_wt_first()function before line 433 - Apply sorting/validation before
var_wrt_refconversion loop
Property Removals
- Remove
training_dataproperty (keep onlyarrays) - Remove
scaled_arraysproperty - Remove
scaled_training_dataproperty
Scaled Arrays Removal
- Remove
self._scaled_arraysinitialization (line 461) - Remove scaled array computation in condition loop (lines 497-499)
- Update
times_seencomputation to use unscaled arrays - Remove
rereferencefunction frommultidms/utils.py - Update import in
data.pyto removerereference
Sparse Array Simplification
- Change
X[condition]storage from BCOO to scipy sparse (.tocsr()) - Remove or update assertion on line 474
- Update
jaxmodels.Data.from_multidms()to handle scipy sparse directly
Tests
- Update
test_non_identical_conversion()to pre-aggregate data - Add
test_missing_wildtype_error() - Add
test_wildtype_first_in_conditions() - Add
test_missing_columns_error_message() - Add
test_backward_compatibility()
Notebooks
- Update
notebooks/model_collection.ipynb - Update
notebooks/fit_delta_BA1_example.ipynb
Breaking Changes Summary
| Change | Impact | Migration |
|---|---|---|
collapse_identical_variants removed |
Users must aggregate data before creating Data object |
df.groupby(['condition', 'aa_substitutions']).agg({'func_score': 'mean'}) |
training_data property removed |
Use arrays instead |
Replace .training_data with .arrays |
scaled_arrays property removed |
No longer available | Not needed for v2 modeling |
scaled_training_data property removed |
No longer available | Not needed for v2 modeling |
| Wildtype must exist for each condition | Error raised if missing | Include row with empty aa_substitutions for each condition |
| Wildtype sorted to first position | Data reordered internally | No action needed; happens automatically |
times_seen_* computation changed |
Values computed from raw arrays instead of scaled | Values now represent actual mutation occurrences, not bundle-adjusted counts |
| Sparse arrays stored as scipy | Advanced users accessing data.arrays["X"] see different type |
Convert to BCOO if needed: BCOO.from_scipy_sparse(X) |
Related Files
multidms/data.py(primary)multidms/jaxmodels.py(consumer of Data)multidms/model.py(consumer of Data)multidms/utils.py(removerereferencefunction)tests/test_data.pynotebooks/model_collection.ipynbnotebooks/fit_delta_BA1_example.ipynb- Contract:
specs/001-jaxmodels-refactor/contracts/data_class.md