Skip to content

Refactor Data Class to Match v2.0 Contract Specification #182

@jgallowa07

Description

@jgallowa07

Issue: Refactor Data Class to Match v2.0 Contract Specification

Summary

Refactor multidms/data.py to align with the v2.0 Data class contract specification. This involves removing deprecated parameters, simplifying the API, removing scaled arrays functionality, and ensuring wildtype variants are first in each condition's data.

Motivation

The v2.0 refactor shifts responsibility for data aggregation from the Data class to users, removes the scaled arrays functionality (no longer needed for modeling), and ensures data is properly structured for the jaxmodels backend.


Required Changes

1. Remove collapse_identical_variants Parameter

Files affected:

  • multidms/data.py: Lines 81-86 (docstring), 194, 233, 264, 271-281

Changes:

  1. Remove parameter from __init__ signature (line 194)
  2. Remove self._collapse_identical_variants assignment (line 233)
  3. Remove weight overwrite comment referencing it (line 264)
  4. Remove the entire aggregation block (lines 271-281):
    if self._collapse_identical_variants:
        agg_dict = {
            "weight": "sum",
            "func_score": self._collapse_identical_variants,
        }
        df = (
            variants_df[cols]
            .assign(weight=1)
            .groupby(["condition", "aa_substitutions"], as_index=False)
            .aggregate(agg_dict)
        )
  5. Update docstring to remove parameter documentation

Migration for users:

# v1.x
data = Data(df, collapse_identical_variants='mean')  # REMOVED

# v2.0
df_collapsed = df.groupby(['condition', 'aa_substitutions']).agg({'func_score': 'mean'}).reset_index()
data = Data(df_collapsed, reference='A')

2. Ensure Wildtype is First Variant in Each Condition

Context: The jaxmodels.Data.from_multidms() method (line 64 in jaxmodels.py) assumes the wildtype is the first variant in each condition's data. It uses index [0] to get x_wt and [1:] to get all other variants. This is critical for the modeling to work correctly.

Critical Note: For non-reference conditions with non-identical sites, the wildtype (empty aa_substitutions) gets converted to bundle mutations in var_wrt_ref (e.g., "G3P"). Therefore, we must check aa_substitutions, NOT var_wrt_ref, when validating/sorting for wildtype.

Changes needed:

  1. Validate that wildtype exists for each condition BEFORE the var_wrt_ref conversion (around line 433)
  2. Sort wildtype to be first BEFORE the conversion loop
  3. Raise an error if no explicit wildtype exists for a condition

Implementation:
Add validation and sorting BEFORE line 433 (before the var_wrt_ref conversion loop):

# Validate wildtype exists and sort to first position for each condition
def validate_and_sort_wt_first(group):
    """Ensure wildtype exists and sort it first."""
    condition = group['condition'].iloc[0]
    wt_mask = group['aa_substitutions'].str.strip() == ''

    if not wt_mask.any():
        raise ValueError(
            f"No wildtype variant found for condition '{condition}'. "
            f"Please include a row with empty 'aa_substitutions' for this condition."
        )

    return pd.concat([group[wt_mask], group[~wt_mask]])

df = df.groupby('condition', group_keys=False).apply(validate_and_sort_wt_first).reset_index(drop=True)

# Now proceed with var_wrt_ref conversion...
df = df.assign(var_wrt_ref=df["aa_substitutions"])

This approach:

  • Validates wildtype existence and sorts in a single pass (DRY)
  • Uses aa_substitutions (original input) rather than var_wrt_ref (computed)
  • Runs BEFORE the conversion loop so wildtype is properly first in all subsequent processing

3. Remove training_data Property (Keep Only arrays)

Current state: Two properties return the same data:

  • Data.arrays (line 636-638)
  • Data.training_data (line 641-643) - alias

Changes:

  1. Remove the training_data property (lines 641-643)
  2. Keep only arrays as the canonical property name

4. Remove Scaled Arrays Functionality

Context: The scaled_arrays functionality was used for the old v1 modeling approach. The v2 jaxmodels backend handles bundle mutations differently and does not need this.

Files affected:

  • multidms/data.py
  • multidms/utils.py (remove rereference function entirely)

