Refactor NPE-C loss logic to Strategy Pattern#1755
Refactor NPE-C loss logic to Strategy Pattern#1755Sumit6307 wants to merge 1 commit intosbi-dev:mainfrom
Conversation
|
Hi @janfb, please have a look at this PR. Thanks! |
|
Hi @Sumit6307 , thank you for this PR as well. I assume you have read my comment under the PR (#1756 (comment)). That said, this contribution looks very good on a high level. I suggest the following:
|
@janfb Yes, I have read your comment under #1756. I’m glad to hear that the contribution looks good at a high level. I’ll go through #1241 and @michaeldeistler’s proposal in detail and review how my PR aligns with the planned new organization of SNPE methods. Based on that, I’ll update the implementation and clarify the alignment where needed. I’ll also run ruff and pyright locally and fix the remaining linting and type-checking issues so that CI passes. Please let me know if you’d prefer me to continue iterating on this within the current PR, or if it would be better to first open or move the discussion to a dedicated issue. Thanks for the guidance. |
|
Sounds good @Sumit6307 , let's discuss the implementation here in the PR. |
Summary
This PR refactors the NPE_C trainer to use the Strategy Pattern for its loss calculation logic. It isolates the "atomic" (sample-based) and "non-atomic" (analytical Gaussian) loss implementations into separate strategy classes, significantly improving the readability and maintainability of the core trainer class.
Motivation
The NPE_C class previously contained complex, non-trivial logic for switching between two very different loss calculation methods. This tightly coupled the mathematical details of the loss functions with the training loop. Validating the Single Responsibility Principle, this refactor:
Changes
_log_prob_proposal_posterior_atomic,_log_prob_proposal_posterior_mog, and _automatic_posterior_transformation.Verification