-
Notifications
You must be signed in to change notification settings - Fork 74
Updates to CorrDiffTaiwan model wrapper #455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
5cce8d2
ec255fd
2d7b597
1920574
f7e079a
605187a
137a0b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -883,7 +883,7 @@ def load_default_package(cls) -> Package: | |
|
|
||
| @classmethod | ||
| @check_optional_dependencies() | ||
| def load_model(cls, package: Package) -> DiagnosticModel: | ||
| def load_model(cls, package: Package, device: str | None = None) -> DiagnosticModel: | ||
| """Load diagnostic from package""" | ||
|
|
||
| if StackedRandomGenerator is None or deterministic_sampler is None: | ||
|
|
@@ -900,15 +900,33 @@ def load_model(cls, package: Package) -> DiagnosticModel: | |
| str( | ||
| checkpoint_zip.parent | ||
| / Path("corrdiff_inference_package/checkpoints/diffusion.mdlus") | ||
| ) | ||
| ).eval() | ||
| ), | ||
| override_args={"use_apex_gn": True} | ||
| ) | ||
| residual.use_fp16, residual.profile_mode = True, False | ||
| residual = residual.eval() | ||
| if device is not None: | ||
| residual = residual.to(device) | ||
| residual = residual.to(memory_format=torch.channels_last) | ||
|
|
||
| regression = PhysicsNemoModule.from_checkpoint( | ||
| str( | ||
| checkpoint_zip.parent | ||
| / Path("corrdiff_inference_package/checkpoints/regression.mdlus") | ||
| ) | ||
| ).eval() | ||
| ), | ||
| override_args={"use_apex_gn": True} | ||
| ) | ||
| regression.use_fp16, regression.profile_mode = True, False | ||
| regression = regression.eval() | ||
| if device is not None: | ||
| regression = regression.to(device) | ||
| regression = regression.to(memory_format=torch.channels_last) | ||
|
|
||
| # Compile models | ||
| torch._dynamo.config.cache_size_limit = 264 | ||
| torch._dynamo.reset() | ||
| residual = torch.compile(residual) | ||
| regression = torch.compile(regression) | ||
|
|
||
| store = zarr.storage.LocalStore( | ||
| str( | ||
|
|
@@ -997,10 +1015,11 @@ def _forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self._interpolate(x) | ||
|
|
||
| # Add sample dimension | ||
| x = x.unsqueeze(0) | ||
| x = x.unsqueeze(0).to(memory_format=torch.channels_last) | ||
| x = (x - self.in_center) / self.in_scale | ||
|
|
||
| # Create grid channels | ||
| # TODO: why do we need this grid concatenated to the input? | ||
|
||
| x1 = np.sin(np.linspace(0, 2 * np.pi, 448)) | ||
| x2 = np.cos(np.linspace(0, 2 * np.pi, 448)) | ||
| y1 = np.sin(np.linspace(0, 2 * np.pi, 448)) | ||
|
|
@@ -1018,43 +1037,44 @@ def _forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # Concat Grids | ||
| x = torch.cat((x, grid), dim=1) | ||
|
|
||
| # Repeat for sample size | ||
| sample_seeds = torch.arange(self.number_of_samples) | ||
| x = x.repeat(self.number_of_samples, 1, 1, 1) | ||
| # Create seeds for each sample | ||
| seed = self.seed if self.seed is not None else np.random.randint(2**32) | ||
CharlelieLrt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| gen = torch.Generator(device=x.device) | ||
| gen.manual_seed(seed) | ||
| sample_seeds = torch.randint(0, 2**32, (self.number_of_samples,), device=x.device, generator=gen) | ||
CharlelieLrt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Create latents | ||
| rnd = StackedRandomGenerator(x.device, sample_seeds) | ||
| sampler_fn = partial( | ||
| deterministic_sampler, | ||
| num_steps=self.number_of_steps, | ||
| solver=self.solver, | ||
| ) | ||
|
|
||
| # Get high-res image shape | ||
| coord = self.output_coords(self.input_coords()) | ||
| img_resolution_x = coord["lat"].shape[0] | ||
| img_resolution_y = coord["lon"].shape[1] | ||
| latents = rnd.randn( | ||
| [ | ||
| self.number_of_samples, | ||
| self.regression_model.img_out_channels, | ||
| img_resolution_x, | ||
| img_resolution_y, | ||
| ], | ||
| device=x.device, | ||
| ) | ||
| H_hr = coord["lat"].shape[0] | ||
| W_hr = coord["lon"].shape[1] | ||
|
|
||
| mean = self.unet_regression( | ||
| mean_hr = self.unet_regression( | ||
| self.regression_model, | ||
| torch.zeros_like(latents), | ||
| x, | ||
| num_steps=self.number_of_steps, | ||
| img_lr=x, | ||
| output_channels=len(self.output_variables), | ||
CharlelieLrt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| number_of_samples=self.number_of_samples, | ||
| ) | ||
CharlelieLrt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| res = deterministic_sampler( | ||
| self.residual_model, | ||
| latents, | ||
| x, | ||
| randn_like=rnd.randn_like, | ||
| num_steps=self.number_of_steps, | ||
| solver=self.solver, | ||
|
|
||
| res_hr = diffusion_step( | ||
| net=self.residual_model, | ||
| sampler_fn=sampler_fn, | ||
| img_shape=(H_hr, W_hr), | ||
| img_out_channels=len(self.output_variables), | ||
CharlelieLrt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| rank_batches=sample_seeds, | ||
CharlelieLrt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| img_lr=x, | ||
| device=x.device, | ||
| mean_hr=mean_hr, | ||
| ) | ||
| x = mean + res | ||
| x = self.out_scale * x + self.out_center | ||
| return x | ||
|
|
||
| x_hr = mean_hr + res_hr | ||
| x_hr = self.out_scale * x_hr + self.out_center | ||
| return x_hr | ||
|
|
||
| @batch_func() | ||
| def __call__( | ||
|
|
@@ -1078,79 +1098,43 @@ def __call__( | |
| @staticmethod | ||
| def unet_regression( | ||
| net: torch.nn.Module, | ||
| latents: torch.Tensor, | ||
| img_lr: torch.Tensor, | ||
| class_labels: torch.Tensor = None, | ||
| randn_like: Callable = torch.randn_like, | ||
| num_steps: int = 8, | ||
| sigma_min: float = 0.0, | ||
| sigma_max: float = 0.0, | ||
| rho: int = 7, | ||
| S_churn: float = 0, | ||
| S_min: float = 0, | ||
| S_max: float = float("inf"), | ||
| S_noise: float = 0.0, | ||
| lead_time_label: torch.Tensor, | ||
| output_channels: int, | ||
| number_of_samples: int, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Perform U-Net regression with temporal sampling. | ||
| Perform U-Net regression. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| net : torch.nn.Module | ||
| U-Net model for regression. | ||
| latents : torch.Tensor | ||
| Latent representation. | ||
| img_lr : torch.Tensor) | ||
| Low-resolution input image. | ||
| class_labels : torch.Tensor, optional | ||
| Class labels for conditional generation. | ||
| randn_like : function, optional | ||
| Function for generating random noise. | ||
| num_steps : int, optional | ||
| Number of time steps for temporal sampling. | ||
| sigma_min : float, optional | ||
| Minimum noise level. | ||
| sigma_max : float, optional | ||
| Maximum noise level. | ||
| rho : int, optional | ||
| Exponent for noise level interpolation. | ||
| S_churn : float, optional | ||
| Churning parameter. | ||
| S_min : float, optional | ||
| Minimum churning value. | ||
| S_max : float, optional | ||
| Maximum churning value. | ||
| S_noise : float, optional | ||
| Noise level for churning. | ||
| img_lr : torch.Tensor | ||
| Low-resolution input image of shape (1, C_in, H_hr, W_hr). | ||
| lead_time_label : torch.Tensor | ||
| Lead time label of shape (1, 1, 1, 1), or (1,). | ||
| output_channels : int | ||
| Number of output channels C_out. | ||
| number_of_samples : int | ||
| Number of samples to generate for the single input batch element. | ||
| Only used to expand the shape of the output tensor. | ||
|
|
||
| Returns | ||
| ------- | ||
| torch.Tensor: Predicted output at the next time step. | ||
| torch.Tensor: Predicted output with shape (number_of_samples, C_out, | ||
| H_hr, W_hr). | ||
| """ | ||
| # Adjust noise levels based on what's supported by the network. | ||
| sigma_min = max(sigma_min, 0) | ||
| sigma_max = min(sigma_max, np.inf) | ||
|
|
||
| # Time step discretization. | ||
| step_indices = torch.arange( | ||
| num_steps, dtype=torch.float64, device=latents.device | ||
| mean_hr = regression_step( | ||
| net=net, | ||
| img_lr=img_lr, | ||
| latents_shape=( | ||
| number_of_samples, | ||
| output_channels, | ||
| img_lr.shape[-2], | ||
| img_lr.shape[-1], | ||
| ), | ||
| lead_time_label=lead_time_label | ||
| ) | ||
| t_steps = ( | ||
| sigma_max ** (1 / rho) | ||
| + step_indices | ||
| / (num_steps - 1) | ||
| * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) | ||
| ) ** rho | ||
| t_steps = torch.cat( | ||
| [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] | ||
| ) # t_N = 0 | ||
|
|
||
| # conditioning | ||
| x_lr = img_lr | ||
|
|
||
| # Main sampling loop. | ||
| x_hat = latents.to(torch.float64) * t_steps[0] | ||
|
|
||
| x_next = net(x_hat, x_lr).to(torch.float64) | ||
|
|
||
| return x_next | ||
| return mean_hr | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model wrappers are modules themselves, so users are expected to use wrapper.to(device).
So load_model (unless its gpu only model) just places things on the cpu and then users move it to the device. Can this be done with the optimization settings here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I could do that and move the
.to(device)in the Triton server instead. There's however one caveat: the.to(memory_format=channels_last)has to be done after the.to(device). So the user code (and here the Triton server) would have the respnsability of doing both the.to(device)and the.to(memory_formet=channels_last). Would that be okay?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is one of the few instances where it might be beneficial to allow the user to specify the device here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll leave the
deviceargument here for now