Changes in data.py:

  1. Remove self._scaled_arrays initialization (line 461):

    self._scaled_arrays = {"X": {}, "y": y, "w": w}  # REMOVE
  2. Remove the scaled array computation (lines 497-499):

    self._scaled_arrays["X"][condition] = rereference(
        X[condition], self._bundle_idxs[condition]
    )  # REMOVE
  3. Remove scaled_arrays property (lines 646-648)

  4. Remove scaled_training_data property (lines 651-653)

  5. Update times_seen computation (line 511): Compute from raw mutation occurrences (unscaled arrays). The times_seen column represents how many variants in the dataset have that mutation.

    Change line 511 from:

    times_seen = pd.Series(self._scaled_arrays["X"][condition].sum(0).todense())

    to:

    times_seen = pd.Series(X[condition].sum(0).todense())
  6. Remove rereference function from multidms/utils.py - it is no longer needed after removing scaled arrays.

  7. Update imports in data.py (line 21):

    # Current
    from multidms.utils import rereference, split_subs
    
    # After refactor
    from multidms.utils import split_subs

5. Simplify Sparse Array Storage

Context: Currently, Data converts binarymap sparse arrays to jax.experimental.sparse.BCOO format (line 471-472). However, jaxmodels.Data.from_multidms() immediately converts these back to scipy.sparse.csr_array (line 68) because BCOO slicing has issues.

Research findings:

  • data.py line 471-472: Converts to BCOO
  • jaxmodels.py line 66-72: Converts BCOO back to scipy, slices, then back to BCOO

Decision: Store as scipy sparse in Data class, let jaxmodels handle the BCOO conversion as needed.

Changes:

  1. In data.py, change line 471-473 from:

    X[condition] = sparse.BCOO.from_scipy_sparse(
        cond_bmap.binary_variants.tocoo()
    )

    to:

    X[condition] = cond_bmap.binary_variants.tocsr()  # Keep as scipy sparse
  2. Update the assertion on line 474 accordingly (or remove if not applicable to scipy sparse)

  3. Update jaxmodels.Data.from_multidms() to handle scipy sparse directly:

    # Simplified - X is already scipy sparse
    X = multidms_data.arrays["X"][condition]
    X = X[1:]  # exclude WT (scipy sparse supports this)
    X = BCOO.from_scipy_sparse(X)
  4. The single_mut_encodings cached property (lines 701-718) uses sparse.BCOO.fromdense - keep as-is since BCOO is appropriate for this cached property.


6. Update Validation and Error Messages

Contract specifies error message format:

ValueError: DataFrame missing required columns: {missing}.
Expected: condition, aa_substitutions, func_score.
Found: {present}.

Changes:
Add explicit check for required columns at the VERY START of __init__ (immediately after the docstring, line 202), before any DataFrame operations:

def __init__(self, variants_df, reference, ...):
    """See main class docstring."""

    # Validate required columns FIRST (before any DataFrame access)
    required_cols = ["condition", "aa_substitutions", "func_score"]
    missing_cols = [col for col in required_cols if col not in variants_df.columns]
    if missing_cols:
        raise ValueError(
            f"DataFrame missing required columns: {missing_cols}. "
            f"Expected: {required_cols}. "
            f"Found: {list(variants_df.columns)}."
        )

    # Then proceed with existing validation...
    if pd.isnull(variants_df["condition"]).any():
        ...

This must be placed BEFORE line 204 (if pd.isnull(variants_df["condition"]).any()) which already tries to access DataFrame columns.


7. Keep weights Behavior As-Is

The current behavior only assigns weights if a "weight" column exists in variants_df (line 261-264, 483-484). This is the desired behavior - no changes needed here.


Test Updates

File: tests/test_data.py

Required changes:

  1. Remove/update test_non_identical_conversion() (lines 170-186): Uses collapse_identical_variants="mean". Update to pre-aggregate data:

    def test_non_identical_conversion():
        # Pre-aggregate the data
        df_collapsed = TEST_FUNC_SCORES.groupby(
            ['condition', 'aa_substitutions']
        ).agg({'func_score': 'mean'}).reset_index()
    
        data = multidms.Data(
            df_collapsed,
            alphabet=multidms.AAS_WITHSTOP,
            reference="a",
            assert_site_integrity=True,
        )
        # ... rest of test assertions
  2. Add test for missing wildtype error:

    def test_missing_wildtype_error():
        """Test that error is raised when wildtype missing for a condition."""
        # Remove all wildtype rows (empty aa_substitutions)
        df = TEST_FUNC_SCORES.query("aa_substitutions != ''")
        with pytest.raises(ValueError, match="No wildtype variant found"):
            multidms.Data(df, reference='a', alphabet=multidms.AAS_WITHSTOP)
  3. Add test for wildtype-first invariant:

    def test_wildtype_first_in_conditions():
        """Test that wildtype is the first variant for each condition."""
        # Create data where wildtype is NOT first in input
        df = TEST_FUNC_SCORES.copy()
        df = df.iloc[::-1].reset_index(drop=True)  # Reverse order
    
        data = multidms.Data(df, reference='a', alphabet=multidms.AAS_WITHSTOP)
    
        # Verify wildtype (empty aa_substitutions) is first for each condition
        for condition in data.conditions:
            cond_df = data.variants_df[data.variants_df['condition'] == condition]
            first_aa_subs = cond_df.iloc[0]['aa_substitutions']
            assert first_aa_subs.strip() == '', (
                f"Wildtype should be first for condition {condition}, "
                f"but got '{first_aa_subs}'"
            )
  4. Add test for error message format:

    def test_missing_columns_error_message():
        """Test that missing columns produce properly formatted error."""
        df = pd.DataFrame({'condition': ['a'], 'func_score': [1.0]})  # missing aa_substitutions
        with pytest.raises(ValueError, match="DataFrame missing required columns"):
            multidms.Data(df, reference='a')
  5. Add backward compatibility test:

    def test_backward_compatibility():
        """Ensure non-breaking changes remain compatible with v1.x usage."""
        data = multidms.Data(
            TEST_FUNC_SCORES,
            alphabet=multidms.AAS_WITHSTOP,
            reference="a",
        )
    
        # Verify all v1.x properties still work
        assert data.reference == "a"
        assert len(data.conditions) == 2
        assert len(data.mutations) > 0
        assert 'condition' in data.variants_df.columns
        assert 'mutation' in data.mutations_df.columns
        assert data.arrays is not None
  6. Remove any tests that use scaled_arrays or scaled_training_data (currently in commented v1 tests)


