Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 82 additions & 98 deletions earth2studio/models/dx/corrdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ def load_default_package(cls) -> Package:

@classmethod
@check_optional_dependencies()
def load_model(cls, package: Package) -> DiagnosticModel:
def load_model(cls, package: Package, device: str | None = None) -> DiagnosticModel:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model wrappers are modules themselves, so users are expected to use wrapper.to(device).

So load_model (unless its gpu only model) just places things on the cpu and then users move it to the device. Can this be done with the optimization settings here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I could do that and move the .to(device) in the Triton server instead. There's however one caveat: the .to(memory_format=channels_last) has to be done after the .to(device). So the user code (and here the Triton server) would have the respnsability of doing both the .to(device) and the .to(memory_formet=channels_last). Would that be okay?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is one of the few instances where it might be beneficial to allow the user to specify the device here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave the device argument here for now

"""Load diagnostic from package"""

if StackedRandomGenerator is None or deterministic_sampler is None:
Expand All @@ -897,15 +897,33 @@ def load_model(cls, package: Package) -> DiagnosticModel:
str(
checkpoint_zip.parent
/ Path("corrdiff_inference_package/checkpoints/diffusion.mdlus")
)
).eval()
),
override_args={"use_apex_gn": True}
)
residual.use_fp16, residual.profile_mode = True, False
residual = residual.eval()
if device is not None:
residual = residual.to(device)
residual = residual.to(memory_format=torch.channels_last)

regression = PhysicsNemoModule.from_checkpoint(
str(
checkpoint_zip.parent
/ Path("corrdiff_inference_package/checkpoints/regression.mdlus")
)
).eval()
),
override_args={"use_apex_gn": True}
)
regression.use_fp16, regression.profile_mode = True, False
regression = regression.eval()
if device is not None:
regression = regression.to(device)
regression = regression.to(memory_format=torch.channels_last)

# Compile models
torch._dynamo.config.cache_size_limit = 264
torch._dynamo.reset()
residual = torch.compile(residual)
regression = torch.compile(regression)

# Get dataset for lat/lon grid info and centers/stds'
try:
Expand Down Expand Up @@ -1011,10 +1029,11 @@ def _forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._interpolate(x)

# Add sample dimension
x = x.unsqueeze(0)
x = x.unsqueeze(0).to(memory_format=torch.channels_last)
x = (x - self.in_center) / self.in_scale

# Create grid channels
# TODO: why do we need this grid concatenated to the input?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think because the original corrdiff had a spatial encoding which was this grid

Copy link
Collaborator Author

@CharlelieLrt CharlelieLrt Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I'm not sure with which codebase was the "original corrdiff" trained but physicsnemo training does not do that. So, I think this does not work with newer checkpoints trained with physicsnemo. I'll leave it as it is for now, as I'm not sure how to support both.

x1 = np.sin(np.linspace(0, 2 * np.pi, 448))
x2 = np.cos(np.linspace(0, 2 * np.pi, 448))
y1 = np.sin(np.linspace(0, 2 * np.pi, 448))
Expand All @@ -1032,43 +1051,44 @@ def _forward(self, x: torch.Tensor) -> torch.Tensor:
# Concat Grids
x = torch.cat((x, grid), dim=1)

# Repeat for sample size
sample_seeds = torch.arange(self.number_of_samples)
x = x.repeat(self.number_of_samples, 1, 1, 1)
# Create seeds for each sample
seed = self.seed if self.seed is not None else np.random.randint(2**32)
gen = torch.Generator(device=x.device)
gen.manual_seed(seed)
sample_seeds = torch.randint(0, 2**32, (self.number_of_samples,), device=x.device, generator=gen)

# Create latents
rnd = StackedRandomGenerator(x.device, sample_seeds)
sampler_fn = partial(
deterministic_sampler,
num_steps=self.number_of_steps,
solver=self.solver,
)

