Skip to content

Commit 32f120d

Browse files
jgallowa07claude
andauthored
Complete Model v2 wrapper implementation with shift parameters (#180)
* Complete Model v2 wrapper implementation with shift parameters (#172) Implements comprehensive Model wrapper for jaxmodels backend with wide format API, shift parameters, and extensive test coverage. Key Features: - Wide format get_mutations_df(): One row per mutation with beta_{condition} and shift_{condition} columns (matches v1 API) - Shift parameters automatically calculated as β_condition - β_reference - Extended Model.fit() to expose advanced jaxmodels parameters via kwargs (beta_clip_range, ge_kwargs, cal_kwargs, loss_kwargs) - 24 comprehensive tests in test_model_v2.py covering all Model v2 functionality - Updated jaxmodels_simulation_fits.ipynb to use Model wrapper API Changes: - multidms/model.py: Complete v2 Model wrapper implementation - multidms/jaxmodels.py: Fixed type hints - tests/test_model_v2.py: 24 comprehensive tests (NEW) - tests/test_data.py: Commented out v1 Model tests - notebooks/jaxmodels/jaxmodels_simulation_fits.ipynb: Updated to Model API Tests: 61 passed, 1 skipped All linting and formatting checks pass Closes #172 Related: #173, #178, #179 * Fix CI: Exclude deprecated files from ruff linting Add biophysical.py and model_v1_backup.py to ruff exclude list in pyproject.toml. These are deprecated/backup files that should not block CI. * Remove deprecated files from repository Remove biophysical.py and model_v1_backup.py from version control as they are no longer needed for v2 implementation. These were causing CI linting failures and are superseded by the new jaxmodels-based Model wrapper. Users can access v1 functionality by using multidms v1.x releases. Changes: - Removed multidms/biophysical.py (deprecated with ImportError) - Removed multidms/model_v1_backup.py (backup of old implementation) - Cleaned up pyproject.toml excludes (no longer needed) * Optimize test_model.py for faster execution - Replace test_model_v2.py with optimized test_model.py - Add session-scoped data fixtures to avoid repeated creation - Add module-scoped fitted model fixtures shared across tests - Reduce iteration counts where sufficient (maxiter 5→3, 20→10) - Expected ~3-4x speedup from fixture reuse and reduced iterations - All 28 tests retained with 61 passed, 1 skipped 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * Fix docs build errors in CI - Fix CHANGELOG.rst link syntax (proper RST format) - Fix jaxtyping array shape annotation (space not comma) Fixes errors: - "Unknown target name" in CHANGELOG.rst line 33 - "Axes should be separated with spaces, not commas" in jaxmodels.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * remove unnecessary spec files, and simulated data * remove old REFACTOR PLAN * remove unnecessary claude commands * update gitignore --------- Co-authored-by: Claude <[email protected]>
1 parent c927da0 commit 32f120d

20 files changed

+1275056
-2794
lines changed

.gitignore

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@ Attic/
77
docs/multidms*rst
88
.vscode
99

10+
11+
12+
# ecperimental files
13+
notebooks/Smaller_beta_output/
14+
notebooks/linear_epistasis_output/
15+
notebooks/separate_offset_output/
16+
17+
# Specify kit related
18+
.claude/
19+
.specify/
20+
REFACTOR_PLAN.md
21+
specs/
22+
1023
# Byte-compiled / optimized / DLL files
1124
__pycache__/
1225
*.py[cod]
@@ -171,4 +184,4 @@ cython_debug/
171184

172185
notebooks/test_dump.pkl
173186
notebooks/jaxmodels/papermill_results
174-
notebooks/jaxmodels/output
187+
notebooks/jaxmodels/output

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ The format is based on `Keep a Changelog <https://keepachangelog.com>`_.
3030
0.4.0
3131
-----
3232
- new simulation validation analysis and plotting functions (at the time of re-submission)
33-
- fixes bug described in `#130 https://github.com/matsengrp/multidms/issues/130`_, having to do with pandas groupby.apply 2.2.0 behavior change.
33+
- fixes bug described in `#130 <https://github.com/matsengrp/multidms/issues/130>`_, having to do with pandas groupby.apply 2.2.0 behavior change.
3434
- updates python version requirements to 3.9 or newer, as 3.8 did not work with the new pandas version, 2.2.0 bug patch described above.
3535
- supresses the cpu warning from jax.
3636
- adds `ModelCollection.add_validation_loss <https://github.com/matsengrp/multidms/blob/b0e7cbe96216e1307d070adc531fe51a960ec32a/multidms/model_collection.py#L569>`_, `ModelCollection.get_conditional_loss_df <https://github.com/matsengrp/multidms/blob/b0e7cbe96216e1307d070adc531fe51a960ec32a/multidms/model_collection.py#L627>`_, `Model.conditional_loss <https://github.com/matsengrp/multidms/blob/b0e7cbe96216e1307d070adc531fe51a960ec32a/multidms/model.py#L379>`_, and `Model.get_df_loss <https://github.com/matsengrp/multidms/blob/b0e7cbe96216e1307d070adc531fe51a960ec32a/multidms/model.py#L568>`_ methods, which can all be used quite easily to perform cross validation analysis.

CLAUDE.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Project Overview
6+
7+
`multidms` is a Python package for joint modeling of multiple deep mutational scanning (DMS) experiments. It uses JAX for high-performance computing and automatic differentiation to fit global-epistasis models that estimate individual mutation effects and how they differ between experimental conditions.
8+
9+
## Development Commands
10+
11+
### Core Development Workflow
12+
```bash
13+
# Install development dependencies
14+
pip install -e ".[dev]"
15+
16+
# Code quality checks
17+
ruff check . # Lint code
18+
black . # Format code
19+
20+
# Testing
21+
pytest --doctest-modules -vv # Run all tests including doctests
22+
pytest tests/ # Run unit tests only
23+
pytest --doctest-modules multidms tests # Full test suite (as used in CI)
24+
25+
# Documentation
26+
make -C docs clean # Clean docs build
27+
make -C docs html # Build documentation
28+
29+
# Version management
30+
bumpver update --patch # Bump patch version
31+
bumpver update --minor # Bump minor version
32+
bumpver update --major # Bump major version
33+
```
34+
35+
### Testing Notes
36+
- The test suite is minimal with only basic Data class tests in `tests/test_data.py`
37+
- Doctests are integrated throughout the codebase and run with pytest
38+
- CI runs tests on Python 3.9, 3.10, 3.11 on Ubuntu and macOS
39+
40+
## Package Architecture
41+
42+
### Core Classes and Entry Points
43+
- **`multidms.Data`** - Handles data preprocessing and one-hot encoding of variant substitutions
44+
- **`multidms.Model`** - Main model class for fitting DMS experiments using JAX-based optimization
45+
- **`multidms.ModelCollection`** - Interface for fitting multiple models in parallel
46+
- **`multidms.fit_models`** - Function for parallel model fitting across collections
47+
48+
### Key Modules
49+
- **`multidms.biophysical`** - Core biophysical model equations, transformations, and mathematical foundations
50+
- **`multidms.model_collection`** - Parallel model fitting and analysis workflows
51+
- **`multidms.plot`** - Interactive plotting functionality using matplotlib/seaborn/altair
52+
- **`multidms.utils`** - Data transformation utilities and helper functions
53+
54+
### Dependencies and Architecture Patterns
55+
- **JAX ecosystem**: Core computational framework with jaxopt for optimization
56+
- **Data handling**: pandas for DataFrames, numpy for arrays (version pinned ≤1.26.0)
57+
- **Optimization**: Uses generalized lasso with bit-flipping algorithms via pylops/pyproximal
58+
- **Visualization**: Multi-library approach (matplotlib, seaborn, altair) for different plot types
59+
- **Scientific computing**: scipy for statistical functions, polyclonal for related modeling
60+
61+
### Code Style and Conventions
62+
- **Formatting**: Black with line length 89 (matches ruff configuration)
63+
- **Linting**: Ruff with specific rule selections (E, F, UP, D) and custom ignores for docstring styles
64+
- **Documentation**: NumPy-style docstrings throughout
65+
- **Type hints**: Used where appropriate, with typing_extensions for compatibility
66+
67+
### Development Patterns
68+
- Models compose biophysical equations from `multidms.biophysical` module
69+
- Heavy use of JAX transformations (jit, grad, vmap) for performance
70+
- Parameter initialization and transformation handled through dedicated methods
71+
- Cross-validation and simulation validation workflows built into model classes
72+
73+
### File Organization
74+
- Main package code in `multidms/` with flat module structure
75+
- Jupyter notebooks in `notebooks/` demonstrate usage and validation
76+
- Sphinx documentation in `docs/` with linked notebook examples
77+
- Minimal test suite in `tests/` (expansion needed)
78+
79+
### CI/CD and Release Process
80+
- GitHub Actions handle testing, linting, documentation builds
81+
- Automated PyPI publishing on tagged releases
82+
- Version management via bumpver tool with coordinated updates across files
83+
- Multi-platform testing ensures compatibility across development environments
84+
85+
## Active Technologies
86+
- Python 3.9, 3.10, 3.11 (multi-version CI support) + JAX ≥0.4.29, jaxopt, equinox, pandas ≥2.2.0, numpy ≤1.26.0 (001-jaxmodels-refactor)
87+
- N/A (library operates on in-memory pandas DataFrames) (001-jaxmodels-refactor)
88+
- Python 3.9, 3.10, 3.11 (CI tested on all three versions) + JAX ≥0.4.29, equinox, jaxopt, pandas ≥2.2.0, numpy ≤1.26.0 (001-jaxmodels-refactor)
89+
- N/A (in-memory DataFrame processing) (001-jaxmodels-refactor)
90+
91+
## Recent Changes
92+
- 001-jaxmodels-refactor: Added Python 3.9, 3.10, 3.11 (multi-version CI support) + JAX ≥0.4.29, jaxopt, equinox, pandas ≥2.2.0, numpy ≤1.26.0

0 commit comments

Comments
 (0)