Notebook Updates

Files affected:

  • notebooks/model_collection.ipynb (line 233)
  • notebooks/fit_delta_BA1_example.ipynb (lines 427, 2372)

Changes: Replace collapse_identical_variants usage with explicit pre-aggregation:

# Before (v1.x)
data = multidms.Data(
    df,
    collapse_identical_variants="mean",
    reference="reference",
)

# After (v2.0)
df_collapsed = df.groupby(['condition', 'aa_substitutions']).agg({
    'func_score': 'mean'
}).reset_index()
data = multidms.Data(df_collapsed, reference="reference")

Implementation Checklist

Constructor Changes

  • Add required columns validation at very start of __init__ (before line 204)
  • Remove collapse_identical_variants parameter from constructor signature
  • Remove self._collapse_identical_variants assignment
  • Remove aggregation logic block
  • Update constructor docstring to remove collapse_identical_variants documentation

Wildtype Handling

  • Add validate_and_sort_wt_first() function before line 433
  • Apply sorting/validation before var_wrt_ref conversion loop

Property Removals

  • Remove training_data property (keep only arrays)
  • Remove scaled_arrays property
  • Remove scaled_training_data property

Scaled Arrays Removal

  • Remove self._scaled_arrays initialization (line 461)
  • Remove scaled array computation in condition loop (lines 497-499)
  • Update times_seen computation to use unscaled arrays
  • Remove rereference function from multidms/utils.py
  • Update import in data.py to remove rereference

Sparse Array Simplification

  • Change X[condition] storage from BCOO to scipy sparse (.tocsr())
  • Remove or update assertion on line 474
  • Update jaxmodels.Data.from_multidms() to handle scipy sparse directly

Tests

  • Update test_non_identical_conversion() to pre-aggregate data
  • Add test_missing_wildtype_error()
  • Add test_wildtype_first_in_conditions()
  • Add test_missing_columns_error_message()
  • Add test_backward_compatibility()

Notebooks

  • Update notebooks/model_collection.ipynb
  • Update notebooks/fit_delta_BA1_example.ipynb

Breaking Changes Summary

Change Impact Migration
collapse_identical_variants removed Users must aggregate data before creating Data object df.groupby(['condition', 'aa_substitutions']).agg({'func_score': 'mean'})
training_data property removed Use arrays instead Replace .training_data with .arrays
scaled_arrays property removed No longer available Not needed for v2 modeling
scaled_training_data property removed No longer available Not needed for v2 modeling
Wildtype must exist for each condition Error raised if missing Include row with empty aa_substitutions for each condition
Wildtype sorted to first position Data reordered internally No action needed; happens automatically
times_seen_* computation changed Values computed from raw arrays instead of scaled Values now represent actual mutation occurrences, not bundle-adjusted counts
Sparse arrays stored as scipy Advanced users accessing data.arrays["X"] see different type Convert to BCOO if needed: BCOO.from_scipy_sparse(X)

Related Files

  • multidms/data.py (primary)
  • multidms/jaxmodels.py (consumer of Data)
  • multidms/model.py (consumer of Data)
  • multidms/utils.py (remove rereference function)
  • tests/test_data.py
  • notebooks/model_collection.ipynb
  • notebooks/fit_delta_BA1_example.ipynb
  • Contract: specs/001-jaxmodels-refactor/contracts/data_class.md

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions