@@ -32,6 +32,8 @@ def __init__(
3232 loss_func : nn .Module | None = None ,
3333 n_noise_channels : int | None = None ,
3434 n_noise_input_channels : int | None = None ,
35+ global_cond_channels : int | None = None ,
36+ include_global_cond : bool = False ,
3537 checkpointing : bool = False ,
3638 ):
3739 super ().__init__ ()
@@ -42,6 +44,8 @@ def __init__(
4244
4345 self .n_noise_channels = n_noise_channels
4446 self .n_noise_input_channels = n_noise_input_channels or n_noise_channels
47+ self .global_cond_channels = global_cond_channels
48+ self .include_global_cond = include_global_cond
4549
4650 if self .n_noise_channels is None and n_noise_input_channels is not None :
4751 msg = (
@@ -50,6 +54,15 @@ def __init__(
5054 )
5155 raise ValueError (msg )
5256
57+ if self .include_global_cond and (
58+ self .global_cond_channels is None or self .global_cond_channels <= 0
59+ ):
60+ msg = (
61+ "include_global_cond=True requires global_cond_channels to be "
62+ "set to a positive integer."
63+ )
64+ raise ValueError (msg )
65+
5366 self .modulation_proj = None
5467 if (
5568 self .n_noise_channels
@@ -68,8 +81,8 @@ def __init__(
6881 n_steps_output = 1 ,
6982 n_steps_input = 1 ,
7083 mod_features = n_noise_channels or 256 ,
71- global_cond_channels = None ,
72- include_global_cond = False ,
84+ global_cond_channels = global_cond_channels ,
85+ include_global_cond = include_global_cond ,
7386 hid_channels = hidden_dim ,
7487 hid_blocks = n_layers ,
7588 attention_heads = num_heads ,
@@ -82,12 +95,19 @@ def __init__(
8295 use_precomputed_modulation = True ,
8396 )
8497
85- def forward (self , x : Tensor , x_noise : Tensor | None = None ) -> Tensor :
98+ def forward (
99+ self ,
100+ x : Tensor ,
101+ x_noise : Tensor | None = None ,
102+ global_cond : Tensor | None = None ,
103+ ) -> Tensor :
86104 """Run TemporalViT with channel-first inputs and outputs.
87105
88106 Args:
89107 x: Input tensor with shape (B, C, H, W).
90108 x_noise: Optional noise/modulation tensor.
109+ global_cond: Optional global conditioning tensor with shape
110+ (B, C_global). Used only when include_global_cond=True.
91111
92112 Returns
93113 -------
@@ -107,22 +127,49 @@ def forward(self, x: Tensor, x_noise: Tensor | None = None) -> Tensor:
107127 )
108128 raise ValueError (msg )
109129
130+ if (
131+ not self .n_noise_channels
132+ and x_noise is not None
133+ and x_noise .shape [- 1 ] != self .model .mod_features
134+ ):
135+ msg = (
136+ f"Expected x_noise with last dim { self .model .mod_features } , "
137+ f"got { x_noise .shape [- 1 ]} ."
138+ )
139+ raise ValueError (msg )
140+
141+ model_global_cond = None
142+ if self .include_global_cond :
143+ if global_cond is None :
144+ msg = "global_cond must be provided when include_global_cond=True."
145+ raise ValueError (msg )
146+ if global_cond .shape [- 1 ] != self .global_cond_channels :
147+ msg = (
148+ f"Expected global_cond with last dim "
149+ f"{ self .global_cond_channels } , got "
150+ f"{ global_cond .shape [- 1 ]} ."
151+ )
152+ raise ValueError (msg )
153+ model_global_cond = global_cond
154+
110155 x_in = rearrange (x , "b c h w -> b 1 h w c" ).contiguous ()
111- y = self .model (x_in , t = x_noise , cond = None , global_cond = None )
156+ y = self .model (x_in , t = x_noise , cond = None , global_cond = model_global_cond )
112157 return rearrange (y , "b 1 h w c -> b c h w" ).contiguous ()
113158
114- def map (self , x : Tensor , global_cond : Tensor | None = None ) -> Tensor : # noqa: ARG002
159+ def map (self , x : Tensor , global_cond : Tensor | None = None ) -> Tensor :
115160 noise_channels = self .n_noise_input_channels or self .n_noise_channels
161+ if noise_channels is None :
162+ noise_channels = self .model .mod_features
163+
116164 if self .n_noise_channels :
117- if noise_channels is None :
118- msg = "n_noise_channels is set but no noise input width is available."
119- raise ValueError (msg )
120165 noise = torch .randn (
121166 x .shape [0 ], noise_channels , dtype = x .dtype , device = x .device
122167 )
123168 else :
124- noise = torch .zeros (x .shape [0 ], dtype = x .dtype , device = x .device )
125- return self (x , noise )
169+ noise = torch .zeros (
170+ x .shape [0 ], noise_channels , dtype = x .dtype , device = x .device
171+ )
172+ return self (x , noise , global_cond = global_cond )
126173
127174 def loss (self , batch : EncodedBatch ) -> Tensor :
128175 pred = self .map (batch .encoded_inputs , batch .global_cond )
0 commit comments