Skip to content

Commit 97d20e2

Browse files
committed
ApplyRifleXRoPE_WanVideo
experimental
1 parent 1a4259f commit 97d20e2

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@
182182
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
183183
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
184184
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
185+
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
185186

186187
#instance diffusion
187188
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

nodes/nodes.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2494,6 +2494,42 @@ def get_guider(self, model, cfg, positive, negative, start_percent, end_percent)
24942494
return (guider, )
24952495

24962496

2497+
class ApplyRifleXRoPE_WanVideo:
2498+
@classmethod
2499+
def INPUT_TYPES(s):
2500+
return {
2501+
"required": {
2502+
"model": ("MODEL",),
2503+
"latent": ("LATENT", {"tooltip": "Only used to get the latent count"}),
2504+
"k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1, "tooltip": "Index of intrinsic frequency"}),
2505+
}
2506+
}
2507+
2508+
RETURN_TYPES = ("MODEL",)
2509+
FUNCTION = "patch"
2510+
CATEGORY = "KJNodes/experimental"
2511+
EXPERIMENTAL = True
2512+
DESCRIPTION = "Extends the potential frame count of HunyuanVideo using this method: https://github.com/thu-ml/RIFLEx"
2513+
2514+
def patch(self, model, latent, k):
2515+
model_class = model.model.diffusion_model
2516+
2517+
model_clone = model.clone()
2518+
num_frames = latent["samples"].shape[2]
2519+
d = model_class.dim // model_class.num_heads
2520+
2521+
rope_embedder = EmbedND_RifleX(
2522+
d,
2523+
10000.0,
2524+
[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)],
2525+
num_frames,
2526+
k
2527+
)
2528+
2529+
model_clone.add_object_patch(f"diffusion_model.rope_embedder", rope_embedder)
2530+
2531+
return (model_clone, )
2532+
24972533
class ApplyRifleXRoPE_HunuyanVideo:
24982534
@classmethod
24992535
def INPUT_TYPES(s):
@@ -2513,7 +2549,6 @@ def INPUT_TYPES(s):
25132549

25142550
def patch(self, model, latent, k):
25152551
model_class = model.model.diffusion_model
2516-
print(model_class.pe_embedder)
25172552

25182553
model_clone = model.clone()
25192554
num_frames = latent["samples"].shape[2]
@@ -2549,6 +2584,23 @@ def rope_riflex(pos, dim, theta, L_test, k):
25492584
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
25502585
return out.to(dtype=torch.float32, device=pos.device)
25512586

2587+
class EmbedND_RifleX(nn.Module):
2588+
def __init__(self, dim, theta, axes_dim, num_frames, k):
2589+
super().__init__()
2590+
self.dim = dim
2591+
self.theta = theta
2592+
self.axes_dim = axes_dim
2593+
self.num_frames = num_frames
2594+
self.k = k
2595+
2596+
def forward(self, ids):
2597+
n_axes = ids.shape[-1]
2598+
emb = torch.cat(
2599+
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k) for i in range(n_axes)],
2600+
dim=-3,
2601+
)
2602+
return emb.unsqueeze(1)
2603+
25522604
class EmbedND_RifleX(nn.Module):
25532605
def __init__(self, dim, theta, axes_dim, num_frames, k):
25542606
super().__init__()

0 commit comments

Comments
 (0)