Skip to content

Commit dcb426f

Browse files
committed
📝 Add some information about how diffusion models could be implemented in the multistage framework
1 parent ad4d70e commit dcb426f

5 files changed

Lines changed: 119 additions & 5 deletions

File tree

docs/scripts/generate_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
HEADER = """\
1818
# Reference configuration
1919
20-
<!-- This file is auto-generated by `docs/generate_config.py` — do not edit by hand. -->
20+
<!-- This file is auto-generated by `docs/generate_config.py`. Do not edit by hand. -->
2121
2222
The full default configuration composed from [`icenet_mp/config/base.yaml`](https://github.com/alan-turing-institute/icenet-mp/blob/main/icenet_mp/config/base.yaml) and all its sub-configs.
2323
This is the configuration used when you run any command without overrides.

docs/src/how-to/add-a-model.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Add a new model
1+
# Add a model
22

33
## Tensor format
44

@@ -14,7 +14,7 @@ All IceNet-MP models operate on tensors in `NTCHW` format:
1414

1515
`N` and `T` are the same across all inputs, but `C`, `H`, and `W` may differ per dataset.
1616

17-
For example, with batch size `N=2`, 3 history steps, and 4 forecast steps, each of the `k` inputs each have shape `(2, 3, C_k, H_k, W_k)` and the output has shape `(2, 4, C_out, H_out, W_out)`.
17+
For example, with 3 history steps, and 4 forecast steps, each of the `k` inputs each have shape `(N, 3, C_k, H_k, W_k)` and the output has shape `(N, 4, C_out, H_out, W_out)`.
1818

1919
## Standalone models
2020

@@ -36,7 +36,7 @@ You define a latent space `(H_latent, W_latent)` and the framework automatically
3636
1. Each dataset-specific **encoder** maps input `(N, T_history, C_k, H_k, W_k)` to `(N, T_history, C_k_latent, H_latent, W_latent)`.
3737
2. The `k` encoded tensors are concatenated to `(N, T_history, C_latent, H_latent, W_latent)`.
3838
3. The **processor** maps `(N, T_history, C_latent, H_latent, W_latent)` to `(N, T_forecast, C_latent, H_latent, W_latent)`.
39-
4. Each output-specific **decoder** maps the processor output `(N, T_forecast, C_out, H_out, W_out)`.
39+
4. Each output-specific **decoder** maps the processor output, `(N, T_forecast, C_latent, H_latent, W_latent)`, to `(N, T_forecast, C_out, H_out, W_out)`.
4040

4141
![Encode-process-decode pipeline diagram](../assets/pipeline-encode-process-decode.png)
4242

docs/src/how-to/add-a-processor.md

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Add a processor
2+
3+
A processor sits between the encoders and decoder in the encode-process-decode pipeline.
4+
It receives the concatenated latent representations of all inputs and produces a latent forecast.
5+
6+
## The processor interface
7+
8+
All IceNet-MP processors extend `BaseProcessor` from `icenet_mp.models.processors`.
9+
They operate on tensors in `NTCHW` format, taking in a tensor with a number of history steps and returning a tensor with a number of forecast steps.
10+
For example, with 3 history steps, and 4 forecast steps, a processor will convert a tensor of shape `(N, 3, C, H, W)` to `(N, 4, C, H, W)`
11+
12+
The base class exposes two entry points, and you only need to implement one:
13+
14+
| Method | Signature | When to override |
15+
|--------|-----------|-----------------|
16+
| `forward` | `(x: TensorNCHW) -> TensorNCHW` | Stateless single-timestep transforms |
17+
| `rollout` | `(x: TensorNTCHW, y: TensorNTCHW \| None) -> ModelStepOutput` | Any model that needs access to the full temporal history, or that behaves differently during training vs. inference |
18+
19+
The default `rollout` implementation calls `forward` once per forecast step, passing each prediction back as the next input.
20+
If your architecture works on one timestep at a time and uses the same logic during training and inference, only overriding `forward` is sufficient.
21+
22+
## Simple processor: override `forward`
23+
24+
```python
25+
from typing import Any
26+
from icenet_mp.models.processors import BaseProcessor
27+
from icenet_mp.types import TensorNCHW
28+
29+
30+
class MyProcessor(BaseProcessor):
31+
def __init__(self, *, hidden_dim: int = 128, **kwargs: Any) -> None:
32+
super().__init__(**kwargs)
33+
in_channels = self.data_space.channels
34+
self.model = ... # your nn.Module here
35+
36+
def forward(self, x: TensorNCHW) -> TensorNCHW:
37+
return self.model(x)
38+
```
39+
40+
This model can be trained in either single-stage or multistage mode.
41+
42+
43+
## Training vs. inference: override `rollout`
44+
45+
Some architectures fundamentally differ between training and inference.
46+
The canonical example is a diffusion model: during training you corrupt the target and predict noise; during inference you run the full reverse diffusion chain from pure noise.
47+
48+
If you use the multistage training flow - encode and decode components can be pretrained independently before the processor is trained on their fixed latent space.
49+
This then allows the use of different training and inference behaviour in the `rollout` method.
50+
51+
The `rollout` signature allows the processor to handle both training and inference without direct knowledge of which step is being run:
52+
53+
- if `y`, the latent-space-encoded target, is provided, this is **training**
54+
- if `y` is `None` then this is **inference**
55+
56+
```python
57+
from torch import Tensor
58+
from icenet_mp.models.processors import BaseProcessor
59+
from icenet_mp.types import ModelStepOutput, TensorNTCHW
60+
61+
62+
class MyDiffusionProcessor(BaseProcessor):
63+
64+
def rollout(
65+
self, x: TensorNTCHW, y: TensorNTCHW | None = None
66+
) -> ModelStepOutput:
67+
# x: (N, T_history, C, H, W) - encoded inputs
68+
# y: (N, T_forecast, C, H, W) - encoded targets
69+
if y is not None:
70+
# --- Training path ---
71+
prediction, loss = self._training(x, y)
72+
return ModelStepOutput(prediction=prediction, target=y, loss=loss)
73+
else:
74+
# --- Inference path ---
75+
prediction = self._inference(x)
76+
return ModelStepOutput(prediction=prediction, target=None, loss=None)
77+
```
78+
79+
Returning a valid `loss` tensor tells `ProcessorStage` to skip its own loss computation and use yours instead.
80+
The decoded prediction is still computed and logged, but gradients flow through your custom loss.
81+
82+
## Register the processor in config
83+
84+
Add a model config under `icenet_mp/config/model/` that points `processor._target_` at your class:
85+
86+
```yaml
87+
# icenet_mp/config/model/cnn_mydiffusion_cnn.yaml
88+
_target_: icenet_mp.models.EncodeProcessDecode
89+
90+
name: cnn-ddpm-cnn
91+
92+
encoders:
93+
latent_space: [144, 144]
94+
era5:
95+
_target_: icenet_mp.models.encoders.CNNEncoder
96+
sic-icenet:
97+
_target_: icenet_mp.models.encoders.CNNEncoder
98+
99+
processor:
100+
_target_: icenet_mp.models.processors.MyDiffusionProcessor
101+
timesteps: 1000
102+
103+
decoder:
104+
_target_: icenet_mp.models.decoders.CNNDecoder
105+
bounded: false
106+
```
107+
108+
Then run training with:
109+
110+
```bash
111+
uv run imp train model=cnn_mydiffusion_cnn
112+
```

docs/src/how-to/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
Step-by-step guides for common tasks.
44

5-
- [Add a new model](add-a-model.md) - implement a custom architecture
5+
- [Add a model](add-a-model.md) - implement a custom architecture
6+
- [Add a processor](add-a-processor.md) - implement a processor, including models with different training and inference behaviour
67
- [Train a model](train.md) - run single-stage end-to-end training
78
- [Train in stages](train-multistage.md) - pretrain each component separately before finetuning

zensical.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ nav = [
1717
{ "How-to" = [
1818
{ "Overview" = "how-to/index.md" },
1919
{ "Add a model" = "how-to/add-a-model.md" },
20+
{ "Add a processor" = "how-to/add-a-processor.md" },
2021
{ "Train a model" = "how-to/train.md" },
2122
{ "Run multistage training" = "how-to/train-multistage.md" },
2223
] },

0 commit comments

Comments
 (0)