Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions comfy/ldm/wan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.flux.math import apply_rope1, rope
import comfy.ldm.common_dit
import comfy.model_management
import comfy.patcher_extension
Expand Down Expand Up @@ -570,6 +570,14 @@ def forward_orig(
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
x = torch.concat((full_ref, x), dim=1)

# In-context reference (Bernini)
context_latents = kwargs.get("context_latents", None)
main_len = x.shape[1]
if context_latents is not None:
for lat in context_latents:
cl = self.patch_embedding(lat.float().to(x.device)).to(x.dtype).flatten(2).transpose(1, 2)
x = torch.cat([x, cl], dim=1)

# context
context = self.text_embedding(context)

Expand Down Expand Up @@ -599,14 +607,17 @@ def block_wrap(args):
# head
x = self.head(x, e)

if context_latents is not None:
x = x[:, :main_len]

if full_ref is not None:
x = x[:, full_ref.shape[1]:]

# unpatchify
x = self.unpatchify(x, grid_sizes)
return x

def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
Expand Down Expand Up @@ -638,6 +649,13 @@ def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=No
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])

freqs = self.rope_embedder(img_ids).movedim(1, 2)

# In-context reference: a non-zero source_id composes an extra rotation into the spatial rope
if source_id:
d = self.dim // self.num_heads
pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32)
id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype)
freqs = torch.einsum('...ij,...jk->...ik', freqs, id_rot)
return freqs

def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
Expand All @@ -661,6 +679,15 @@ def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, tr
t_len += 1

freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)

# In-context reference: one rope block per stream, each with it's own source_id (1, 2, ...) to distinguish from the target (id 0).
context_latents = kwargs.get("context_latents", None)
if context_latents is not None:
context_latents = [comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size) for lat in context_latents]
for i, lat in enumerate(context_latents):
freqs = torch.cat([freqs, self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=i + 1)], dim=1)
kwargs = {**kwargs, "context_latents": context_latents}

return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]

def unpatchify(self, x, grid_sizes):
Expand Down
18 changes: 18 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,8 +1518,26 @@ def extra_conds(self, **kwargs):
if reference_latents is not None:
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])

# In-context reference conditioning (Bernini)
context_latents = kwargs.get("context_latents", None)
if context_latents is not None:
out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents])

return out

def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
# In-context cond slicing (Bernini)
if cond_key == "context_latents" and isinstance(getattr(cond_value, "cond", None), list):
dim = window.dim
out = []
for lat in cond_value.cond:
if lat.ndim > dim and lat.shape[dim] > 1 and lat.shape[dim] == x_in.shape[dim]:
out.append(window.get_tensor(lat, device, dim=dim, retain_index_list=retain_index_list))
else:
out.append(lat.to(device))
return cond_value._copy_with(out)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)


class WAN21_CausalAR(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
Expand Down
115 changes: 115 additions & 0 deletions comfy_extras/nodes_bernini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
from typing_extensions import override

import comfy.model_management
import comfy.utils
import node_helpers
from comfy_api.latest import ComfyExtension, io


def _resize_long_edge(image, max_size, stride=16):
"""Resize (preserve aspect) so the long edge <= max_size, then snap each side to `stride`"""
h, w = image.shape[1], image.shape[2]
scale = min(max_size / max(h, w), 1.0)
nh = max(stride, round(h * scale / stride) * stride)
nw = max(stride, round(w * scale / stride) * stride)
return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "area", "disabled").movedim(1, -1)


class BerniniConditioning(io.ComfyNode):
"""Bernini in-context conditioning for a Wan2.2-A14B model.

Attaches the VAE-encoded source video / reference images to the conditioning
source video first, then each reference image

The task is inferred from which inputs are connected:
(nothing) -> t2v (text-to-video)
source_video -> v2v (video-to-video)
source_video + ref_images -> rv2v (reference-guided video editing)
ref_images only -> r2v (reference-to-video)
source_video + ref_video -> ads2v (insert image/video into video)

source_video is the edit base / canvas (resized to width x height).
reference_video is moving content to composite in.
Streams are ordered source_video, reference_video, then reference_images -> source_id (1, 2, 3, ...).
"""

@classmethod
def define_schema(cls):
return io.Schema(
node_id="BerniniConditioning",
display_name="Bernini Conditioning",
category="conditioning/video_models",
description="Conditioning node for Bernini in-context video/image conditioning. It can be used for the following tasks: t2v (text-to-video), v2v (video-to-video), rv2v (reference-guided video editing), r2v (reference-to-video), ads2v (insert image/video into video)."
"Reference images injected as in-context tokens (r2v, rv2v) are encoded independently at their own native aspect ratio (long edge capped at ref_max_size)",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=8192, step=16),
io.Int.Input("height", default=480, min=16, max=8192, step=16),
io.Int.Input("length", default=81, min=1, max=8192, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("source_video", optional=True, tooltip=(
"Source video to edit or restyle (v2v, rv2v). Resized to width/height and trimmed to length.")),
io.Image.Input("reference_video", optional=True, tooltip=(
"Video to insert into the source video (ads2v).")),
io.Autogrow.Input("reference_images", optional=True,
template=io.Autogrow.TemplatePrefix(
input=io.Image.Input("reference_image", tooltip=(
"A reference image injected as an in-context token (task r2v or rv2v).")),
Comment thread
alexisrolland marked this conversation as resolved.
Outdated
prefix="reference_image_", min=0, max=8)),
io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True, tooltip=(
"Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio and snapped to 16px.")),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)

@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size,
source_video=None, reference_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
device=comfy.model_management.intermediate_device())

# source_video (1), reference_video (2), reference_images (3, 4, ...).
context = []
if source_video is not None:
vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
context.append(vae.encode(vid[:, :, :, :3]))

if reference_video is not None:
ref_vid = _resize_long_edge(reference_video[:length], ref_max_size) # moving content, native aspect
context.append(vae.encode(ref_vid[:, :, :, :3]))

# reference_images is an autogrow dict {reference_image_0: IMAGE, ...}; each slot is a
# separate stream at its own native aspect (a multi-image batch in one slot -> one stream per frame).
if reference_images:
for name in sorted(reference_images):
imgs = reference_images[name]
if imgs is None:
continue
for i in range(imgs.shape[0]):
img = _resize_long_edge(imgs[i:i + 1], ref_max_size) # native aspect per ref
context.append(vae.encode(img[:, :, :, :3]))

if context:
positive = node_helpers.conditioning_set_values(positive, {"context_latents": context})
negative = node_helpers.conditioning_set_values(negative, {"context_latents": context})

return io.NodeOutput(positive, negative, {"samples": latent})


class BerniniExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
BerniniConditioning,
]


async def comfy_entrypoint() -> BerniniExtension:
return BerniniExtension()
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2404,6 +2404,7 @@ async def init_builtin_extra_nodes():
"nodes_video.py",
"nodes_lumina2.py",
"nodes_wan.py",
"nodes_bernini.py",
"nodes_lotus.py",
"nodes_hunyuan3d.py",
"nodes_primitive.py",
Expand Down
Loading