Skip to content

Conversation

@CharlelieLrt
Copy link
Collaborator

Earth2Studio Pull Request

Description

This PR updates the CorrDiffTaiwan model wrapper with latest APIs and performance optimizations. It should bring up to 10x inference speed (at best) thanks to Apex group norm and compilation.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

Signed-off-by: Charlelie Laurent <[email protected]>
@CharlelieLrt
Copy link
Collaborator Author

@NickGeneva a few questions on the CorrDiffTaiwan wrapper:

  • I do not understand why we need to compute a grid and concatenate it to the input before the forward pass. It is not done in PhysicsNeMo.
  • Currently physicsNeMo uses the stochatsic_sampler by default, but the wrapper uses the deterministic sampler. I don't think that's a problem, but the stochatsic_sampler should be provided as an option at least.
    LMK what you think.

@classmethod
@check_optional_dependencies()
def load_model(cls, package: Package) -> DiagnosticModel:
def load_model(cls, package: Package, device: str | None = None) -> DiagnosticModel:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model wrappers are modules themselves, so users are expected to use wrapper.to(device).

So load_model (unless its gpu only model) just places things on the cpu and then users move it to the device. Can this be done with the optimization settings here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I could do that and move the .to(device) in the Triton server instead. There's however one caveat: the .to(memory_format=channels_last) has to be done after the .to(device). So the user code (and here the Triton server) would have the respnsability of doing both the .to(device) and the .to(memory_formet=channels_last). Would that be okay?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is one of the few instances where it might be beneficial to allow the user to specify the device here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave the device argument here for now

x = (x - self.in_center) / self.in_scale

# Create grid channels
# TODO: why do we need this grid concatenated to the input?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think because the original corrdiff had a spatial encoding which was this grid

Copy link
Collaborator Author

@CharlelieLrt CharlelieLrt Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I'm not sure with which codebase was the "original corrdiff" trained but physicsnemo training does not do that. So, I think this does not work with newer checkpoints trained with physicsnemo. I'll leave it as it is for now, as I'm not sure how to support both.

@NickGeneva
Copy link
Collaborator

Can you also add a line in the changelog about these updates? Thanks!

@NickGeneva NickGeneva added the 4 - In Review Currently Under Review label Aug 18, 2025
@dallasfoster dallasfoster self-assigned this Aug 26, 2025
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR updates the CorrDiffTaiwan model wrapper with significant performance optimizations and API modernization. The changes introduce Apex group normalization, FP16 precision, channels_last memory format, and torch.compile optimization to achieve up to 10x inference speed improvements. The refactoring replaces the manual sampling loop with new standardized APIs from physicsnemo.utils.corrdiff using diffusion_step and regression_step functions. The implementation adds a device parameter to load_model for flexible hardware deployment and streamlines seed generation logic. These optimizations align with PyTorch best practices for deep learning inference acceleration and modernize the model wrapper to use more maintainable APIs from the physicsnemo library.

Important Files Changed

Filename Score Overview
earth2studio/models/dx/corrdiff.py 2/5 Major refactoring with performance optimizations and API updates, but contains critical issues with parameter handling and validation logic

Confidence score: 2/5

  • This PR contains significant improvements but has several critical issues that need to be addressed before merging
  • Score reflects complex optimization changes with multiple technical problems including missing seed attribute access, incorrect parameter passing in regression method, broken validation logic, and missing output variable references
  • Pay close attention to the unet_regression method parameter handling and the inference pipeline's seed and validation logic

Sequence Diagram

sequenceDiagram
    participant User
    participant CorrDiffTaiwan
    participant PhysicsNemoModule
    participant diffusion_step
    participant regression_step
    participant interpolator
    participant torch

    User->>CorrDiffTaiwan: "load_model(package)"
    CorrDiffTaiwan->>PhysicsNemoModule: "from_checkpoint(diffusion.mdlus)"
    PhysicsNemoModule-->>CorrDiffTaiwan: "residual_model"
    CorrDiffTaiwan->>PhysicsNemoModule: "from_checkpoint(regression.mdlus)"
    PhysicsNemoModule-->>CorrDiffTaiwan: "regression_model"
    CorrDiffTaiwan->>torch: "compile(residual_model)"
    torch-->>CorrDiffTaiwan: "compiled_residual"
    CorrDiffTaiwan->>torch: "compile(regression_model)"
    torch-->>CorrDiffTaiwan: "compiled_regression"
    CorrDiffTaiwan-->>User: "initialized_model"

    User->>CorrDiffTaiwan: "__call__(x, coords)"
    CorrDiffTaiwan->>CorrDiffTaiwan: "_forward(x[i])"
    CorrDiffTaiwan->>interpolator: "_interpolate(x)"
    interpolator-->>CorrDiffTaiwan: "interpolated_x"
    CorrDiffTaiwan->>CorrDiffTaiwan: "normalize and add grid channels"
    CorrDiffTaiwan->>regression_step: "unet_regression(regression_model, x)"
    regression_step-->>CorrDiffTaiwan: "mean_hr"
    CorrDiffTaiwan->>diffusion_step: "diffusion_step(residual_model, x, mean_hr)"
    diffusion_step-->>CorrDiffTaiwan: "res_hr"
    CorrDiffTaiwan->>CorrDiffTaiwan: "combine: mean_hr + res_hr"
    CorrDiffTaiwan->>CorrDiffTaiwan: "denormalize output"
    CorrDiffTaiwan-->>User: "output_tensor, output_coords"