# Get high-res image shape
coord = self.output_coords(self.input_coords())
img_resolution_x = coord["lat"].shape[0]
img_resolution_y = coord["lon"].shape[1]
latents = rnd.randn(
[
self.number_of_samples,
self.regression_model.img_out_channels,
img_resolution_x,
img_resolution_y,
],
device=x.device,
)
H_hr = coord["lat"].shape[0]
W_hr = coord["lon"].shape[1]

mean = self.unet_regression(
mean_hr = self.unet_regression(
self.regression_model,
torch.zeros_like(latents),
x,
num_steps=self.number_of_steps,
img_lr=x,
output_channels=len(self.output_variables),
number_of_samples=self.number_of_samples,
)
res = deterministic_sampler(
self.residual_model,
latents,
x,
randn_like=rnd.randn_like,
num_steps=self.number_of_steps,
solver=self.solver,

res_hr = diffusion_step(
net=self.residual_model,
sampler_fn=sampler_fn,
img_shape=(H_hr, W_hr),
img_out_channels=len(self.output_variables),
rank_batches=sample_seeds,
img_lr=x,
device=x.device,
mean_hr=mean_hr,
)
x = mean + res
x = self.out_scale * x + self.out_center
return x

x_hr = mean_hr + res_hr
x_hr = self.out_scale * x_hr + self.out_center
return x_hr

@batch_func()
def __call__(
Expand All @@ -1092,79 +1112,43 @@ def __call__(
@staticmethod
def unet_regression(
net: torch.nn.Module,
latents: torch.Tensor,
img_lr: torch.Tensor,
class_labels: torch.Tensor = None,
randn_like: Callable = torch.randn_like,
num_steps: int = 8,
sigma_min: float = 0.0,
sigma_max: float = 0.0,
rho: int = 7,
S_churn: float = 0,
S_min: float = 0,
S_max: float = float("inf"),
S_noise: float = 0.0,
lead_time_label: torch.Tensor,
output_channels: int,
number_of_samples: int,
) -> torch.Tensor:
"""
Perform U-Net regression with temporal sampling.
Perform U-Net regression.

Parameters
----------
net : torch.nn.Module
U-Net model for regression.
latents : torch.Tensor
Latent representation.
img_lr : torch.Tensor)
Low-resolution input image.
class_labels : torch.Tensor, optional
Class labels for conditional generation.
randn_like : function, optional
Function for generating random noise.
num_steps : int, optional
Number of time steps for temporal sampling.
sigma_min : float, optional
Minimum noise level.
sigma_max : float, optional
Maximum noise level.
rho : int, optional
Exponent for noise level interpolation.
S_churn : float, optional
Churning parameter.
S_min : float, optional
Minimum churning value.
S_max : float, optional
Maximum churning value.
S_noise : float, optional
Noise level for churning.
img_lr : torch.Tensor
Low-resolution input image of shape (1, C_in, H_hr, W_hr).
lead_time_label : torch.Tensor
Lead time label of shape (1, 1, 1, 1), or (1,).
output_channels : int
Number of output channels C_out.
number_of_samples : int
Number of samples to generate for the single input batch element.
Only used to expand the shape of the output tensor.

Returns
-------
torch.Tensor: Predicted output at the next time step.
torch.Tensor: Predicted output with shape (number_of_samples, C_out,
H_hr, W_hr).
"""
# Adjust noise levels based on what's supported by the network.
sigma_min = max(sigma_min, 0)
sigma_max = min(sigma_max, np.inf)

# Time step discretization.
step_indices = torch.arange(
num_steps, dtype=torch.float64, device=latents.device
mean_hr = regression_step(
net=net,
img_lr=img_lr,
latents_shape=(
number_of_samples,
output_channels,
img_lr.shape[-2],
img_lr.shape[-1],
),
lead_time_label=lead_time_label
)
t_steps = (
sigma_max ** (1 / rho)
+ step_indices
/ (num_steps - 1)
* (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
) ** rho
t_steps = torch.cat(
[net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]
) # t_N = 0

# conditioning
x_lr = img_lr

# Main sampling loop.
x_hat = latents.to(torch.float64) * t_steps[0]

x_next = net(x_hat, x_lr).to(torch.float64)

return x_next
return mean_hr