We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 492fa97 commit e15afefCopy full SHA for e15afef
1 file changed
models/wan/modules/model.py
@@ -1308,7 +1308,7 @@ def forward(
1308
kwargs["standin_phase"] = 2
1309
if (current_step == 0 or not standin_cache_enabled) and x_id == 0:
1310
standin_x = self.patch_embedding(standin_ref).to(modulation_dtype).flatten(2).transpose(1, 2)
1311
- standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)) )
+ standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)).to(modulation_dtype) )
1312
standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype)
1313
standin_e = standin_ref = None
1314
0 commit comments