Skip to content

Commit 4a0481e

Browse files
committed
add tests
1 parent 2846939 commit 4a0481e

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

tests/models/transformers/test_models_transformer_hunyuan_video.py

+71
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def prepare_init_args_and_inputs_for_common(self):
8080
"text_embed_dim": 16,
8181
"pooled_projection_dim": 8,
8282
"rope_axes_dim": (2, 4, 4),
83+
"image_condition_type": None,
8384
}
8485
inputs_dict = self.dummy_input
8586
return init_dict, inputs_dict
@@ -144,6 +145,7 @@ def prepare_init_args_and_inputs_for_common(self):
144145
"text_embed_dim": 16,
145146
"pooled_projection_dim": 8,
146147
"rope_axes_dim": (2, 4, 4),
148+
"image_condition_type": None,
147149
}
148150
inputs_dict = self.dummy_input
149151
return init_dict, inputs_dict
@@ -209,6 +211,75 @@ def prepare_init_args_and_inputs_for_common(self):
209211
"text_embed_dim": 16,
210212
"pooled_projection_dim": 8,
211213
"rope_axes_dim": (2, 4, 4),
214+
"image_condition_type": "latent_concat",
215+
}
216+
inputs_dict = self.dummy_input
217+
return init_dict, inputs_dict
218+
219+
def test_output(self):
220+
super().test_output(expected_output_shape=(1, *self.output_shape))
221+
222+
def test_gradient_checkpointing_is_applied(self):
223+
expected_set = {"HunyuanVideoTransformer3DModel"}
224+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
225+
226+
227+
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
228+
model_class = HunyuanVideoTransformer3DModel
229+
main_input_name = "hidden_states"
230+
uses_custom_attn_processor = True
231+
232+
@property
233+
def dummy_input(self):
234+
batch_size = 1
235+
num_channels = 2
236+
num_frames = 1
237+
height = 16
238+
width = 16
239+
text_encoder_embedding_dim = 16
240+
pooled_projection_dim = 8
241+
sequence_length = 12
242+
243+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
244+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
245+
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
246+
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
247+
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
248+
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
249+
250+
return {
251+
"hidden_states": hidden_states,
252+
"timestep": timestep,
253+
"encoder_hidden_states": encoder_hidden_states,
254+
"pooled_projections": pooled_projections,
255+
"encoder_attention_mask": encoder_attention_mask,
256+
"guidance": guidance,
257+
}
258+
259+
@property
260+
def input_shape(self):
261+
return (8, 1, 16, 16)
262+
263+
@property
264+
def output_shape(self):
265+
return (4, 1, 16, 16)
266+
267+
def prepare_init_args_and_inputs_for_common(self):
268+
init_dict = {
269+
"in_channels": 2,
270+
"out_channels": 4,
271+
"num_attention_heads": 2,
272+
"attention_head_dim": 10,
273+
"num_layers": 1,
274+
"num_single_layers": 1,
275+
"num_refiner_layers": 1,
276+
"patch_size": 1,
277+
"patch_size_t": 1,
278+
"guidance_embeds": True,
279+
"text_embed_dim": 16,
280+
"pooled_projection_dim": 8,
281+
"rope_axes_dim": (2, 4, 4),
282+
"image_condition_type": "token_replace",
212283
}
213284
inputs_dict = self.dummy_input
214285
return init_dict, inputs_dict

0 commit comments

Comments
 (0)