@@ -32,6 +32,8 @@ def __init__(
3232 loss_func : nn .Module | None = None ,
3333 n_noise_channels : int | None = None ,
3434 n_noise_input_channels : int | None = None ,
35+ global_cond_channels : int | None = None ,
36+ include_global_cond : bool = False ,
3537 checkpointing : bool = False ,
3638 ):
3739 super ().__init__ ()
@@ -42,6 +44,8 @@ def __init__(
4244
4345 self .n_noise_channels = n_noise_channels
4446 self .n_noise_input_channels = n_noise_input_channels or n_noise_channels
47+ self .global_cond_channels = global_cond_channels
48+ self .include_global_cond = include_global_cond
4549
4650 if self .n_noise_channels is None and n_noise_input_channels is not None :
4751 msg = (
@@ -68,8 +72,8 @@ def __init__(
6872 n_steps_output = 1 ,
6973 n_steps_input = 1 ,
7074 mod_features = n_noise_channels or 256 ,
71- global_cond_channels = None ,
72- include_global_cond = False ,
75+ global_cond_channels = global_cond_channels ,
76+ include_global_cond = include_global_cond ,
7377 hid_channels = hidden_dim ,
7478 hid_blocks = n_layers ,
7579 attention_heads = num_heads ,
@@ -82,12 +86,19 @@ def __init__(
8286 use_precomputed_modulation = True ,
8387 )
8488
85- def forward (self , x : Tensor , x_noise : Tensor | None = None ) -> Tensor :
89+ def forward (
90+ self ,
91+ x : Tensor ,
92+ x_noise : Tensor | None = None ,
93+ global_cond : Tensor | None = None ,
94+ ) -> Tensor :
8695 """Run TemporalViT with channel-first inputs and outputs.
8796
8897 Args:
8998 x: Input tensor with shape (B, C, H, W).
9099 x_noise: Optional noise/modulation tensor.
100+ global_cond: Optional global conditioning tensor with shape
101+ (B, C_global). Used only when include_global_cond=True.
91102
92103 Returns
93104 -------
@@ -107,22 +118,49 @@ def forward(self, x: Tensor, x_noise: Tensor | None = None) -> Tensor:
107118 )
108119 raise ValueError (msg )
109120
121+ if (
122+ not self .n_noise_channels
123+ and x_noise is not None
124+ and x_noise .shape [- 1 ] != self .model .mod_features
125+ ):
126+ msg = (
127+ f"Expected x_noise with last dim { self .model .mod_features } , "
128+ f"got { x_noise .shape [- 1 ]} ."
129+ )
130+ raise ValueError (msg )
131+
132+ model_global_cond = None
133+ if self .include_global_cond :
134+ if global_cond is None :
135+ msg = "global_cond must be provided when include_global_cond=True."
136+ raise ValueError (msg )
137+ if global_cond .shape [- 1 ] != self .global_cond_channels :
138+ msg = (
139+ f"Expected global_cond with last dim "
140+ f"{ self .global_cond_channels } , got "
141+ f"{ global_cond .shape [- 1 ]} ."
142+ )
143+ raise ValueError (msg )
144+ model_global_cond = global_cond
145+
110146 x_in = rearrange (x , "b c h w -> b 1 h w c" ).contiguous ()
111- y = self .model (x_in , t = x_noise , cond = None , global_cond = None )
147+ y = self .model (x_in , t = x_noise , cond = None , global_cond = model_global_cond )
112148 return rearrange (y , "b 1 h w c -> b c h w" ).contiguous ()
113149
114- def map (self , x : Tensor , global_cond : Tensor | None = None ) -> Tensor : # noqa: ARG002
150+ def map (self , x : Tensor , global_cond : Tensor | None = None ) -> Tensor :
115151 noise_channels = self .n_noise_input_channels or self .n_noise_channels
152+ if noise_channels is None :
153+ noise_channels = self .model .mod_features
154+
116155 if self .n_noise_channels :
117- if noise_channels is None :
118- msg = "n_noise_channels is set but no noise input width is available."
119- raise ValueError (msg )
120156 noise = torch .randn (
121157 x .shape [0 ], noise_channels , dtype = x .dtype , device = x .device
122158 )
123159 else :
124- noise = torch .zeros (x .shape [0 ], dtype = x .dtype , device = x .device )
125- return self (x , noise )
160+ noise = torch .zeros (
161+ x .shape [0 ], noise_channels , dtype = x .dtype , device = x .device
162+ )
163+ return self (x , noise , global_cond = global_cond )
126164
127165 def loss (self , batch : EncodedBatch ) -> Tensor :
128166 pred = self .map (batch .encoded_inputs , batch .global_cond )
0 commit comments