Skip to content

Commit b56e647

Browse files
committed
Add global_cond support to AzulaViTProcessor and related configs
1 parent 2b9a145 commit b56e647

7 files changed

Lines changed: 67 additions & 10 deletions

File tree

local_hydra/local_experiment/processor/advection_diffusion/crps_vit_azula_large.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ model:
3434
n_layers: 12
3535
patch_size: 1
3636
n_noise_channels: 1024
37+
global_cond_channels: auto
38+
include_global_cond: true
3739
loss_func:
3840
_target_: autocast.losses.ensemble.AlphaFairCRPSLoss
3941
train_metrics:

local_hydra/local_experiment/processor/conditioned_navier_stokes/crps_vit_azula_large.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ model:
3434
n_layers: 12
3535
patch_size: 1
3636
n_noise_channels: 1024
37+
global_cond_channels: auto
38+
include_global_cond: true
3739
loss_func:
3840
_target_: autocast.losses.ensemble.AlphaFairCRPSLoss
3941
train_metrics:

local_hydra/local_experiment/processor/gpe_laser_wake_only/crps_vit_azula_large.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ model:
3434
n_layers: 12
3535
patch_size: 1
3636
n_noise_channels: 1024
37+
global_cond_channels: auto
38+
include_global_cond: true
3739
loss_func:
3840
_target_: autocast.losses.ensemble.AlphaFairCRPSLoss
3941
train_metrics:

local_hydra/local_experiment/processor/gray_scott/crps_vit_azula_large.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ model:
3434
n_layers: 12
3535
patch_size: 1
3636
n_noise_channels: 1024
37+
global_cond_channels: auto
38+
include_global_cond: true
3739
loss_func:
3840
_target_: autocast.losses.ensemble.AlphaFairCRPSLoss
3941
train_metrics:

src/autocast/configs/processor/vit_azula_large.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ n_layers: 10
99
patch_size: 4
1010
temporal_method: none
1111
n_noise_channels: 256
12+
include_global_cond: false

src/autocast/configs/processor/vit_azula_small.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ n_layers: 10
99
patch_size: 4
1010
temporal_method: none
1111
n_noise_channels: 256
12+
include_global_cond: false

src/autocast/processors/azula_vit.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def __init__(
3232
loss_func: nn.Module | None = None,
3333
n_noise_channels: int | None = None,
3434
n_noise_input_channels: int | None = None,
35+
global_cond_channels: int | None = None,
36+
include_global_cond: bool = False,
3537
checkpointing: bool = False,
3638
):
3739
super().__init__()
@@ -42,6 +44,8 @@ def __init__(
4244

4345
self.n_noise_channels = n_noise_channels
4446
self.n_noise_input_channels = n_noise_input_channels or n_noise_channels
47+
self.global_cond_channels = global_cond_channels
48+
self.include_global_cond = include_global_cond
4549

4650
if self.n_noise_channels is None and n_noise_input_channels is not None:
4751
msg = (
@@ -50,6 +54,15 @@ def __init__(
5054
)
5155
raise ValueError(msg)
5256

57+
if self.include_global_cond and (
58+
self.global_cond_channels is None or self.global_cond_channels <= 0
59+
):
60+
msg = (
61+
"include_global_cond=True requires global_cond_channels to be "
62+
"set to a positive integer."
63+
)
64+
raise ValueError(msg)
65+
5366
self.modulation_proj = None
5467
if (
5568
self.n_noise_channels
@@ -68,8 +81,8 @@ def __init__(
6881
n_steps_output=1,
6982
n_steps_input=1,
7083
mod_features=n_noise_channels or 256,
71-
global_cond_channels=None,
72-
include_global_cond=False,
84+
global_cond_channels=global_cond_channels,
85+
include_global_cond=include_global_cond,
7386
hid_channels=hidden_dim,
7487
hid_blocks=n_layers,
7588
attention_heads=num_heads,
@@ -82,12 +95,19 @@ def __init__(
8295
use_precomputed_modulation=True,
8396
)
8497

85-
def forward(self, x: Tensor, x_noise: Tensor | None = None) -> Tensor:
98+
def forward(
99+
self,
100+
x: Tensor,
101+
x_noise: Tensor | None = None,
102+
global_cond: Tensor | None = None,
103+
) -> Tensor:
86104
"""Run TemporalViT with channel-first inputs and outputs.
87105
88106
Args:
89107
x: Input tensor with shape (B, C, H, W).
90108
x_noise: Optional noise/modulation tensor.
109+
global_cond: Optional global conditioning tensor with shape
110+
(B, C_global). Used only when include_global_cond=True.
91111
92112
Returns
93113
-------
@@ -107,22 +127,49 @@ def forward(self, x: Tensor, x_noise: Tensor | None = None) -> Tensor:
107127
)
108128
raise ValueError(msg)
109129

130+
if (
131+
not self.n_noise_channels
132+
and x_noise is not None
133+
and x_noise.shape[-1] != self.model.mod_features
134+
):
135+
msg = (
136+
f"Expected x_noise with last dim {self.model.mod_features}, "
137+
f"got {x_noise.shape[-1]}."
138+
)
139+
raise ValueError(msg)
140+
141+
model_global_cond = None
142+
if self.include_global_cond:
143+
if global_cond is None:
144+
msg = "global_cond must be provided when include_global_cond=True."
145+
raise ValueError(msg)
146+
if global_cond.shape[-1] != self.global_cond_channels:
147+
msg = (
148+
f"Expected global_cond with last dim "
149+
f"{self.global_cond_channels}, got "
150+
f"{global_cond.shape[-1]}."
151+
)
152+
raise ValueError(msg)
153+
model_global_cond = global_cond
154+
110155
x_in = rearrange(x, "b c h w -> b 1 h w c").contiguous()
111-
y = self.model(x_in, t=x_noise, cond=None, global_cond=None)
156+
y = self.model(x_in, t=x_noise, cond=None, global_cond=model_global_cond)
112157
return rearrange(y, "b 1 h w c -> b c h w").contiguous()
113158

114-
def map(self, x: Tensor, global_cond: Tensor | None = None) -> Tensor: # noqa: ARG002
159+
def map(self, x: Tensor, global_cond: Tensor | None = None) -> Tensor:
115160
noise_channels = self.n_noise_input_channels or self.n_noise_channels
161+
if noise_channels is None:
162+
noise_channels = self.model.mod_features
163+
116164
if self.n_noise_channels:
117-
if noise_channels is None:
118-
msg = "n_noise_channels is set but no noise input width is available."
119-
raise ValueError(msg)
120165
noise = torch.randn(
121166
x.shape[0], noise_channels, dtype=x.dtype, device=x.device
122167
)
123168
else:
124-
noise = torch.zeros(x.shape[0], dtype=x.dtype, device=x.device)
125-
return self(x, noise)
169+
noise = torch.zeros(
170+
x.shape[0], noise_channels, dtype=x.dtype, device=x.device
171+
)
172+
return self(x, noise, global_cond=global_cond)
126173

127174
def loss(self, batch: EncodedBatch) -> Tensor:
128175
pred = self.map(batch.encoded_inputs, batch.global_cond)

0 commit comments

Comments
 (0)