Skip to content

Conversation

@laggui
Copy link
Member

@laggui laggui commented Nov 3, 2025

Warning

Because module.valid() is stateless, the require_grad attribute of module params will not follow when subsequently calling module.train(). For default use cases, this is fine, but if params were manually changed (e.g., froze some parameters) this could lead to unintended user results if the same logic is not applied after moving the model back to autodiff.

This should be documented, but we should also keep this PR pending for now for this reason.

Checklist

  • Confirmed that cargo run-checks command has been executed.

Related Issues/PRs

From discord:

Is there a convenient way to get the model on the auto diff backend from learner.fit or do I basically have to save and reload the model if I want to do auto diff stuff after fitting?

Changes

Added the inverse function to module.valid() <> module.train(), which moves the module and all of its sub-modules to the autodiff backend.

Testing

Single test for batchnorm module to illustrate usage

@codecov
Copy link

codecov bot commented Nov 3, 2025

Codecov Report

❌ Patch coverage is 72.72727% with 30 lines in your changes missing coverage. Please review.
✅ Project coverage is 64.74%. Comparing base (3af2817) to head (1876018).
⚠️ Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-core/src/module/param/primitive.rs 0.00% 15 Missing ⚠️
crates/burn-core/src/module/param/constant.rs 25.00% 9 Missing ⚠️
crates/burn-core/src/module/param/tensor.rs 33.33% 6 Missing ⚠️

❌ Your patch check has failed because the patch coverage (72.72%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (64.74%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@           Coverage Diff            @@
##             main    #3975    +/-   ##
========================================
  Coverage   64.73%   64.74%            
========================================
  Files        1178     1182     +4     
  Lines      140212   140527   +315     
========================================
+ Hits        90770    90983   +213     
- Misses      49442    49544   +102     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants