-
Notifications
You must be signed in to change notification settings - Fork 2
Validate multihit branch lengths #124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The issues that ended up causing the discrepancy were:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements changes to validate multihit branch lengths and improve branch length optimization along with related refactoring. Key changes include:
- Addition of new utility functions (flatten_codon_idxs, unflatten_codon_idxs, nt_idx_tensor_of_str) and corresponding tests.
- Refactored optimization routines in molevol.py using SciPy’s bracket and minimize_scalar functions.
- Consistent updating of branch length handling and multihit correction across modules (multihit, dxsm, dnsm, ddsm, dasm, framework).
Reviewed Changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_sequences.py | Added tests for new codon index conversion functions. |
| setup.py | Added the scipy dependency. |
| netam/sequences.py | Introduced new codon indexing utilities supporting multihit workflows. |
| netam/multihit.py | Updated HitClassDataset to subclass BranchLengthDataset and refactored branch length handling. |
| netam/molevol.py | Overhauled optimization functions and added new codon reshaping utilities. |
| netam/models.py | Updated forward docstring to reflect new input expectations. |
| netam/hit_class.py | Modified multihit correction using the new normalization routine. |
| netam/framework.py | Added parallel branch length optimization with multiprocessing. |
| netam/dxsm.py, dnsm.py, ddsm.py, dasm.py | Replaced deprecated branch length references with the updated ones. |
matsen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merge when happy
netam/dasm.py
Outdated
|
|
||
| # We have to clamp the predictions to avoid log(0) issues. | ||
| preds = torch.clamp(preds, min=torch.finfo(preds.dtype).eps) | ||
| preds = torch.clamp(preds, min=torch.finfo(preds.dtype).tiny) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check?
|
Look how wrong the branch lengths were! (I think they were more wrong with multihit adjustment, but we can't make that comparison because we also changed how the multihit model gets applied) These are on 50k pcps from the Jaffe heavy chain dataset. |
add branch length tests remove old comment update comment add lots of tests maybe fixed? fix multihit normalization and add tests tweaks to branch lengths finalized branch length optimization fixed tests again cleanup fix copilot suggestion dasm training performs well cleanup and fix tests format lint remove unnecessary warning
0cec219 to
bb412a1
Compare
See #123 for discussion, referencing a test in this PR