Skip to content

Conversation

@willdumm
Copy link
Contributor

@willdumm willdumm commented Mar 5, 2025

See #123 for discussion, referencing a test in this PR

@willdumm
Copy link
Contributor Author

willdumm commented Apr 23, 2025

The issues that ended up causing the discrepancy were:

  • gradient-based optimization giving consistently incorrect (though close to correct on real data) results. We now use the gradient-free "Brent" method from scipy.optimize.minimize_scalar with some contortions to avoid numerical issues since our log-probabilities are quite flat w.r.t. branch lengths, especially when the optimal branch length is very close to 0.
  • we were unnecessarily clamping linear-space probabilities to a number that was too large. Now using torch.finfo(dtype).tiny instead of torch.finfo(dtype).eps. The point is to avoid issues taking the log of something too close to zero, and the changed value is still adequate for that purpose.
  • Also, there are two completely independent code paths computing codon probabilities, one for branch length optimization and one for loss computation (in DASM). This PR doesn't fix that, but tests both paths. A future PR should fix this. Issue opened Independent code paths for codon probability computations #134

@willdumm willdumm marked this pull request as ready for review April 23, 2025 23:42
@willdumm willdumm requested review from Copilot and matsen April 23, 2025 23:43
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements changes to validate multihit branch lengths and improve branch length optimization along with related refactoring. Key changes include:

  • Addition of new utility functions (flatten_codon_idxs, unflatten_codon_idxs, nt_idx_tensor_of_str) and corresponding tests.
  • Refactored optimization routines in molevol.py using SciPy’s bracket and minimize_scalar functions.
  • Consistent updating of branch length handling and multihit correction across modules (multihit, dxsm, dnsm, ddsm, dasm, framework).

Reviewed Changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/test_sequences.py Added tests for new codon index conversion functions.
setup.py Added the scipy dependency.
netam/sequences.py Introduced new codon indexing utilities supporting multihit workflows.
netam/multihit.py Updated HitClassDataset to subclass BranchLengthDataset and refactored branch length handling.
netam/molevol.py Overhauled optimization functions and added new codon reshaping utilities.
netam/models.py Updated forward docstring to reflect new input expectations.
netam/hit_class.py Modified multihit correction using the new normalization routine.
netam/framework.py Added parallel branch length optimization with multiprocessing.
netam/dxsm.py, dnsm.py, ddsm.py, dasm.py Replaced deprecated branch length references with the updated ones.

Copy link
Contributor

@matsen matsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge when happy

netam/dasm.py Outdated

# We have to clamp the predictions to avoid log(0) issues.
preds = torch.clamp(preds, min=torch.finfo(preds.dtype).eps)
preds = torch.clamp(preds, min=torch.finfo(preds.dtype).tiny)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check?

@willdumm
Copy link
Contributor Author

Look how wrong the branch lengths were! (I think they were more wrong with multihit adjustment, but we can't make that comparison because we also changed how the multihit model gets applied)

These are on 50k pcps from the Jaffe heavy chain dataset.
branch_length_comparison.pdf

branch_length_comparison_origin.pdf

add branch length tests

remove old comment

update comment

add lots of tests

maybe fixed?

fix multihit normalization and add tests

tweaks to branch lengths

finalized branch length optimization

fixed tests again

cleanup

fix copilot suggestion

dasm training performs well

cleanup and fix tests

format

lint

remove unnecessary warning
@willdumm willdumm force-pushed the 123-multihit-branch-lengths branch from 0cec219 to bb412a1 Compare April 28, 2025 17:53
@willdumm willdumm merged commit 0d36aee into main Apr 28, 2025
2 checks passed
@willdumm willdumm deleted the 123-multihit-branch-lengths branch April 28, 2025 18:04
@willdumm willdumm restored the 123-multihit-branch-lengths branch April 29, 2025 20:38
@willdumm willdumm deleted the 123-multihit-branch-lengths branch April 29, 2025 21:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants