-
Notifications
You must be signed in to change notification settings - Fork 74
Updates to CorrDiffTaiwan model wrapper #455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Charlelie Laurent <[email protected]>
|
@NickGeneva a few questions on the CorrDiffTaiwan wrapper:
|
| @classmethod | ||
| @check_optional_dependencies() | ||
| def load_model(cls, package: Package) -> DiagnosticModel: | ||
| def load_model(cls, package: Package, device: str | None = None) -> DiagnosticModel: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model wrappers are modules themselves, so users are expected to use wrapper.to(device).
So load_model (unless its gpu only model) just places things on the cpu and then users move it to the device. Can this be done with the optimization settings here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I could do that and move the .to(device) in the Triton server instead. There's however one caveat: the .to(memory_format=channels_last) has to be done after the .to(device). So the user code (and here the Triton server) would have the respnsability of doing both the .to(device) and the .to(memory_formet=channels_last). Would that be okay?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is one of the few instances where it might be beneficial to allow the user to specify the device here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll leave the device argument here for now
earth2studio/models/dx/corrdiff.py
Outdated
| x = (x - self.in_center) / self.in_scale | ||
|
|
||
| # Create grid channels | ||
| # TODO: why do we need this grid concatenated to the input? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think because the original corrdiff had a spatial encoding which was this grid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. I'm not sure with which codebase was the "original corrdiff" trained but physicsnemo training does not do that. So, I think this does not work with newer checkpoints trained with physicsnemo. I'll leave it as it is for now, as I'm not sure how to support both.
|
Can you also add a line in the changelog about these updates? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR updates the CorrDiffTaiwan model wrapper with significant performance optimizations and API modernization. The changes introduce Apex group normalization, FP16 precision, channels_last memory format, and torch.compile optimization to achieve up to 10x inference speed improvements. The refactoring replaces the manual sampling loop with new standardized APIs from physicsnemo.utils.corrdiff using diffusion_step and regression_step functions. The implementation adds a device parameter to load_model for flexible hardware deployment and streamlines seed generation logic. These optimizations align with PyTorch best practices for deep learning inference acceleration and modernize the model wrapper to use more maintainable APIs from the physicsnemo library.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| earth2studio/models/dx/corrdiff.py | 2/5 | Major refactoring with performance optimizations and API updates, but contains critical issues with parameter handling and validation logic |
Confidence score: 2/5
- This PR contains significant improvements but has several critical issues that need to be addressed before merging
- Score reflects complex optimization changes with multiple technical problems including missing seed attribute access, incorrect parameter passing in regression method, broken validation logic, and missing output variable references
- Pay close attention to the unet_regression method parameter handling and the inference pipeline's seed and validation logic
Sequence Diagram
sequenceDiagram
participant User
participant CorrDiffTaiwan
participant PhysicsNemoModule
participant diffusion_step
participant regression_step
participant interpolator
participant torch
User->>CorrDiffTaiwan: "load_model(package)"
CorrDiffTaiwan->>PhysicsNemoModule: "from_checkpoint(diffusion.mdlus)"
PhysicsNemoModule-->>CorrDiffTaiwan: "residual_model"
CorrDiffTaiwan->>PhysicsNemoModule: "from_checkpoint(regression.mdlus)"
PhysicsNemoModule-->>CorrDiffTaiwan: "regression_model"
CorrDiffTaiwan->>torch: "compile(residual_model)"
torch-->>CorrDiffTaiwan: "compiled_residual"
CorrDiffTaiwan->>torch: "compile(regression_model)"
torch-->>CorrDiffTaiwan: "compiled_regression"
CorrDiffTaiwan-->>User: "initialized_model"
User->>CorrDiffTaiwan: "__call__(x, coords)"
CorrDiffTaiwan->>CorrDiffTaiwan: "_forward(x[i])"
CorrDiffTaiwan->>interpolator: "_interpolate(x)"
interpolator-->>CorrDiffTaiwan: "interpolated_x"
CorrDiffTaiwan->>CorrDiffTaiwan: "normalize and add grid channels"
CorrDiffTaiwan->>regression_step: "unet_regression(regression_model, x)"
regression_step-->>CorrDiffTaiwan: "mean_hr"
CorrDiffTaiwan->>diffusion_step: "diffusion_step(residual_model, x, mean_hr)"
diffusion_step-->>CorrDiffTaiwan: "res_hr"
CorrDiffTaiwan->>CorrDiffTaiwan: "combine: mean_hr + res_hr"
CorrDiffTaiwan->>CorrDiffTaiwan: "denormalize output"
CorrDiffTaiwan-->>User: "output_tensor, output_coords"
Additional Comments (2)
1 file reviewed, 7 comments
Signed-off-by: Charlelie Laurent <[email protected]>
Signed-off-by: Charlelie Laurent <[email protected]>
Signed-off-by: Charlelie Laurent <[email protected]>
|
/blossom-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The current changes continue to address the CorrDiffTaiwan model wrapper updates that were identified in previous reviews. The developer has been working on fixing several critical issues including the missing seed attribute, incorrect variable references, and API compatibility problems with the unet_regression method. However, based on the context provided, many of the fundamental issues identified in previous reviews appear to still be present, particularly around the self.seed attribute reference, the missing lead_time_label parameter in unet_regression calls, and inconsistencies between self.output_variables and OUT_VARIABLES usage.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| earth2studio/models/dx/corrdiff.py | 2/5 | Updates to CorrDiffTaiwan model wrapper with API changes and performance optimizations, but several critical issues remain unresolved |
Confidence score: 2/5
- This PR requires careful review due to multiple unresolved critical issues that will cause runtime failures
- Score reflects persistent problems with missing attributes (self.seed), incorrect API calls (missing lead_time_label), and variable reference inconsistencies (self.output_variables vs OUT_VARIABLES)
- Pay close attention to earth2studio/models/dx/corrdiff.py as it contains multiple breaking changes that need to be addressed before merge
1 file reviewed, no comments
Signed-off-by: Charlelie Laurent <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Note: This review covers only the changes made since the last review, not the entire PR.
This most recent change adds a single line to the CHANGELOG.md file documenting the CorrDiffTaiwan model wrapper updates. The entry follows the project's Keep a Changelog format and is placed appropriately under the "Changed" section for version 0.10.0. This change addresses one of the incomplete checklist items in the PR description, ensuring the changelog is updated to reflect the PhysicsNeMo API updates and performance optimizations mentioned in the PR.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| CHANGELOG.md | 5/5 | Added single line documenting CorrDiffTaiwan model wrapper updates under version 0.10.0 |
Confidence score: 5/5
- This changelog addition is safe to merge with no risk of breaking changes
- Score reflects a simple, well-formatted documentation update that follows established conventions
- No files require special attention as this is just a changelog entry
1 file reviewed, no comments
|
/blossom-ci |
3 similar comments
|
/blossom-ci |
|
/blossom-ci |
|
/blossom-ci |
Signed-off-by: Charlelie Laurent <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The latest changes focus primarily on code style improvements and one critical bug fix in the CorrDiffTaiwan model wrapper. The developer added trailing commas to function parameters following Python best practices, improved string formatting for better readability, and most importantly fixed a variable reference bug where self.output_variables was incorrectly changed to OUT_VARIABLES on line 1077. However, this fix is incomplete as it should be len(OUT_VARIABLES) based on the context where it's used as a parameter expecting a count of output channels. The changes demonstrate attention to code consistency and formatting standards but don't address several critical issues identified in previous reviews.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| earth2studio/models/dx/corrdiff.py | 3/5 | Added trailing commas for consistency, improved string formatting, and partially fixed variable reference bug but incomplete implementation remains |
Confidence score: 3/5
- This PR has mixed safety - some improvements but leaves critical issues unaddressed that could cause runtime failures
- Score reflects partial progress on bug fixes but incomplete implementation and unaddressed critical issues from previous reviews including missing self.seed attribute, missing lead_time_label parameter, and incorrect rank_batches structure
- Pay close attention to earth2studio/models/dx/corrdiff.py as it contains an incomplete variable fix and multiple unaddressed critical issues that could cause AttributeErrors and parameter mismatches at runtime
1 file reviewed, no comments
|
/blossom-ci |
Earth2Studio Pull Request
Description
This PR updates the CorrDiffTaiwan model wrapper with latest APIs and performance optimizations. It should bring up to 10x inference speed (at best) thanks to Apex group norm and compilation.
Checklist
Dependencies