@@ -80,6 +80,7 @@ def prepare_init_args_and_inputs_for_common(self):
80
80
"text_embed_dim" : 16 ,
81
81
"pooled_projection_dim" : 8 ,
82
82
"rope_axes_dim" : (2 , 4 , 4 ),
83
+ "image_condition_type" : None ,
83
84
}
84
85
inputs_dict = self .dummy_input
85
86
return init_dict , inputs_dict
@@ -144,6 +145,7 @@ def prepare_init_args_and_inputs_for_common(self):
144
145
"text_embed_dim" : 16 ,
145
146
"pooled_projection_dim" : 8 ,
146
147
"rope_axes_dim" : (2 , 4 , 4 ),
148
+ "image_condition_type" : None ,
147
149
}
148
150
inputs_dict = self .dummy_input
149
151
return init_dict , inputs_dict
@@ -209,6 +211,75 @@ def prepare_init_args_and_inputs_for_common(self):
209
211
"text_embed_dim" : 16 ,
210
212
"pooled_projection_dim" : 8 ,
211
213
"rope_axes_dim" : (2 , 4 , 4 ),
214
+ "image_condition_type" : "latent_concat" ,
215
+ }
216
+ inputs_dict = self .dummy_input
217
+ return init_dict , inputs_dict
218
+
219
+ def test_output (self ):
220
+ super ().test_output (expected_output_shape = (1 , * self .output_shape ))
221
+
222
+ def test_gradient_checkpointing_is_applied (self ):
223
+ expected_set = {"HunyuanVideoTransformer3DModel" }
224
+ super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
225
+
226
+
227
+ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests (ModelTesterMixin , unittest .TestCase ):
228
+ model_class = HunyuanVideoTransformer3DModel
229
+ main_input_name = "hidden_states"
230
+ uses_custom_attn_processor = True
231
+
232
+ @property
233
+ def dummy_input (self ):
234
+ batch_size = 1
235
+ num_channels = 2
236
+ num_frames = 1
237
+ height = 16
238
+ width = 16
239
+ text_encoder_embedding_dim = 16
240
+ pooled_projection_dim = 8
241
+ sequence_length = 12
242
+
243
+ hidden_states = torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
244
+ timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
245
+ encoder_hidden_states = torch .randn ((batch_size , sequence_length , text_encoder_embedding_dim )).to (torch_device )
246
+ pooled_projections = torch .randn ((batch_size , pooled_projection_dim )).to (torch_device )
247
+ encoder_attention_mask = torch .ones ((batch_size , sequence_length )).to (torch_device )
248
+ guidance = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device , dtype = torch .float32 )
249
+
250
+ return {
251
+ "hidden_states" : hidden_states ,
252
+ "timestep" : timestep ,
253
+ "encoder_hidden_states" : encoder_hidden_states ,
254
+ "pooled_projections" : pooled_projections ,
255
+ "encoder_attention_mask" : encoder_attention_mask ,
256
+ "guidance" : guidance ,
257
+ }
258
+
259
+ @property
260
+ def input_shape (self ):
261
+ return (8 , 1 , 16 , 16 )
262
+
263
+ @property
264
+ def output_shape (self ):
265
+ return (4 , 1 , 16 , 16 )
266
+
267
+ def prepare_init_args_and_inputs_for_common (self ):
268
+ init_dict = {
269
+ "in_channels" : 2 ,
270
+ "out_channels" : 4 ,
271
+ "num_attention_heads" : 2 ,
272
+ "attention_head_dim" : 10 ,
273
+ "num_layers" : 1 ,
274
+ "num_single_layers" : 1 ,
275
+ "num_refiner_layers" : 1 ,
276
+ "patch_size" : 1 ,
277
+ "patch_size_t" : 1 ,
278
+ "guidance_embeds" : True ,
279
+ "text_embed_dim" : 16 ,
280
+ "pooled_projection_dim" : 8 ,
281
+ "rope_axes_dim" : (2 , 4 , 4 ),
282
+ "image_condition_type" : "token_replace" ,
212
283
}
213
284
inputs_dict = self .dummy_input
214
285
return init_dict , inputs_dict
0 commit comments