Loading

Additional Comments (2)

  1. earth2studio/models/dx/corrdiff.py, line 814 (link)

    logic: Validation logic error: should use or instead of and for the conditions

  2. earth2studio/models/dx/corrdiff.py, line 816 (link)

    logic: Same validation issue: should use or instead of and

1 file reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

@CharlelieLrt
Copy link
Collaborator Author

/blossom-ci

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The current changes continue to address the CorrDiffTaiwan model wrapper updates that were identified in previous reviews. The developer has been working on fixing several critical issues including the missing seed attribute, incorrect variable references, and API compatibility problems with the unet_regression method. However, based on the context provided, many of the fundamental issues identified in previous reviews appear to still be present, particularly around the self.seed attribute reference, the missing lead_time_label parameter in unet_regression calls, and inconsistencies between self.output_variables and OUT_VARIABLES usage.

Important Files Changed

Filename Score Overview
earth2studio/models/dx/corrdiff.py 2/5 Updates to CorrDiffTaiwan model wrapper with API changes and performance optimizations, but several critical issues remain unresolved

Confidence score: 2/5

  • This PR requires careful review due to multiple unresolved critical issues that will cause runtime failures
  • Score reflects persistent problems with missing attributes (self.seed), incorrect API calls (missing lead_time_label), and variable reference inconsistencies (self.output_variables vs OUT_VARIABLES)
  • Pay close attention to earth2studio/models/dx/corrdiff.py as it contains multiple breaking changes that need to be addressed before merge

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Charlelie Laurent <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Note: This review covers only the changes made since the last review, not the entire PR.

This most recent change adds a single line to the CHANGELOG.md file documenting the CorrDiffTaiwan model wrapper updates. The entry follows the project's Keep a Changelog format and is placed appropriately under the "Changed" section for version 0.10.0. This change addresses one of the incomplete checklist items in the PR description, ensuring the changelog is updated to reflect the PhysicsNeMo API updates and performance optimizations mentioned in the PR.

Important Files Changed

Filename Score Overview
CHANGELOG.md 5/5 Added single line documenting CorrDiffTaiwan model wrapper updates under version 0.10.0

Confidence score: 5/5

  • This changelog addition is safe to merge with no risk of breaking changes
  • Score reflects a simple, well-formatted documentation update that follows established conventions
  • No files require special attention as this is just a changelog entry

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@CharlelieLrt
Copy link
Collaborator Author

/blossom-ci

3 similar comments
@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva
Copy link
Collaborator

/blossom-ci

Signed-off-by: Charlelie Laurent <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The latest changes focus primarily on code style improvements and one critical bug fix in the CorrDiffTaiwan model wrapper. The developer added trailing commas to function parameters following Python best practices, improved string formatting for better readability, and most importantly fixed a variable reference bug where self.output_variables was incorrectly changed to OUT_VARIABLES on line 1077. However, this fix is incomplete as it should be len(OUT_VARIABLES) based on the context where it's used as a parameter expecting a count of output channels. The changes demonstrate attention to code consistency and formatting standards but don't address several critical issues identified in previous reviews.

Important Files Changed

Filename Score Overview
earth2studio/models/dx/corrdiff.py 3/5 Added trailing commas for consistency, improved string formatting, and partially fixed variable reference bug but incomplete implementation remains

Confidence score: 3/5

  • This PR has mixed safety - some improvements but leaves critical issues unaddressed that could cause runtime failures
  • Score reflects partial progress on bug fixes but incomplete implementation and unaddressed critical issues from previous reviews including missing self.seed attribute, missing lead_time_label parameter, and incorrect rank_batches structure
  • Pay close attention to earth2studio/models/dx/corrdiff.py as it contains an incomplete variable fix and multiple unaddressed critical issues that could cause AttributeErrors and parameter mismatches at runtime

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@CharlelieLrt
Copy link
Collaborator Author

/blossom-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

4 - In Review Currently Under Review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants