@@ -2494,6 +2494,42 @@ def get_guider(self, model, cfg, positive, negative, start_percent, end_percent)
24942494 return (guider , )
24952495
24962496
2497+ class ApplyRifleXRoPE_WanVideo :
2498+ @classmethod
2499+ def INPUT_TYPES (s ):
2500+ return {
2501+ "required" : {
2502+ "model" : ("MODEL" ,),
2503+ "latent" : ("LATENT" , {"tooltip" : "Only used to get the latent count" }),
2504+ "k" : ("INT" , {"default" : 4 , "min" : 1 , "max" : 100 , "step" : 1 , "tooltip" : "Index of intrinsic frequency" }),
2505+ }
2506+ }
2507+
2508+ RETURN_TYPES = ("MODEL" ,)
2509+ FUNCTION = "patch"
2510+ CATEGORY = "KJNodes/experimental"
2511+ EXPERIMENTAL = True
2512+ DESCRIPTION = "Extends the potential frame count of HunyuanVideo using this method: https://github.com/thu-ml/RIFLEx"
2513+
2514+ def patch (self , model , latent , k ):
2515+ model_class = model .model .diffusion_model
2516+
2517+ model_clone = model .clone ()
2518+ num_frames = latent ["samples" ].shape [2 ]
2519+ d = model_class .dim // model_class .num_heads
2520+
2521+ rope_embedder = EmbedND_RifleX (
2522+ d ,
2523+ 10000.0 ,
2524+ [d - 4 * (d // 6 ), 2 * (d // 6 ), 2 * (d // 6 )],
2525+ num_frames ,
2526+ k
2527+ )
2528+
2529+ model_clone .add_object_patch (f"diffusion_model.rope_embedder" , rope_embedder )
2530+
2531+ return (model_clone , )
2532+
24972533class ApplyRifleXRoPE_HunuyanVideo :
24982534 @classmethod
24992535 def INPUT_TYPES (s ):
@@ -2513,7 +2549,6 @@ def INPUT_TYPES(s):
25132549
25142550 def patch (self , model , latent , k ):
25152551 model_class = model .model .diffusion_model
2516- print (model_class .pe_embedder )
25172552
25182553 model_clone = model .clone ()
25192554 num_frames = latent ["samples" ].shape [2 ]
@@ -2549,6 +2584,23 @@ def rope_riflex(pos, dim, theta, L_test, k):
25492584 out = rearrange (out , "b n d (i j) -> b n d i j" , i = 2 , j = 2 )
25502585 return out .to (dtype = torch .float32 , device = pos .device )
25512586
2587+ class EmbedND_RifleX (nn .Module ):
2588+ def __init__ (self , dim , theta , axes_dim , num_frames , k ):
2589+ super ().__init__ ()
2590+ self .dim = dim
2591+ self .theta = theta
2592+ self .axes_dim = axes_dim
2593+ self .num_frames = num_frames
2594+ self .k = k
2595+
2596+ def forward (self , ids ):
2597+ n_axes = ids .shape [- 1 ]
2598+ emb = torch .cat (
2599+ [rope_riflex (ids [..., i ], self .axes_dim [i ], self .theta , self .num_frames , self .k ) for i in range (n_axes )],
2600+ dim = - 3 ,
2601+ )
2602+ return emb .unsqueeze (1 )
2603+
25522604class EmbedND_RifleX (nn .Module ):
25532605 def __init__ (self , dim , theta , axes_dim , num_frames , k ):
25542606 super ().__init__ ()
0 commit comments