Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for ERA5 model levels and additional variables in ARCO data source
- Changed the HRRR_X and HRRR_Y coordinates of the HRRR data source to match the native
LCC coordinates
- Updated CorrDiffTaiwan model wrapper to use latest PhysicsNeMo APIs.

### Deprecated

Expand Down
188 changes: 89 additions & 99 deletions earth2studio/models/dx/corrdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,8 @@ class CorrDiffTaiwan(torch.nn.Module, AutoModelMixin):
solver: Literal['euler', 'heun']
Discretization of diffusion process. Only 'euler' and 'heun'
are supported. Default is 'euler'
seed: int | None, optional
Random seed for reproducibility. Default is None.
"""

def __init__(
Expand All @@ -798,6 +800,7 @@ def __init__(
number_of_samples: int = 1,
number_of_steps: int = 8,
solver: Literal["euler", "heun"] = "euler",
seed: int | None = None,
):
super().__init__()
self.residual_model = residual_model
Expand All @@ -821,6 +824,8 @@ def __init__(
self.number_of_samples = number_of_samples
self.number_of_steps = number_of_steps
self.solver = solver
self.seed = seed
self.output_variables = OUT_VARIABLES # Default set of output variables

def input_coords(self) -> CoordSystem:
"""Input coordinate system"""
Expand Down Expand Up @@ -883,7 +888,7 @@ def load_default_package(cls) -> Package:

@classmethod
@check_optional_dependencies()
def load_model(cls, package: Package) -> DiagnosticModel:
def load_model(cls, package: Package, device: str | None = None) -> DiagnosticModel:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model wrappers are modules themselves, so users are expected to use wrapper.to(device).

So load_model (unless its gpu only model) just places things on the cpu and then users move it to the device. Can this be done with the optimization settings here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I could do that and move the .to(device) in the Triton server instead. There's however one caveat: the .to(memory_format=channels_last) has to be done after the .to(device). So the user code (and here the Triton server) would have the respnsability of doing both the .to(device) and the .to(memory_formet=channels_last). Would that be okay?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is one of the few instances where it might be beneficial to allow the user to specify the device here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave the device argument here for now

"""Load diagnostic from package"""

if StackedRandomGenerator is None or deterministic_sampler is None:
Expand All @@ -900,15 +905,33 @@ def load_model(cls, package: Package) -> DiagnosticModel:
str(
checkpoint_zip.parent
/ Path("corrdiff_inference_package/checkpoints/diffusion.mdlus")
)
).eval()
),
override_args={"use_apex_gn": True},
)
residual.use_fp16, residual.profile_mode = True, False
residual = residual.eval()
if device is not None:
residual = residual.to(device)
residual = residual.to(memory_format=torch.channels_last)

regression = PhysicsNemoModule.from_checkpoint(
str(
checkpoint_zip.parent
/ Path("corrdiff_inference_package/checkpoints/regression.mdlus")
)
).eval()
),
override_args={"use_apex_gn": True},
)
regression.use_fp16, regression.profile_mode = True, False
regression = regression.eval()
if device is not None:
regression = regression.to(device)
regression = regression.to(memory_format=torch.channels_last)

# Compile models
torch._dynamo.config.cache_size_limit = 264
torch._dynamo.reset()
residual = torch.compile(residual)
regression = torch.compile(regression)

store = zarr.storage.LocalStore(
str(
Expand Down Expand Up @@ -990,14 +1013,14 @@ def _interpolate(self, x: torch.Tensor) -> torch.Tensor:
def _forward(self, x: torch.Tensor) -> torch.Tensor:
if self.solver not in ["euler", "heun"]:
raise ValueError(
f"solver must be either 'euler' or 'heun' but got {self.solver}"
f"solver must be either 'euler' or 'heun', " f"but got {self.solver}"
)

# Interpolate
x = self._interpolate(x)

# Add sample dimension
x = x.unsqueeze(0)
x = x.unsqueeze(0).to(memory_format=torch.channels_last)
x = (x - self.in_center) / self.in_scale

# Create grid channels
Expand All @@ -1018,43 +1041,49 @@ def _forward(self, x: torch.Tensor) -> torch.Tensor:
# Concat Grids
x = torch.cat((x, grid), dim=1)

# Repeat for sample size
sample_seeds = torch.arange(self.number_of_samples)
x = x.repeat(self.number_of_samples, 1, 1, 1)
# Create seeds for each sample
seed = self.seed if self.seed is not None else np.random.randint(2**32)
if seed:
gen = torch.Generator(device=x.device)
gen.manual_seed(seed)
else:
gen = None
sample_seeds = torch.randint(
0, 2**32, (self.number_of_samples,), device=x.device, generator=gen
)

# Create latents
rnd = StackedRandomGenerator(x.device, sample_seeds)
sampler_fn = partial(
deterministic_sampler,
num_steps=self.number_of_steps,
solver=self.solver,
)

# Get high-res image shape
coord = self.output_coords(self.input_coords())
img_resolution_x = coord["lat"].shape[0]
img_resolution_y = coord["lon"].shape[1]
latents = rnd.randn(
[
self.number_of_samples,
self.regression_model.img_out_channels,
img_resolution_x,
img_resolution_y,
],
device=x.device,
)
H_hr = coord["lat"].shape[0]
W_hr = coord["lon"].shape[1]

mean = self.unet_regression(
mean_hr = self.unet_regression(
self.regression_model,
torch.zeros_like(latents),
x,
num_steps=self.number_of_steps,
img_lr=x,
output_channels=len(OUT_VARIABLES),
number_of_samples=self.number_of_samples,
)
res = deterministic_sampler(
self.residual_model,
latents,
x,
randn_like=rnd.randn_like,
num_steps=self.number_of_steps,
solver=self.solver,

res_hr = diffusion_step(
net=self.residual_model,
sampler_fn=sampler_fn,
img_shape=(H_hr, W_hr),
img_out_channels=len(self.output_variables),
rank_batches=[sample_seeds], # Single rank
img_lr=x,
device=x.device,
mean_hr=mean_hr,
)
x = mean + res
x = self.out_scale * x + self.out_center
return x

x_hr = mean_hr + res_hr
x_hr = self.out_scale * x_hr + self.out_center
return x_hr

@batch_func()
def __call__(
Expand All @@ -1078,79 +1107,40 @@ def __call__(
@staticmethod
def unet_regression(
net: torch.nn.Module,
latents: torch.Tensor,
img_lr: torch.Tensor,
class_labels: torch.Tensor = None,
randn_like: Callable = torch.randn_like,
num_steps: int = 8,
sigma_min: float = 0.0,
sigma_max: float = 0.0,
rho: int = 7,
S_churn: float = 0,
S_min: float = 0,
S_max: float = float("inf"),
S_noise: float = 0.0,
output_channels: int,
number_of_samples: int,
) -> torch.Tensor:
"""
Perform U-Net regression with temporal sampling.
Perform U-Net regression.

