-
Notifications
You must be signed in to change notification settings - Fork 611
Expand file tree
/
Copy pathnn.py
More file actions
371 lines (323 loc) · 13.4 KB
/
nn.py
File metadata and controls
371 lines (323 loc) · 13.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Literal
import numpy as np
from tensordict import TensorDict
import torch
from physicsnemo.core import Module
from physicsnemo.models.diffusion_unets import StormCastUNet
from physicsnemo.diffusion.preconditioners import EDMPrecond, EDMPreconditioner
from physicsnemo.diffusion.utils import ConcatConditionWrapper
# from physicsnemo.diffusion.samplers import deterministic_sampler
from physicsnemo.models.dit import DiT
from utils.sampler import deterministic_sampler
import utils.apex # do not remove, enables Apex LayerNorm with ShardTensor
def get_preconditioned_unet(
name: str,
target_channels: int,
conditional_channels: int = 0,
img_resolution: tuple = (512, 640),
model_type: str | None = None,
lead_time_steps: int = 0,
lead_time_channels: int = 4,
amp_mode: bool = False,
use_apex_gn: bool = False,
**model_kwargs,
) -> EDMPrecond | StormCastUNet:
"""
Create a preconditioner-wrapped SongUNet network.
Args:
name: 'regression' or 'diffusion' to select between either model type
target_channels: The number of channels in the target
conditional_channels: The number of channels in the conditioning
img_resolution: resolution of the data (U-Net inputs/outputs)
model_type: the model class to use, or None to select it automatically
lead_time_steps: the number of possible lead time steps, if 0 lead time embedding will be disabled
lead_time_channels: the number of channels to use for each lead time embedding
amp_mode: whether to use automatic mixed precision
use_apex_gn: whether to use Apex GroupNorm
Returns:
EDMPrecond or StormCastUNet: a wrapped torch module net(x+n, sigma, condition, class_labels) -> x
"""
if model_type is None:
model_type = "SongUNetPosLtEmbd" if lead_time_steps else "SongUNet"
model_params = {
"img_resolution": img_resolution,
"img_out_channels": target_channels,
"model_type": model_type,
"amp_mode": amp_mode,
"use_apex_gn": use_apex_gn,
}
model_params.update(model_kwargs)
if lead_time_steps:
model_params["N_grid_channels"] = 0
model_params["lead_time_channels"] = lead_time_channels
model_params["lead_time_steps"] = lead_time_steps
else:
lead_time_channels = 0
if name == "diffusion":
return EDMPrecond(
img_channels=target_channels + conditional_channels + lead_time_channels,
**model_params,
)
elif name == "regression":
return StormCastUNet(
img_in_channels=conditional_channels + lead_time_channels,
embedding_type="zero",
**model_params,
)
def get_preconditioned_natten_dit(
target_channels: int,
conditional_channels: int = 0,
scalar_condition_channels: int = 0,
img_resolution: tuple = (512, 640),
hidden_size: int = 768,
depth: int = 16,
num_heads: int = 16,
patch_size: int = 4,
attn_kernel_size: int = 31,
lead_time_steps: int = 0,
layernorm_backend: Literal["torch", "apex"] = "apex",
conditioning_embedder: Literal["dit", "edm", "zero"] = "dit",
**model_kwargs,
) -> EDMPreconditioner:
"""
Create a preconditioner-wrapped Diffusion Transformer (DiT) network.
Args:
target_channels: The number of channels in the target
conditional_channels: The number of channels in the conditioning
scalar_condition_channels: The number of scalar condition channels
img_resolution: Resolution of the data (DiT inputs/outputs)
hidden_size: The number of channels in the internal layers of the DiT
depth: The number of transformer blocks in the DiT
num_heads: number of heads in multi-head attention
patch_size: the patch size used by the DiT embedder
attn_kernel_size: the attention neighborhood size
lead_time_steps: the number of possible lead time steps, if 0 lead time embedding will be disabled
**model_kwargs: any additional parameters to the model
Returns:
EDMPrecond or StormCastUNet: a wrapped torch module net(x+n, sigma, condition, class_labels) -> x
"""
condition_dim = scalar_condition_channels + lead_time_steps
attn_kwargs = {"attn_kernel": attn_kernel_size}
dit = DiT(
input_size=img_resolution,
in_channels=target_channels + conditional_channels,
out_channels=target_channels,
hidden_size=hidden_size,
depth=depth,
num_heads=num_heads,
patch_size=patch_size,
attention_backend="natten2d",
layernorm_backend=layernorm_backend,
attn_kwargs=attn_kwargs,
condition_dim=condition_dim,
conditioning_embedder=conditioning_embedder,
**model_kwargs,
)
return EDMPreconditioner(model=ConcatConditionWrapper(dit))
def build_network_condition_and_target(
background: torch.Tensor,
state: tuple[torch.Tensor, torch.Tensor],
invariant_tensor: torch.Tensor | None,
scalar_conditions: torch.Tensor | None = None,
lead_time_label: torch.Tensor | None = None,
regression_net: Module | None = None,
condition_list: Iterable[str] = ("state", "background"),
regression_condition_list: Iterable[str] = ("state", "background"),
) -> tuple[torch.Tensor | TensorDict, torch.Tensor, torch.Tensor | None]:
"""Build the condition and target tensors for the network.
Args:
background: background tensor
state: tuple of previous state and target state
invariant_tensor: invariant tensor or None if no invariant is used
lead_time_label: lead time label or None if lead time embedding is not used
regression_net: regression model, can be None if 'regression' is not in condition_list
condition_list: list of conditions to include, may include 'state', 'background', 'regression' and 'invariant'
regression_condition_list: list of conditions for the regression network, may include 'state', 'background', and 'invariant'
This is only used if regression_net is set.
Returns:
A tuple of tensors: (
condition: model condition concatenated from conditions specified in condition_list,
target: training target,
regression: regression model output
). The regression model output will be None if 'regression' is not in condition_list.
"""
if ("regression" in condition_list) and (regression_net is None):
raise ValueError(
"regression_net must be provided if 'regression' is in condition_list"
)
target = state[1]
condition_tensors = {
"state": state[0],
"background": background,
"invariant": invariant_tensor,
"regression": None,
}
with torch.no_grad():
if "regression" in condition_list:
# Inference regression model
condition_tensors["regression"] = regression_model_forward(
regression_net,
state[0],
background,
invariant_tensor,
lead_time_label=lead_time_label,
condition_list=regression_condition_list,
)
target = target - condition_tensors["regression"]
condition = [
y for c in condition_list if (y := condition_tensors[c]) is not None
]
condition = torch.cat(condition, dim=1) if condition else None
if scalar_conditions is not None:
condition = TensorDict(
{"cond_concat": condition, "cond_vec": scalar_conditions}
if condition is not None
else {"cond_vec": scalar_conditions},
device=state[1].device,
).to(dtype=state[1].dtype)
return (condition, target, condition_tensors["regression"])
def unpack_batch(
batch: dict[str, Any],
device: torch.device | str,
memory_format: torch.memory_format = torch.preserve_format,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Unpack a data batch into background, state and lead time label with the correct
device and data types.
"""
if isinstance(batch["state"], torch.Tensor):
# downscaling and unconditional models may return a single tensor as "state"
batch["state"] = [None, batch["state"]]
(background, state, mask) = nested_to(
(batch.get("background"), batch["state"], batch.get("mask")),
device=device,
dtype=torch.float32,
non_blocking=True,
memory_format=memory_format,
)
lead_time_label = batch.get("lead_time_label")
if lead_time_label is not None:
lead_time_label = lead_time_label.to(
device=device, dtype=torch.int64, non_blocking=True
)
scalar_conditions = batch.get("scalar_conditions")
if scalar_conditions is not None:
scalar_conditions = scalar_conditions.to(
device=device, dtype=torch.float32, non_blocking=True
)
return (background, state, mask, lead_time_label, scalar_conditions)
def diffusion_model_forward(
model: Module,
condition: torch.Tensor,
shape: Iterable[int],
dtype: torch.dtype,
device: torch.device,
lead_time_label: torch.Tensor | None = None,
sampler_args: dict[str, Any] = {},
) -> torch.Tensor:
"""Helper function to run diffusion model sampling"""
# TODO: avoid creating full tensor here when sharding
latents = torch.randn(*shape, device=device, dtype=dtype)
if not hasattr(model, "sigma_min"):
model.sigma_min = 0.0
if not hasattr(model, "sigma_max"):
model.sigma_max = np.inf
if not hasattr(model, "round_sigma"):
model.round_sigma = torch.as_tensor
return deterministic_sampler(
model,
latents=latents,
img_lr=condition,
lead_time_label=lead_time_label,
dtype=dtype,
**sampler_args,
)
def regression_model_forward(
model: Module,
state: torch.Tensor,
background: torch.Tensor,
invariant_tensor: torch.Tensor,
lead_time_label: torch.Tensor | None = None,
condition_list: Iterable[str] = ("state", "background"),
) -> torch.Tensor:
"""Helper function to run regression model forward pass in inference"""
(x, _, _) = build_network_condition_and_target(
background,
(state, None),
invariant_tensor,
lead_time_label=lead_time_label,
condition_list=condition_list,
)
labels = {} if lead_time_label is None else {"lead_time_label": lead_time_label}
return model(x, **labels)
def regression_loss_fn(
net: Module,
images: torch.Tensor,
condition: torch.Tensor,
class_labels: None = None,
lead_time_label: torch.Tensor | None = None,
augment_pipe: Callable | None = None,
return_model_outputs: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Helper function for training the StormCast regression model, so that it has a similar call signature as
the EDMLoss and the same training loop can be used to train both regression and diffusion models
Args:
net: physicsnemo.models.diffusion.StormCastUNet
images: Target data, shape [batch_size, target_channels, w, h]
condition: input to the model, shape=[batch_size, condition_channel, w, h]
class_labels: unused (applied to match EDMLoss signature)
lead_time_label: lead time label or None if lead time embedding is not used
augment_pipe: optional data augmentation pipe
return_model_outputs: If True, will return the generated outputs
Returns:
out: loss function with shape [batch_size, target_channels, w, h]
This should be averaged to get the mean loss for gradient descent.
"""
y, augment_labels = (
augment_pipe(images) if augment_pipe is not None else (images, None)
)
labels = {} if lead_time_label is None else {"lead_time_label": lead_time_label}
D_yn = net(x=condition, **labels)
loss = (D_yn - y) ** 2
if return_model_outputs:
return loss, D_yn
else:
return loss
def nested_to(
x: torch.Tensor | Mapping | list | tuple | Any, **kwargs
) -> torch.Tensor | dict | list | Any:
"""Move tensors in nested structures to a device/dtype.
Parameters
----------
x : torch.Tensor or Mapping or list or tuple
Input structure.
**kwargs
Keyword arguments forwarded to ``Tensor.to``.
Returns
-------
torch.Tensor or dict or list
Structure with tensors moved.
"""
if isinstance(x, Mapping):
return {k: nested_to(v, **kwargs) for (k, v) in x.items()}
elif isinstance(x, (list, tuple)):
return [nested_to(v, **kwargs) for v in x]
else:
if not isinstance(x, torch.Tensor):
return x
return x.to(**kwargs)