Skip to content

Commit d9a8f00

Browse files
authored
Merge pull request #330 from alan-turing-institute/update-support-for-global-cond-vit
Update support for passing `global_cond` to ViT
2 parents b103316 + a8bd096 commit d9a8f00

7 files changed

Lines changed: 58 additions & 10 deletions

File tree

local_hydra/local_experiment/processor/advection_diffusion/crps_vit_azula_large.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ model:
3434
n_layers: 12
3535
patch_size: 1
3636
n_noise_channels: 1024
37+
global_cond_channels: auto
38+
include_global_cond: true
3739
loss_func:
3840
_target_: autocast.losses.ensemble.AlphaFairCRPSLoss
3941
train_metrics:

local_hydra/local_experiment/processor/conditioned_navier_stokes/crps_vit_azula_large.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ model:
3434
n_layers: 12
3535
patch_size: 1
3636
n_noise_channels: 1024
37+
global_cond_channels: auto
38+
include_global_cond: true
3739
loss_func:
3840
_target_: autocast.losses.ensemble.AlphaFairCRPSLoss
3941
train_metrics:

local_hydra/local_experiment/processor/gpe_laser_wake_only/crps_vit_azula_large.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ model:
3434
n_layers: 12
3535
patch_size: 1
3636
n_noise_channels: 1024
37+
global_cond_channels: auto
38+
include_global_cond: true
3739
loss_func:
3840
_target_: autocast.losses.ensemble.AlphaFairCRPSLoss
3941
train_metrics:

local_hydra/local_experiment/processor/gray_scott/crps_vit_azula_large.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ model:
3434
n_layers: 12
3535
patch_size: 1
3636
n_noise_channels: 1024
37+
global_cond_channels: auto
38+
include_global_cond: true
3739
loss_func:
3840
_target_: autocast.losses.ensemble.AlphaFairCRPSLoss
3941
train_metrics:

src/autocast/configs/processor/vit_azula_large.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ n_layers: 10
99
patch_size: 4
1010
temporal_method: none
1111
n_noise_channels: 256
12+
include_global_cond: false

src/autocast/configs/processor/vit_azula_small.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ n_layers: 10
99
patch_size: 4
1010
temporal_method: none
1111
n_noise_channels: 256
12+
include_global_cond: false

src/autocast/processors/azula_vit.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def __init__(
3232
loss_func: nn.Module | None = None,
3333
n_noise_channels: int | None = None,
3434
n_noise_input_channels: int | None = None,
35+
global_cond_channels: int | None = None,
36+
include_global_cond: bool = False,
3537
checkpointing: bool = False,
3638
):
3739
super().__init__()
@@ -42,6 +44,8 @@ def __init__(
4244

4345
self.n_noise_channels = n_noise_channels
4446
self.n_noise_input_channels = n_noise_input_channels or n_noise_channels
47+
self.global_cond_channels = global_cond_channels
48+
self.include_global_cond = include_global_cond
4549

4650
if self.n_noise_channels is None and n_noise_input_channels is not None:
4751
msg = (
@@ -68,8 +72,8 @@ def __init__(
6872
n_steps_output=1,
6973
n_steps_input=1,
7074
mod_features=n_noise_channels or 256,
71-
global_cond_channels=None,
72-
include_global_cond=False,
75+
global_cond_channels=global_cond_channels,
76+
include_global_cond=include_global_cond,
7377
hid_channels=hidden_dim,
7478
hid_blocks=n_layers,
7579
attention_heads=num_heads,
@@ -82,12 +86,19 @@ def __init__(
8286
use_precomputed_modulation=True,
8387
)
8488

85-
def forward(self, x: Tensor, x_noise: Tensor | None = None) -> Tensor:
89+
def forward(
90+
self,
91+
x: Tensor,
92+
x_noise: Tensor | None = None,
93+
global_cond: Tensor | None = None,
94+
) -> Tensor:
8695
"""Run TemporalViT with channel-first inputs and outputs.
8796
8897
Args:
8998
x: Input tensor with shape (B, C, H, W).
9099
x_noise: Optional noise/modulation tensor.
100+
global_cond: Optional global conditioning tensor with shape
101+
(B, C_global). Used only when include_global_cond=True.
91102
92103
Returns
93104
-------
@@ -107,22 +118,49 @@ def forward(self, x: Tensor, x_noise: Tensor | None = None) -> Tensor:
107118
)
108119
raise ValueError(msg)
109120

121+
if (
122+
not self.n_noise_channels
123+
and x_noise is not None
124+
and x_noise.shape[-1] != self.model.mod_features
125+
):
126+
msg = (
127+
f"Expected x_noise with last dim {self.model.mod_features}, "
128+
f"got {x_noise.shape[-1]}."
129+
)
130+
raise ValueError(msg)
131+
132+
model_global_cond = None
133+
if self.include_global_cond:
134+
if global_cond is None:
135+
msg = "global_cond must be provided when include_global_cond=True."
136+
raise ValueError(msg)
137+
if global_cond.shape[-1] != self.global_cond_channels:
138+
msg = (
139+
f"Expected global_cond with last dim "
140+
f"{self.global_cond_channels}, got "
141+
f"{global_cond.shape[-1]}."
142+
)
143+
raise ValueError(msg)
144+
model_global_cond = global_cond
145+
110146
x_in = rearrange(x, "b c h w -> b 1 h w c").contiguous()
111-
y = self.model(x_in, t=x_noise, cond=None, global_cond=None)
147+
y = self.model(x_in, t=x_noise, cond=None, global_cond=model_global_cond)
112148
return rearrange(y, "b 1 h w c -> b c h w").contiguous()
113149

114-
def map(self, x: Tensor, global_cond: Tensor | None = None) -> Tensor: # noqa: ARG002
150+
def map(self, x: Tensor, global_cond: Tensor | None = None) -> Tensor:
115151
noise_channels = self.n_noise_input_channels or self.n_noise_channels
152+
if noise_channels is None:
153+
noise_channels = self.model.mod_features
154+
116155
if self.n_noise_channels:
117-
if noise_channels is None:
118-
msg = "n_noise_channels is set but no noise input width is available."
119-
raise ValueError(msg)
120156
noise = torch.randn(
121157
x.shape[0], noise_channels, dtype=x.dtype, device=x.device
122158
)
123159
else:
124-
noise = torch.zeros(x.shape[0], dtype=x.dtype, device=x.device)
125-
return self(x, noise)
160+
noise = torch.zeros(
161+
x.shape[0], noise_channels, dtype=x.dtype, device=x.device
162+
)
163+
return self(x, noise, global_cond=global_cond)
126164

127165
def loss(self, batch: EncodedBatch) -> Tensor:
128166
pred = self.map(batch.encoded_inputs, batch.global_cond)

0 commit comments

Comments
 (0)