Skip to content

Commit 201aeac

Browse files
166 improve ddpm encoder (#223)
* 🔨 Improve ddpm input encoders * 📝 Update ddpm docstrings * 🐛 Fix evaluation bug due to newer PyTorch version * ⏪ Remove the solution for the evaluation bug due to the newer PyTorch version, as it will be used locally only for now * 🔨 Run pre-commit
1 parent dbdb28c commit 201aeac

1 file changed

Lines changed: 99 additions & 27 deletions

File tree

icenet_mp/models/ddpm.py

Lines changed: 99 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,37 @@ def __init__( # noqa: PLR0913
112112
self.base_output_channels = self.output_space["channels"]
113113
else:
114114
self.base_output_channels = self.output_space.channels
115-
self.base_output_channels = self.output_space.channels
116115

117116
self.output_channels = self.n_forecast_steps * self.base_output_channels
118117
self.timesteps = timesteps
119118
self.cond_channels = 64
120119
self.input_channels = self.cond_channels
121120

121+
# "InstanceNorm" calculates the mean/std per batch, removing the need for offline preprocessing
122+
self.era5_norm = torch.nn.InstanceNorm3d(self.era5_space, affine=True)
123+
124+
# Reduces the many ERA5 channels down to 32 important ones using 1x1 Conv
125+
self.era5_compressed_channels = 32
126+
self.era5_projector = torch.nn.Sequential(
127+
torch.nn.Conv3d(
128+
self.era5_space, self.era5_compressed_channels, kernel_size=1
129+
),
130+
torch.nn.SiLU(),
131+
)
132+
122133
self.osisaf_encoder = SimpleEncoder2D(
123134
in_channels=self.n_history_steps,
124135
out_channels=self.cond_channels // 2,
125136
)
126137

138+
# (Compressed_Channels * Time_Steps), preserving time history
127139
self.era5_encoder = torch.nn.Sequential(
128-
torch.nn.Conv3d(self.era5_space, self.cond_channels // 2, 3, padding=1),
140+
torch.nn.Conv2d(
141+
in_channels=self.era5_compressed_channels * self.n_history_steps,
142+
out_channels=self.cond_channels // 2,
143+
kernel_size=3,
144+
padding=1,
145+
),
129146
torch.nn.GroupNorm(4, self.cond_channels // 2),
130147
torch.nn.SiLU(),
131148
)
@@ -231,6 +248,8 @@ def loss(
231248
def prepare_inputs(self, batch: dict[str, TensorNTCHW]) -> torch.Tensor:
232249
"""Encode OSISAF and ERA5 separately, then concatenate.
233250
251+
ERA5 -> Norm -> Project -> Resize -> Flatten Time -> Encode
252+
234253
Args:
235254
batch: Dictionary with
236255
'osisaf-south' [B, T, 1, H, W]
@@ -243,42 +262,65 @@ def prepare_inputs(self, batch: dict[str, TensorNTCHW]) -> torch.Tensor:
243262
osisaf = batch[self.osisaf_key] # [B, T, 1, H, W]
244263
era5 = batch["era5"] # [B, T, C, H2, W2]
245264

246-
# Squeeze OSISAF singleton channel
265+
# Handle OSISAF
247266
osisaf = osisaf.squeeze(2) # [B, T, H, W]
248-
249-
# Resize ERA5 spatially to match OSISAF resolution
250-
B, T, C, H2, W2 = era5.shape # noqa: N806
251267
H, W = osisaf.shape[-2:] # noqa: N806
252268

253-
# Flatten batch and time dimensions for spatial interpolation
254-
# [B, T, C, H2, W2] -> [B*T, C, H2, W2]
255-
era5_flat = era5.reshape(B * T, C, H2, W2)
269+
# Handle ERA5
270+
# Permute to [B, C, T, H2, W2] for 3D operations
271+
era5 = era5.permute(0, 2, 1, 3, 4)
272+
273+
# Normalize (On-the-fly standardization)
274+
era5 = self.era5_norm(era5)
275+
276+
# Project (Learnable Feature Selection)
277+
# [B, C, T, H2, W2] -> [B, 32, T, H2, W2]
278+
era5 = self.era5_projector(era5)
279+
280+
# Resize Spatially
281+
B, C_new, T, H2, W2 = era5.shape # noqa: N806
282+
# Flatten batch/channel/time for interpolation
283+
era5_flat = era5.reshape(B * C_new * T, 1, H2, W2)
284+
256285
era5_resized = F.interpolate(
257286
era5_flat,
258287
size=(H, W),
259288
mode="bilinear",
260289
align_corners=False,
261290
)
262-
# Reshape back to [B, T, C, H, W]
263-
era5 = era5_resized.reshape(B, T, C, H, W)
264291

265-
# Permute to [B, C, T, H, W] for Conv3d
266-
era5 = era5.permute(0, 2, 1, 3, 4) # [B, C, T, H, W]
292+
# Flatten Time into Channels
293+
# Reshape back to [B, C_new, T, H, W] then flatten T into C
294+
era5_features = era5_resized.reshape(B, C_new * T, H, W)
267295

268-
# Encode both inputs
296+
# Encode Both
269297
osisaf_features = self.osisaf_encoder(osisaf) # [B, cond//2, H, W]
270-
era5_features = self.era5_encoder(era5) # [B, cond//2, T, H, W]
271298

272-
# Pool over time dimension for ERA5
273-
era5_features = era5_features.mean(dim=2) # [B, cond//2, H, W]
299+
# era5_features enters the encoder as a 2D tensor with many channels
300+
era5_features = self.era5_encoder(era5_features) # [B, cond//2, H, W]
274301

275-
# Concatenate along channel dimension
276-
return torch.cat([osisaf_features, era5_features], dim=1)
302+
return torch.cat([osisaf_features, era5_features], dim=1) # [B, cond, H, W]
277303

278304
def training_step(
279305
self, batch: dict[str, TensorNTCHW], _batch_idx: int
280306
) -> torch.Tensor:
281-
"""One training step using DDPM loss (predicted noise vs. true noise)."""
307+
"""One training step using DDPM v-prediction loss.
308+
309+
During training, the clean target (SIC) is corrupted using the forward
310+
diffusion process by adding noise at a randomly sampled timestep.
311+
The model is trained to predict the corresponding v-target.
312+
313+
Args:
314+
batch (dict[str, TensorNTCHW]):
315+
Dictionary containing:
316+
- input tensors (used to prepare conditioning inputs)
317+
- "target": groundtruth SIC tensor
318+
319+
Returns:
320+
torch.Tensor:
321+
Scalar training loss (MSE between predicted v and target v).
322+
323+
"""
282324
# Prepare input tensor by combining osisaf-south and era5
283325
x = self.prepare_inputs(batch) # [B, T, C_combined, H, W]
284326

@@ -290,14 +332,14 @@ def training_step(
290332
0, self.timesteps, (x.shape[0],), device=self.device
291333
).long() # look into this
292334

293-
# Create noisy version using scaled target
335+
# Create noisy version
294336
noise = torch.randn_like(y)
295337
noisy_y = self.diffusion.q_sample(y, t, noise)
296338

297339
# Predict v
298340
pred_v = self.model(noisy_y, t, x)
299341

300-
# Compute target v using scaled data
342+
# Compute target v
301343
target_v = self.diffusion.calculate_v(y, noise, t)
302344

303345
# Compute loss
@@ -316,7 +358,25 @@ def training_step(
316358
def validation_step(
317359
self, batch: dict[str, TensorNTCHW], _batch_idx: int
318360
) -> torch.Tensor:
319-
"""One validation step using the specified loss function defined in the criterion."""
361+
"""One validation step using full diffusion sampling.
362+
363+
During validation, samples are generated by starting from noise and
364+
iteratively denoising conditioned on the inputs. The final prediction
365+
is compared to the groundtruth SIC using the configured evaluation loss.
366+
367+
Args:
368+
batch (dict[str, TensorNTCHW]):
369+
Dictionary containing:
370+
- input tensors (used to prepare conditioning inputs)
371+
- "target": groundtruth SIC tensor
372+
- optional "sample_weight": weighting tensor
373+
374+
Returns:
375+
torch.Tensor:
376+
Scalar validation loss computed between predicted SIC (y_hat)
377+
and groundtruth SIC (y).
378+
379+
"""
320380
# Prepare input tensor
321381
x = self.prepare_inputs(batch) # [B, T, C_combined, H, W]
322382

@@ -349,14 +409,26 @@ def test_step(
349409
batch: dict[str, TensorNTCHW],
350410
_batch_idx: int, # noqa: PT019
351411
) -> ModelTestOutput:
352-
"""One test step using the specified loss function and full metric evaluation.
412+
"""One test step using full diffusion sampling and metric evaluation.
413+
414+
During testing, predictions are generated by starting from noise
415+
and running the reverse diffusion process conditioned on the inputs.
416+
The final reconstructed SIC is compared to the groundtruth target
417+
using the configured loss and test metrics.
353418
354419
Args:
355-
batch (tuple): (x, y, sample_weight)
356-
batch_idx (int): Batch index.
420+
batch (dict[str, TensorNTCHW]):
421+
Dictionary containing:
422+
- input tensors (used to prepare conditioning inputs)
423+
- "target": groundtruth SIC tensor
424+
- optional "sample_weight": weighting tensor
357425
358426
Returns:
359-
torch.Tensor: Loss value.
427+
ModelTestOutput:
428+
Object containing:
429+
- prediction: reconstructed SIC (y_hat)
430+
- target: groundtruth SIC (y)
431+
- loss: test loss value
360432
361433
"""
362434
x = self.prepare_inputs(batch) # [B, T, C_combined, H, W]

0 commit comments

Comments
 (0)