Parameters
----------
net : torch.nn.Module
U-Net model for regression.
latents : torch.Tensor
Latent representation.
img_lr : torch.Tensor)
Low-resolution input image.
class_labels : torch.Tensor, optional
Class labels for conditional generation.
randn_like : function, optional
Function for generating random noise.
num_steps : int, optional
Number of time steps for temporal sampling.
sigma_min : float, optional
Minimum noise level.
sigma_max : float, optional
Maximum noise level.
rho : int, optional
Exponent for noise level interpolation.
S_churn : float, optional
Churning parameter.
S_min : float, optional
Minimum churning value.
S_max : float, optional
Maximum churning value.
S_noise : float, optional
Noise level for churning.
img_lr : torch.Tensor
Low-resolution input image of shape (1, C_in, H_hr, W_hr).
output_channels : int
Number of output channels C_out.
number_of_samples : int
Number of samples to generate for the single input batch element.
Only used to expand the shape of the output tensor.

Returns
-------
torch.Tensor: Predicted output at the next time step.
torch.Tensor: Predicted output with shape (number_of_samples, C_out,
H_hr, W_hr).
"""
# Adjust noise levels based on what's supported by the network.
sigma_min = max(sigma_min, 0)
sigma_max = min(sigma_max, np.inf)

# Time step discretization.
step_indices = torch.arange(
num_steps, dtype=torch.float64, device=latents.device
mean_hr = regression_step(
net=net,
img_lr=img_lr,
latents_shape=(
number_of_samples,
output_channels,
img_lr.shape[-2],
img_lr.shape[-1],
),
lead_time_label=None,
)
t_steps = (
sigma_max ** (1 / rho)
+ step_indices
/ (num_steps - 1)
* (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
) ** rho
t_steps = torch.cat(
[net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]
) # t_N = 0

# conditioning
x_lr = img_lr

# Main sampling loop.
x_hat = latents.to(torch.float64) * t_steps[0]

x_next = net(x_hat, x_lr).to(torch.float64)

return x_next
return mean_hr