Conversation
This commit refactors the SARIX model classes to eliminate code duplication by making SARIXFourierModel inherit from SARIXModel. Changes to src/idmodels/sarix.py: - Add helper methods to SARIXModel base class: - _get_sarix_module(): Returns the sarix module to use - _get_extra_sarix_params(df): Returns extra parameters for SARIX constructor - Modify SARIXModel.run() to use helper methods for extensibility - Refactor SARIXFourierModel to inherit from SARIXModel - Remove duplicate __init__ and run() methods from SARIXFourierModel - Override helper methods in SARIXFourierModel to provide Fourier-specific behavior - Add _np_percentile() helper function for easier test mocking Changes to tests/integration/test_sarix.py: - Add test_sarix_shared_sigma_pooling_multiple_batches() to verify shared sigma pooling works correctly with multiple locations Changes to pyproject.toml: - Update sarix dependency to use reichlab/sarix instead of elray1/sarix Benefits: - Reduces code by ~65 lines by eliminating duplication - Makes the inheritance relationship explicit - Keeps base class generic and extensible - Easier to add new SARIX model variants in the future 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Adds the uv lock file to ensure all dependencies are pinned to specific versions for reproducible builds across different environments. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit fixes a failing test for SARIX models with shared sigma pooling and updates the sarix dependency to resolve an upstream bug. Problem: -------- The test `test_sarix_shared_sigma_pooling_multiple_batches` was failing with a reshape error when using `sigma_pooling='shared'` with multiple locations. The error occurred in the sarix library at line 392 where it incorrectly used `theta` instead of `sigma` when reshaping arrays: TypeError: cannot reshape array of shape (100, 5, 6) into shape (100, 1, 1) Root Cause: ----------- The installed sarix package (v0.0.1 from elray1/sarix) contained a bug where the variable name was wrong in the sigma pooling code block. The reichlab/sarix repository had already fixed this bug in v0.2.0. Fixes: ------ 1. Updated sarix dependency from elray1/sarix to reichlab/sarix (v0.2.0) 2. Updated requires-python from >=3.9 to >=3.11 (required by newer sarix) 3. Regenerated uv.lock with updated dependencies 4. Regenerated requirements.txt and requirements-dev.txt 5. Added explicit string conversion for output_type_id in CSV output 6. Fixed test assertion to handle pandas type inference for output_type_id Test Results: ------------- ✅ test_sarix - PASSED ✅ test_sarix_shared_sigma_pooling_multiple_batches - PASSED (was failing) ✅ test_drop_level_feats - PASSED 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
|
should we increment the version of idmodels? |
|
Noting that this PR is now dependent on incoming changes to sarix. reichlab/sarix#6 |
I'm a little confused. If this PR depends on another repo's PR, and needs the commit hash from that PR for requirements-dev.txt and requirements.txt, then this PR is premature, right? The commit has used in this PR is: |
|
I also note that the ruff check has failed. |
Update sarix dependency to specific commit 35eea2379a9790e0457b1aed41d13509e5d5056f and fix ruff import sorting error in test_sarix.py. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Regenerate requirements.txt and requirements-dev.txt to include the pinned sarix commit 35eea2379a9790e0457b1aed41d13509e5d5056f. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
|
I updated the PR and the checks are now passing. Sorry for the false start, @matthewcornell . |
|
Should we make this a new version? I'm not sure where I'd change that, but seems like this warrants a minor version bump. |
Update version from 0.0.1 to 0.1.0 to reflect new features including Fourier pooling support, bug fixes, and dependency updates. Add CHANGELOG.md following Keep a Changelog format to document current and historical changes, providing a structured place for tracking future releases. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
matthewcornell
left a comment
There was a problem hiding this comment.
All LGTM except one question: SARIXModel._get_sarix_module() and SARIXFourierModel._get_sarix_module() : these both return the same module (from sarix import sarix). while previous iterations might have returned different values, now this seems obsolete b/c _get_extra_sarix_params() is the only thing that differs, right? If so we should remove the two _get_sarix_module() functions and the sarix_module variable, and change line 50 to: sarix_fit_all_locs_theta_pooled = sarix.SARIX( .
Both SARIXModel and SARIXFourierModel returned the same sarix module, making _get_sarix_module() unnecessary. The only difference between the classes is _get_extra_sarix_params(), so we can directly call sarix.SARIX() instead of going through an extra abstraction layer. This simplifies the code without changing functionality. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This PR points to the new sarix package on reichlab and also adds a subclass that is the fourier version of the model.