Skip to content

Refactor NPE-C loss logic to Strategy Pattern#1755

Open
Sumit6307 wants to merge 1 commit intosbi-dev:mainfrom
Sumit6307:refactor/npe-c-strategy
Open

Refactor NPE-C loss logic to Strategy Pattern#1755
Sumit6307 wants to merge 1 commit intosbi-dev:mainfrom
Sumit6307:refactor/npe-c-strategy

Conversation

@Sumit6307
Copy link

Summary

This PR refactors the NPE_C trainer to use the Strategy Pattern for its loss calculation logic. It isolates the "atomic" (sample-based) and "non-atomic" (analytical Gaussian) loss implementations into separate strategy classes, significantly improving the readability and maintainability of the core trainer class.

Motivation

The NPE_C class previously contained complex, non-trivial logic for switching between two very different loss calculation methods. This tightly coupled the mathematical details of the loss functions with the training loop. Validating the Single Responsibility Principle, this refactor:

  • Decouples loss implementation from the training orchestration.
  • Makes it easier to extend NPE_C with new loss types in the future (e.g., different divergences) without modifying the main loop.
  • Improves code clarity by removing over 100 lines of mixed-concern code from npe_c.py.

Changes

  • New File: sbi/inference/trainers/npe/npe_c_loss.py
    • Implemented AtomicLoss: Encapsulates the sampling-based loss logic.
    • Implemented NonAtomicGaussianLoss: Encapsulates the analytical MoG loss logic including automatic posterior transformation.
  • Modified: sbi/inference/trainers/npe/npe_c.py
    • Removed private methods _log_prob_proposal_posterior_atomic, _log_prob_proposal_posterior_mog, and _automatic_posterior_transformation.
    • Instantiates the appropriate strategy in train() based on the proposal and prior type.
    • Delegates execution to the strategy object.

Verification

  • Logic Preservation: Manually verified that the mathematical operations and logic flow in the new strategy classes match the original implementation exactly line-by-line.
  • Static Analysis: Verified imports and syntax correctness.
  • Behavior: The refactor is purely structural; no changes were made to the underlying mathematical definitions of the NPE-C loss.

@Sumit6307
Copy link
Author

Hi @janfb, please have a look at this PR. Thanks!

@janfb
Copy link
Contributor

janfb commented Feb 6, 2026

Hi @Sumit6307 , thank you for this PR as well. I assume you have read my comment under the PR (#1756 (comment)). That said, this contribution looks very good on a high level. I suggest the following:

@Sumit6307
Copy link
Author

Hi @Sumit6307 , thank you for this PR as well. I assume you have read my comment under the PR (#1756 (comment)). That said, this contribution looks very good on a high level. I suggest the following:

@janfb
Hi, thank you for the feedback!

Yes, I have read your comment under #1756. I’m glad to hear that the contribution looks good at a high level.

I’ll go through #1241 and @michaeldeistler’s proposal in detail and review how my PR aligns with the planned new organization of SNPE methods. Based on that, I’ll update the implementation and clarify the alignment where needed.

I’ll also run ruff and pyright locally and fix the remaining linting and type-checking issues so that CI passes.

Please let me know if you’d prefer me to continue iterating on this within the current PR, or if it would be better to first open or move the discussion to a dedicated issue.

Thanks for the guidance.

@janfb
Copy link
Contributor

janfb commented Feb 6, 2026

Sounds good @Sumit6307 , let's discuss the implementation here in the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants