@@ -112,20 +112,37 @@ def __init__( # noqa: PLR0913
112112 self .base_output_channels = self .output_space ["channels" ]
113113 else :
114114 self .base_output_channels = self .output_space .channels
115- self .base_output_channels = self .output_space .channels
116115
117116 self .output_channels = self .n_forecast_steps * self .base_output_channels
118117 self .timesteps = timesteps
119118 self .cond_channels = 64
120119 self .input_channels = self .cond_channels
121120
121+ # "InstanceNorm" calculates the mean/std per batch, removing the need for offline preprocessing
122+ self .era5_norm = torch .nn .InstanceNorm3d (self .era5_space , affine = True )
123+
124+ # Reduces the many ERA5 channels down to 32 important ones using 1x1 Conv
125+ self .era5_compressed_channels = 32
126+ self .era5_projector = torch .nn .Sequential (
127+ torch .nn .Conv3d (
128+ self .era5_space , self .era5_compressed_channels , kernel_size = 1
129+ ),
130+ torch .nn .SiLU (),
131+ )
132+
122133 self .osisaf_encoder = SimpleEncoder2D (
123134 in_channels = self .n_history_steps ,
124135 out_channels = self .cond_channels // 2 ,
125136 )
126137
138+ # (Compressed_Channels * Time_Steps), preserving time history
127139 self .era5_encoder = torch .nn .Sequential (
128- torch .nn .Conv3d (self .era5_space , self .cond_channels // 2 , 3 , padding = 1 ),
140+ torch .nn .Conv2d (
141+ in_channels = self .era5_compressed_channels * self .n_history_steps ,
142+ out_channels = self .cond_channels // 2 ,
143+ kernel_size = 3 ,
144+ padding = 1 ,
145+ ),
129146 torch .nn .GroupNorm (4 , self .cond_channels // 2 ),
130147 torch .nn .SiLU (),
131148 )
@@ -231,6 +248,8 @@ def loss(
231248 def prepare_inputs (self , batch : dict [str , TensorNTCHW ]) -> torch .Tensor :
232249 """Encode OSISAF and ERA5 separately, then concatenate.
233250
251+ ERA5 -> Norm -> Project -> Resize -> Flatten Time -> Encode
252+
234253 Args:
235254 batch: Dictionary with
236255 'osisaf-south' [B, T, 1, H, W]
@@ -243,42 +262,65 @@ def prepare_inputs(self, batch: dict[str, TensorNTCHW]) -> torch.Tensor:
243262 osisaf = batch [self .osisaf_key ] # [B, T, 1, H, W]
244263 era5 = batch ["era5" ] # [B, T, C, H2, W2]
245264
246- # Squeeze OSISAF singleton channel
265+ # Handle OSISAF
247266 osisaf = osisaf .squeeze (2 ) # [B, T, H, W]
248-
249- # Resize ERA5 spatially to match OSISAF resolution
250- B , T , C , H2 , W2 = era5 .shape # noqa: N806
251267 H , W = osisaf .shape [- 2 :] # noqa: N806
252268
253- # Flatten batch and time dimensions for spatial interpolation
254- # [B, T, C, H2, W2] -> [B*T, C, H2, W2]
255- era5_flat = era5 .reshape (B * T , C , H2 , W2 )
269+ # Handle ERA5
270+ # Permute to [B, C, T, H2, W2] for 3D operations
271+ era5 = era5 .permute (0 , 2 , 1 , 3 , 4 )
272+
273+ # Normalize (On-the-fly standardization)
274+ era5 = self .era5_norm (era5 )
275+
276+ # Project (Learnable Feature Selection)
277+ # [B, C, T, H2, W2] -> [B, 32, T, H2, W2]
278+ era5 = self .era5_projector (era5 )
279+
280+ # Resize Spatially
281+ B , C_new , T , H2 , W2 = era5 .shape # noqa: N806
282+ # Flatten batch/channel/time for interpolation
283+ era5_flat = era5 .reshape (B * C_new * T , 1 , H2 , W2 )
284+
256285 era5_resized = F .interpolate (
257286 era5_flat ,
258287 size = (H , W ),
259288 mode = "bilinear" ,
260289 align_corners = False ,
261290 )
262- # Reshape back to [B, T, C, H, W]
263- era5 = era5_resized .reshape (B , T , C , H , W )
264291
265- # Permute to [B, C, T, H, W] for Conv3d
266- era5 = era5 .permute (0 , 2 , 1 , 3 , 4 ) # [B, C, T, H, W]
292+ # Flatten Time into Channels
293+ # Reshape back to [B, C_new, T, H, W] then flatten T into C
294+ era5_features = era5_resized .reshape (B , C_new * T , H , W )
267295
268- # Encode both inputs
296+ # Encode Both
269297 osisaf_features = self .osisaf_encoder (osisaf ) # [B, cond//2, H, W]
270- era5_features = self .era5_encoder (era5 ) # [B, cond//2, T, H, W]
271298
272- # Pool over time dimension for ERA5
273- era5_features = era5_features . mean ( dim = 2 ) # [B, cond//2, H, W]
299+ # era5_features enters the encoder as a 2D tensor with many channels
300+ era5_features = self . era5_encoder ( era5_features ) # [B, cond//2, H, W]
274301
275- # Concatenate along channel dimension
276- return torch .cat ([osisaf_features , era5_features ], dim = 1 )
302+ return torch .cat ([osisaf_features , era5_features ], dim = 1 ) # [B, cond, H, W]
277303
278304 def training_step (
279305 self , batch : dict [str , TensorNTCHW ], _batch_idx : int
280306 ) -> torch .Tensor :
281- """One training step using DDPM loss (predicted noise vs. true noise)."""
307+ """One training step using DDPM v-prediction loss.
308+
309+ During training, the clean target (SIC) is corrupted using the forward
310+ diffusion process by adding noise at a randomly sampled timestep.
311+ The model is trained to predict the corresponding v-target.
312+
313+ Args:
314+ batch (dict[str, TensorNTCHW]):
315+ Dictionary containing:
316+ - input tensors (used to prepare conditioning inputs)
317+ - "target": groundtruth SIC tensor
318+
319+ Returns:
320+ torch.Tensor:
321+ Scalar training loss (MSE between predicted v and target v).
322+
323+ """
282324 # Prepare input tensor by combining osisaf-south and era5
283325 x = self .prepare_inputs (batch ) # [B, T, C_combined, H, W]
284326
@@ -290,14 +332,14 @@ def training_step(
290332 0 , self .timesteps , (x .shape [0 ],), device = self .device
291333 ).long () # look into this
292334
293- # Create noisy version using scaled target
335+ # Create noisy version
294336 noise = torch .randn_like (y )
295337 noisy_y = self .diffusion .q_sample (y , t , noise )
296338
297339 # Predict v
298340 pred_v = self .model (noisy_y , t , x )
299341
300- # Compute target v using scaled data
342+ # Compute target v
301343 target_v = self .diffusion .calculate_v (y , noise , t )
302344
303345 # Compute loss
@@ -316,7 +358,25 @@ def training_step(
316358 def validation_step (
317359 self , batch : dict [str , TensorNTCHW ], _batch_idx : int
318360 ) -> torch .Tensor :
319- """One validation step using the specified loss function defined in the criterion."""
361+ """One validation step using full diffusion sampling.
362+
363+ During validation, samples are generated by starting from noise and
364+ iteratively denoising conditioned on the inputs. The final prediction
365+ is compared to the groundtruth SIC using the configured evaluation loss.
366+
367+ Args:
368+ batch (dict[str, TensorNTCHW]):
369+ Dictionary containing:
370+ - input tensors (used to prepare conditioning inputs)
371+ - "target": groundtruth SIC tensor
372+ - optional "sample_weight": weighting tensor
373+
374+ Returns:
375+ torch.Tensor:
376+ Scalar validation loss computed between predicted SIC (y_hat)
377+ and groundtruth SIC (y).
378+
379+ """
320380 # Prepare input tensor
321381 x = self .prepare_inputs (batch ) # [B, T, C_combined, H, W]
322382
@@ -349,14 +409,26 @@ def test_step(
349409 batch : dict [str , TensorNTCHW ],
350410 _batch_idx : int , # noqa: PT019
351411 ) -> ModelTestOutput :
352- """One test step using the specified loss function and full metric evaluation.
412+ """One test step using full diffusion sampling and metric evaluation.
413+
414+ During testing, predictions are generated by starting from noise
415+ and running the reverse diffusion process conditioned on the inputs.
416+ The final reconstructed SIC is compared to the groundtruth target
417+ using the configured loss and test metrics.
353418
354419 Args:
355- batch (tuple): (x, y, sample_weight)
356- batch_idx (int): Batch index.
420+ batch (dict[str, TensorNTCHW]):
421+ Dictionary containing:
422+ - input tensors (used to prepare conditioning inputs)
423+ - "target": groundtruth SIC tensor
424+ - optional "sample_weight": weighting tensor
357425
358426 Returns:
359- torch.Tensor: Loss value.
427+ ModelTestOutput:
428+ Object containing:
429+ - prediction: reconstructed SIC (y_hat)
430+ - target: groundtruth SIC (y)
431+ - loss: test loss value
360432
361433 """
362434 x = self .prepare_inputs (batch ) # [B, T, C_combined, H, W]
0 commit comments