From 19475fdafc51804ea542bf5fb1dbd4e17664ffe4 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 1 Jun 2026 19:18:33 +0300 Subject: [PATCH 01/12] Initial commit --- comfy/ldm/wan/model.py | 125 +++++++++++++++++++++++++++++++++- comfy/model_base.py | 26 +++++++ comfy_extras/nodes_bernini.py | 95 ++++++++++++++++++++++++++ nodes.py | 1 + 4 files changed, 246 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_bernini.py diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 70dfe7b16f05..070a00b39d99 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -8,7 +8,7 @@ from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND -from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.flux.math import apply_rope1, rope import comfy.ldm.common_dit import comfy.model_management import comfy.patcher_extension @@ -1739,3 +1739,126 @@ def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, tr freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] + + +class BerniniWanModel(WanModel): + """Wan2.2-A14B fine-tune (ByteDance Bernini-R) with in-context conditioning. + + Source video / reference image latents are patch-embedded with the same + ``patch_embedding`` as the noisy target and concatenated as extra tokens + along the sequence. Each conditioning stream carries a ``source_id`` (target + = 0, conditions = 1, 2, ...) realised as an extra multiplicative rotary + factor composed into the spatial RoPE: spatial coordinates overlap across + streams, only the source_id separates them. Self-attention is full over the + concatenated sequence; the target tokens are sliced back out afterwards. + + The condition latents arrive as kwargs (``bernini_video_latent``, + ``bernini_image_latents``) from ``WAN22_Bernini.extra_conds``. + """ + + def _source_id_freqs(self, freqs, source_id): + # Compose an extra rotation (by source_id, over the full head_dim) into + # the spatial rope. source_id == 0 -> identity (target unchanged). + if source_id == 0: + return freqs + d = self.dim // self.num_heads + pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32) + id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype) + return torch.einsum('...ij,...jk->...ik', freqs, id_rot) + + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0): + freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) + return self._source_id_freqs(freqs, source_id) + + def _bernini_conditions(self, kwargs): + # Returns [(latent[B,C,T,H,W], source_id), ...] in concat order: + # source video first (source_id 1), then each reference image (2, 3, ...). + specs = [] + sid = 1 + video = kwargs.get("bernini_video_latent", None) + if video is not None: + specs.append((video, sid)) + sid += 1 + images = kwargs.get("bernini_image_latents", None) + if images is not None: + for i in range(images.shape[2]): + specs.append((images[:, :, i:i + 1], sid)) + sid += 1 + return specs + + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): + bs, c, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + t_len = t + if time_dim_concat is not None: + time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) + x = torch.cat([x, time_dim_concat], dim=2) + t_len = x.shape[2] + + specs = [(comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size), sid) + for lat, sid in self._bernini_conditions(kwargs)] + + # Target rope (source_id 0) first, then one block per condition stream. + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=0) + for lat, sid in specs: + cf = self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=sid) + freqs = torch.cat([freqs, cf], dim=1) + + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, bernini_cond_specs=specs, **kwargs)[:, :, :t, :h, :w] + + def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, bernini_cond_specs=None, **kwargs): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes + x = x.flatten(2).transpose(1, 2) + target_len = x.shape[1] + + # in-context conditions: patch-embed and append (matching freqs order) + if bernini_cond_specs: + for lat, _ in bernini_cond_specs: + cond = self.patch_embedding(lat.float().to(x.device)).to(x.dtype) + x = torch.cat([x, cond.flatten(2).transpose(1, 2)], dim=1) + + # time embeddings + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) + context = torch.cat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + + # head + x = self.head(x, e) + + # drop the appended condition tokens, keep the target + if bernini_cond_specs: + x = x[:, :target_len] + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 2051789119c7..9afb80ff9f59 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1708,6 +1708,32 @@ def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class WAN22_Bernini(WAN22): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.BerniniWanModel) + self.image_to_video = image_to_video + self.memory_usage_factor_conds = ("bernini_video_latent", "bernini_image_latents") + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + video = kwargs.get("bernini_video_latent", None) + if video is not None: + out["bernini_video_latent"] = comfy.conds.CONDRegular(self.process_latent_in(video)) + images = kwargs.get("bernini_image_latents", None) + if images is not None: + out["bernini_image_latents"] = comfy.conds.CONDRegular(self.process_latent_in(images)) + return out + + def extra_conds_shapes(self, **kwargs): + out = super().extra_conds_shapes(**kwargs) + video = kwargs.get("bernini_video_latent", None) + if video is not None: + out["bernini_video_latent"] = video.shape + images = kwargs.get("bernini_image_latents", None) + if images is not None: + out["bernini_image_latents"] = images.shape + return out + class WAN21_FlowRVS(WAN21): def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None): model_config.unet_config["model_type"] = "t2v" diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py new file mode 100644 index 000000000000..4bfb05c43952 --- /dev/null +++ b/comfy_extras/nodes_bernini.py @@ -0,0 +1,95 @@ +import torch + +import comfy.ldm.wan.model +import comfy.model_base +import comfy.model_management +import comfy.utils +import node_helpers + + +def _patch_bernini(model): + """Flip a loaded Wan2.2-A14B model into Bernini-R mode. + + The Bernini checkpoint is architecturally identical to Wan2.2-A14B (no new + params), so we just swap the forward (BerniniWanModel) and the conditioning + plumbing (WAN22_Bernini) onto the already-loaded model. Idempotent. + """ + model.model.diffusion_model.__class__ = comfy.ldm.wan.model.BerniniWanModel + model.model.__class__ = comfy.model_base.WAN22_Bernini + model.model.memory_usage_factor_conds = ("bernini_video_latent", "bernini_image_latents") + return model + + +def _encode_frames(vae, image, width, height): + image = comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + return vae.encode(image[:, :, :, :3]) + + +class BerniniConditioning: + """Routes Bernini-R inputs and activates Bernini mode on the model(s). + + Attaches the VAE-encoded source video / reference images to BOTH the + positive and negative conditioning so stock CFG keeps the conditions fixed + and only varies the text -- giving Bernini's v2v / rv2v guidance form. For + cfg=1.0 (distill LoRA) the same setup is a single forward with the full + conditioning. t2v attaches nothing. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "vae": ("VAE",), + "task_type": (["t2v", "v2v", "rv2v"],), + "width": ("INT", {"default": 832, "min": 16, "max": 8192, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": 8192, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": 8192, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": { + "model_low": ("MODEL",), + "source_video": ("IMAGE",), + "reference_images": ("IMAGE",), + }, + } + + RETURN_TYPES = ("MODEL", "MODEL", "CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("model", "model_low", "positive", "negative", "latent") + FUNCTION = "build" + CATEGORY = "conditioning/video_models" + + def build(self, model, positive, negative, vae, task_type, width, height, length, batch_size, + model_low=None, source_video=None, reference_images=None): + model = _patch_bernini(model) + if model_low is not None: + model_low = _patch_bernini(model_low) + + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], + device=comfy.model_management.intermediate_device()) + + values = {} + if task_type in ("v2v", "rv2v") and source_video is not None: + values["bernini_video_latent"] = _encode_frames(vae, source_video[:length], width, height) + + if task_type == "rv2v" and reference_images is not None: + # each reference image is an independent single-frame stream (its own source_id) + refs = [_encode_frames(vae, reference_images[i:i + 1], width, height) for i in range(reference_images.shape[0])] + values["bernini_image_latents"] = torch.cat(refs, dim=2) + + if values: + positive = node_helpers.conditioning_set_values(positive, values) + negative = node_helpers.conditioning_set_values(negative, values) + + return (model, model_low, positive, negative, {"samples": latent}) + + +NODE_CLASS_MAPPINGS = { + "BerniniConditioning": BerniniConditioning, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "BerniniConditioning": "Bernini Conditioning", +} diff --git a/nodes.py b/nodes.py index 5678bc22d55c..61f4071ab295 100644 --- a/nodes.py +++ b/nodes.py @@ -2403,6 +2403,7 @@ async def init_builtin_extra_nodes(): "nodes_video.py", "nodes_lumina2.py", "nodes_wan.py", + "nodes_bernini.py", "nodes_lotus.py", "nodes_hunyuan3d.py", "nodes_primitive.py", From bb272ea09fdca6d432b35827a6dc16f26099aa12 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 1 Jun 2026 20:02:23 +0300 Subject: [PATCH 02/12] better --- comfy/ldm/wan/model.py | 159 ++++++++-------------------------- comfy/model_base.py | 33 ++----- comfy_extras/nodes_bernini.py | 97 ++++++++++----------- 3 files changed, 89 insertions(+), 200 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 070a00b39d99..15689a4289f6 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -570,6 +570,17 @@ def forward_orig( full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) x = torch.concat((full_ref, x), dim=1) + # In-context reference streams (e.g. Bernini source video / ref images): + # patch-embed each clean condition latent and append as extra tokens (their + # rope, with per-stream source_id, was appended to `freqs` in _forward). + # Inert when no context_latents are supplied. + context_latents = kwargs.get("context_latents", None) + main_len = x.shape[1] + if context_latents is not None: + for lat in context_latents: + cl = self.patch_embedding(lat.float().to(x.device)).to(x.dtype).flatten(2).transpose(1, 2) + x = torch.cat([x, cl], dim=1) + # context context = self.text_embedding(context) @@ -599,6 +610,9 @@ def block_wrap(args): # head x = self.head(x, e) + if context_latents is not None: + x = x[:, :main_len] + if full_ref is not None: x = x[:, full_ref.shape[1]:] @@ -606,7 +620,7 @@ def block_wrap(args): x = self.unpatchify(x, grid_sizes) return x - def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}): + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0): patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) @@ -638,6 +652,16 @@ def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=No img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) freqs = self.rope_embedder(img_ids).movedim(1, 2) + + # In-context reference conditioning (e.g. Bernini): a non-zero source_id + # composes an extra rotation (over the full head_dim) into the spatial + # rope so streams sharing the same spatial coords stay distinct. source_id + # 0 is identity, so this is a no-op for all normal Wan usage. + if source_id: + d = self.dim // self.num_heads + pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32) + id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype) + freqs = torch.einsum('...ij,...jk->...ik', freqs, id_rot) return freqs def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): @@ -661,6 +685,16 @@ def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, tr t_len += 1 freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) + + # In-context reference streams: one rope block per stream, each with its + # own source_id (1, 2, ...) so they stay distinct from the target (id 0). + context_latents = kwargs.get("context_latents", None) + if context_latents is not None: + context_latents = [comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size) for lat in context_latents] + for i, lat in enumerate(context_latents): + freqs = torch.cat([freqs, self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=i + 1)], dim=1) + kwargs = {**kwargs, "context_latents": context_latents} + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): @@ -1739,126 +1773,3 @@ def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, tr freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] - - -class BerniniWanModel(WanModel): - """Wan2.2-A14B fine-tune (ByteDance Bernini-R) with in-context conditioning. - - Source video / reference image latents are patch-embedded with the same - ``patch_embedding`` as the noisy target and concatenated as extra tokens - along the sequence. Each conditioning stream carries a ``source_id`` (target - = 0, conditions = 1, 2, ...) realised as an extra multiplicative rotary - factor composed into the spatial RoPE: spatial coordinates overlap across - streams, only the source_id separates them. Self-attention is full over the - concatenated sequence; the target tokens are sliced back out afterwards. - - The condition latents arrive as kwargs (``bernini_video_latent``, - ``bernini_image_latents``) from ``WAN22_Bernini.extra_conds``. - """ - - def _source_id_freqs(self, freqs, source_id): - # Compose an extra rotation (by source_id, over the full head_dim) into - # the spatial rope. source_id == 0 -> identity (target unchanged). - if source_id == 0: - return freqs - d = self.dim // self.num_heads - pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32) - id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype) - return torch.einsum('...ij,...jk->...ik', freqs, id_rot) - - def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0): - freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) - return self._source_id_freqs(freqs, source_id) - - def _bernini_conditions(self, kwargs): - # Returns [(latent[B,C,T,H,W], source_id), ...] in concat order: - # source video first (source_id 1), then each reference image (2, 3, ...). - specs = [] - sid = 1 - video = kwargs.get("bernini_video_latent", None) - if video is not None: - specs.append((video, sid)) - sid += 1 - images = kwargs.get("bernini_image_latents", None) - if images is not None: - for i in range(images.shape[2]): - specs.append((images[:, :, i:i + 1], sid)) - sid += 1 - return specs - - def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): - bs, c, t, h, w = x.shape - x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) - - t_len = t - if time_dim_concat is not None: - time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) - x = torch.cat([x, time_dim_concat], dim=2) - t_len = x.shape[2] - - specs = [(comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size), sid) - for lat, sid in self._bernini_conditions(kwargs)] - - # Target rope (source_id 0) first, then one block per condition stream. - freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=0) - for lat, sid in specs: - cf = self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=sid) - freqs = torch.cat([freqs, cf], dim=1) - - return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, bernini_cond_specs=specs, **kwargs)[:, :, :t, :h, :w] - - def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, bernini_cond_specs=None, **kwargs): - # embeddings - x = self.patch_embedding(x.float()).to(x.dtype) - grid_sizes = x.shape[2:] - transformer_options["grid_sizes"] = grid_sizes - x = x.flatten(2).transpose(1, 2) - target_len = x.shape[1] - - # in-context conditions: patch-embed and append (matching freqs order) - if bernini_cond_specs: - for lat, _ in bernini_cond_specs: - cond = self.patch_embedding(lat.float().to(x.device)).to(x.dtype) - x = torch.cat([x, cond.flatten(2).transpose(1, 2)], dim=1) - - # time embeddings - e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) - e = e.reshape(t.shape[0], -1, e.shape[-1]) - e0 = self.time_projection(e).unflatten(2, (6, self.dim)) - - # context - context = self.text_embedding(context) - - context_img_len = None - if clip_fea is not None: - if self.img_emb is not None: - context_clip = self.img_emb(clip_fea) - context = torch.cat([context_clip, context], dim=1) - context_img_len = clip_fea.shape[-2] - - patches_replace = transformer_options.get("patches_replace", {}) - blocks_replace = patches_replace.get("dit", {}) - transformer_options["total_blocks"] = len(self.blocks) - transformer_options["block_type"] = "double" - for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) - return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) - x = out["img"] - else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) - - # head - x = self.head(x, e) - - # drop the appended condition tokens, keep the target - if bernini_cond_specs: - x = x[:, :target_len] - - # unpatchify - x = self.unpatchify(x, grid_sizes) - return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 9afb80ff9f59..83680e1f6eea 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1516,6 +1516,13 @@ def extra_conds(self, **kwargs): if reference_latents is not None: out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0]) + # In-context reference conditioning (source video / reference images, + # e.g. Bernini): a list of clean latents appended as extra token streams + # with per-stream source_id rope. Inert when not supplied. + context_latents = kwargs.get("context_latents", None) + if context_latents is not None: + out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents]) + return out @@ -1708,32 +1715,6 @@ def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image -class WAN22_Bernini(WAN22): - def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): - super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.BerniniWanModel) - self.image_to_video = image_to_video - self.memory_usage_factor_conds = ("bernini_video_latent", "bernini_image_latents") - - def extra_conds(self, **kwargs): - out = super().extra_conds(**kwargs) - video = kwargs.get("bernini_video_latent", None) - if video is not None: - out["bernini_video_latent"] = comfy.conds.CONDRegular(self.process_latent_in(video)) - images = kwargs.get("bernini_image_latents", None) - if images is not None: - out["bernini_image_latents"] = comfy.conds.CONDRegular(self.process_latent_in(images)) - return out - - def extra_conds_shapes(self, **kwargs): - out = super().extra_conds_shapes(**kwargs) - video = kwargs.get("bernini_video_latent", None) - if video is not None: - out["bernini_video_latent"] = video.shape - images = kwargs.get("bernini_image_latents", None) - if images is not None: - out["bernini_image_latents"] = images.shape - return out - class WAN21_FlowRVS(WAN21): def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None): model_config.unet_config["model_type"] = "t2v" diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py index 4bfb05c43952..34777f2d6bcf 100644 --- a/comfy_extras/nodes_bernini.py +++ b/comfy_extras/nodes_bernini.py @@ -1,89 +1,86 @@ import torch -import comfy.ldm.wan.model -import comfy.model_base import comfy.model_management import comfy.utils import node_helpers -def _patch_bernini(model): - """Flip a loaded Wan2.2-A14B model into Bernini-R mode. - - The Bernini checkpoint is architecturally identical to Wan2.2-A14B (no new - params), so we just swap the forward (BerniniWanModel) and the conditioning - plumbing (WAN22_Bernini) onto the already-loaded model. Idempotent. - """ - model.model.diffusion_model.__class__ = comfy.ldm.wan.model.BerniniWanModel - model.model.__class__ = comfy.model_base.WAN22_Bernini - model.model.memory_usage_factor_conds = ("bernini_video_latent", "bernini_image_latents") - return model - - -def _encode_frames(vae, image, width, height): - image = comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - return vae.encode(image[:, :, :, :3]) +def _resize_long_edge(image, max_size, stride=16): + """Resize (preserve aspect) so the long edge <= max_size, snapped to `stride`.""" + h, w = image.shape[1], image.shape[2] + scale = min(max_size / max(h, w), 1.0) + nh = max(stride, round(h * scale / stride) * stride) + nw = max(stride, round(w * scale / stride) * stride) + return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "bilinear", "disabled").movedim(1, -1) class BerniniConditioning: - """Routes Bernini-R inputs and activates Bernini mode on the model(s). - - Attaches the VAE-encoded source video / reference images to BOTH the - positive and negative conditioning so stock CFG keeps the conditions fixed - and only varies the text -- giving Bernini's v2v / rv2v guidance form. For - cfg=1.0 (distill LoRA) the same setup is a single forward with the full - conditioning. t2v attaches nothing. + """Bernini-R in-context conditioning for a Wan2.2-A14B model. + + Attaches the VAE-encoded source video / reference images to BOTH positive and + negative conditioning as ``context_latents`` -- an ordered list of clean + latent streams (source video first, then each reference image), which the Wan + model appends as extra tokens with per-stream source_id rope. With stock CFG + the conditions stay fixed and only the text varies; at cfg=1.0 (distill LoRA) + it's a single forward over the full conditioning. + + The task is inferred from which inputs are connected -- no model input and no + task selector needed; the model loads as a normal Wan2.2 checkpoint via the + stock UNETLoader: + (nothing) -> t2v + source_video -> v2v + source_video + ref images -> rv2v + ref images only -> r2v (each kept at native aspect) """ @classmethod def INPUT_TYPES(s): return { "required": { - "model": ("MODEL",), "positive": ("CONDITIONING",), "negative": ("CONDITIONING",), "vae": ("VAE",), - "task_type": (["t2v", "v2v", "rv2v"],), "width": ("INT", {"default": 832, "min": 16, "max": 8192, "step": 16}), "height": ("INT", {"default": 480, "min": 16, "max": 8192, "step": 16}), "length": ("INT", {"default": 81, "min": 1, "max": 8192, "step": 4}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), }, "optional": { - "model_low": ("MODEL",), "source_video": ("IMAGE",), "reference_images": ("IMAGE",), + "ref_max_size": ("INT", {"default": 848, "min": 16, "max": 8192, "step": 16}), }, } - RETURN_TYPES = ("MODEL", "MODEL", "CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("model", "model_low", "positive", "negative", "latent") + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") FUNCTION = "build" CATEGORY = "conditioning/video_models" - def build(self, model, positive, negative, vae, task_type, width, height, length, batch_size, - model_low=None, source_video=None, reference_images=None): - model = _patch_bernini(model) - if model_low is not None: - model_low = _patch_bernini(model_low) - + def build(self, positive, negative, vae, width, height, length, batch_size, + source_video=None, reference_images=None, ref_max_size=848): latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - values = {} - if task_type in ("v2v", "rv2v") and source_video is not None: - values["bernini_video_latent"] = _encode_frames(vae, source_video[:length], width, height) - - if task_type == "rv2v" and reference_images is not None: - # each reference image is an independent single-frame stream (its own source_id) - refs = [_encode_frames(vae, reference_images[i:i + 1], width, height) for i in range(reference_images.shape[0])] - values["bernini_image_latents"] = torch.cat(refs, dim=2) - - if values: - positive = node_helpers.conditioning_set_values(positive, values) - negative = node_helpers.conditioning_set_values(negative, values) - - return (model, model_low, positive, negative, {"samples": latent}) + # Ordered list of condition streams: source video (source_id 1) first, + # then each reference image (source_id 2, 3, ...). The model assigns the + # source_id from list order. The task (t2v/v2v/rv2v/r2v) is implied by + # which inputs are present. + context = [] + if source_video is not None: + vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + context.append(vae.encode(vid[:, :, :, :3])) + + if reference_images is not None: + for i in range(reference_images.shape[0]): + img = _resize_long_edge(reference_images[i:i + 1], ref_max_size) # native aspect per ref + context.append(vae.encode(img[:, :, :, :3])) + + if context: + positive = node_helpers.conditioning_set_values(positive, {"context_latents": context}) + negative = node_helpers.conditioning_set_values(negative, {"context_latents": context}) + + return (positive, negative, {"samples": latent}) NODE_CLASS_MAPPINGS = { From 46ba987361aa656ee279d782a7d80b17c2dd2c4f Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 1 Jun 2026 21:44:57 +0300 Subject: [PATCH 03/12] Cleanup --- comfy/ldm/wan/model.py | 5 +- comfy/model_base.py | 4 +- comfy_extras/nodes_bernini.py | 98 ++++++++++++++++++----------------- 3 files changed, 52 insertions(+), 55 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 15689a4289f6..394b71d08f01 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -570,10 +570,7 @@ def forward_orig( full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) x = torch.concat((full_ref, x), dim=1) - # In-context reference streams (e.g. Bernini source video / ref images): - # patch-embed each clean condition latent and append as extra tokens (their - # rope, with per-stream source_id, was appended to `freqs` in _forward). - # Inert when no context_latents are supplied. + # In-context reference streams (Bernini) context_latents = kwargs.get("context_latents", None) main_len = x.shape[1] if context_latents is not None: diff --git a/comfy/model_base.py b/comfy/model_base.py index 60c9055fcb84..88155b9ae752 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1517,9 +1517,7 @@ def extra_conds(self, **kwargs): if reference_latents is not None: out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0]) - # In-context reference conditioning (source video / reference images, - # e.g. Bernini): a list of clean latents appended as extra token streams - # with per-stream source_id rope. Inert when not supplied. + # In-context reference conditioning (Bernini) context_latents = kwargs.get("context_latents", None) if context_latents is not None: out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents]) diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py index 34777f2d6bcf..aab4dbd1b0af 100644 --- a/comfy_extras/nodes_bernini.py +++ b/comfy_extras/nodes_bernini.py @@ -1,8 +1,10 @@ import torch +from typing_extensions import override import comfy.model_management import comfy.utils import node_helpers +from comfy_api.latest import ComfyExtension, io def _resize_long_edge(image, max_size, stride=16): @@ -11,22 +13,17 @@ def _resize_long_edge(image, max_size, stride=16): scale = min(max_size / max(h, w), 1.0) nh = max(stride, round(h * scale / stride) * stride) nw = max(stride, round(w * scale / stride) * stride) - return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "bilinear", "disabled").movedim(1, -1) + return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "area", "disabled").movedim(1, -1) -class BerniniConditioning: - """Bernini-R in-context conditioning for a Wan2.2-A14B model. +class BerniniConditioning(io.ComfyNode): + """Bernini in-context conditioning for a Wan2.2-A14B model. - Attaches the VAE-encoded source video / reference images to BOTH positive and - negative conditioning as ``context_latents`` -- an ordered list of clean - latent streams (source video first, then each reference image), which the Wan - model appends as extra tokens with per-stream source_id rope. With stock CFG - the conditions stay fixed and only the text varies; at cfg=1.0 (distill LoRA) - it's a single forward over the full conditioning. + Attaches the VAE-encoded source video / reference images to the conditioning + an ordered list of clean latents (source video first, then each reference image), + which the Wan model appends as extra tokens with per-stream source_id rope. - The task is inferred from which inputs are connected -- no model input and no - task selector needed; the model loads as a normal Wan2.2 checkpoint via the - stock UNETLoader: + The task is inferred from which inputs are connected: (nothing) -> t2v source_video -> v2v source_video + ref images -> rv2v @@ -34,41 +31,43 @@ class BerniniConditioning: """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "positive": ("CONDITIONING",), - "negative": ("CONDITIONING",), - "vae": ("VAE",), - "width": ("INT", {"default": 832, "min": 16, "max": 8192, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": 8192, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": 8192, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": { - "source_video": ("IMAGE",), - "reference_images": ("IMAGE",), - "ref_max_size": ("INT", {"default": 848, "min": 16, "max": 8192, "step": 16}), - }, - } - - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "build" - CATEGORY = "conditioning/video_models" - - def build(self, positive, negative, vae, width, height, length, batch_size, - source_video=None, reference_images=None, ref_max_size=848): + def define_schema(cls): + return io.Schema( + node_id="BerniniConditioning", + display_name="Bernini Conditioning", + category="conditioning/video_models", + description="Conditioning node for Bernini in-context video/image conditioning. Attach source video and/or reference images to the positive/negative conditioning, " + "which the Wan model will append as extra tokens with per-stream source_id rope.", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=8192, step=16), + io.Int.Input("height", default=480, min=16, max=8192, step=16), + io.Int.Input("length", default=81, min=1, max=8192, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("source_video", optional=True, tooltip="Source video to edit/restyle (original task v2v or rv2v). Resized to width/height and trimmed to length."), + io.Image.Input("reference_images", optional=True, tooltip="Reference image(s) injected as in-context tokens (task r2v or rv2v). Each is kept at its native aspect ratio, long edge capped at ref_max_size."), + io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, + source_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) # Ordered list of condition streams: source video (source_id 1) first, - # then each reference image (source_id 2, 3, ...). The model assigns the - # source_id from list order. The task (t2v/v2v/rv2v/r2v) is implied by - # which inputs are present. + # then each reference image (source_id 2, 3, ...), the model assigns the source_id from list order. context = [] if source_video is not None: - vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) context.append(vae.encode(vid[:, :, :, :3])) if reference_images is not None: @@ -80,13 +79,16 @@ def build(self, positive, negative, vae, width, height, length, batch_size, positive = node_helpers.conditioning_set_values(positive, {"context_latents": context}) negative = node_helpers.conditioning_set_values(negative, {"context_latents": context}) - return (positive, negative, {"samples": latent}) + return io.NodeOutput(positive, negative, {"samples": latent}) + +class BerniniExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + BerniniConditioning, + ] -NODE_CLASS_MAPPINGS = { - "BerniniConditioning": BerniniConditioning, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "BerniniConditioning": "Bernini Conditioning", -} +async def comfy_entrypoint() -> BerniniExtension: + return BerniniExtension() From f87432bafbeae8516c6b5b4e11c1d54a8a1da1fc Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 2 Jun 2026 00:17:43 +0300 Subject: [PATCH 04/12] Update nodes_bernini.py --- comfy_extras/nodes_bernini.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py index aab4dbd1b0af..c29af857ebd7 100644 --- a/comfy_extras/nodes_bernini.py +++ b/comfy_extras/nodes_bernini.py @@ -24,10 +24,17 @@ class BerniniConditioning(io.ComfyNode): which the Wan model appends as extra tokens with per-stream source_id rope. The task is inferred from which inputs are connected: - (nothing) -> t2v - source_video -> v2v - source_video + ref images -> rv2v - ref images only -> r2v (each kept at native aspect) + (nothing) -> t2v + source_video -> v2v + source_video + ref images -> rv2v + ref images only -> r2v (each kept at native aspect) + source_video + ref_video -> video insertion / "ads2v" + + source_video is the edit base / canvas (resized to width x height). + reference_video is moving content to composite in (e.g. a clip to play on a + screen), kept at its native aspect like the reference images. Streams are + ordered source_video, reference_video, then reference_images -> source_id + 1, 2, 3... matching the reference repo's [base, content, refs]. """ @classmethod @@ -46,7 +53,8 @@ def define_schema(cls): io.Int.Input("height", default=480, min=16, max=8192, step=16), io.Int.Input("length", default=81, min=1, max=8192, step=4), io.Int.Input("batch_size", default=1, min=1, max=4096), - io.Image.Input("source_video", optional=True, tooltip="Source video to edit/restyle (original task v2v or rv2v). Resized to width/height and trimmed to length."), + io.Image.Input("source_video", optional=True, tooltip="Source video to edit/restyle (task v2v or rv2v). Resized to width/height and trimmed to length. Acts as the edit base / canvas."), + io.Image.Input("reference_video", optional=True, tooltip="Moving content to composite into the source video (video insertion / ads2v), e.g. a clip to play on a screen. Kept at native aspect (long edge capped at ref_max_size), trimmed to length."), io.Image.Input("reference_images", optional=True, tooltip="Reference image(s) injected as in-context tokens (task r2v or rv2v). Each is kept at its native aspect ratio, long edge capped at ref_max_size."), io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True), ], @@ -59,17 +67,21 @@ def define_schema(cls): @classmethod def execute(cls, positive, negative, vae, width, height, length, batch_size, - source_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput: + source_video=None, reference_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - # Ordered list of condition streams: source video (source_id 1) first, - # then each reference image (source_id 2, 3, ...), the model assigns the source_id from list order. + # Ordered list of condition streams -> source_id by list order: + # source_video (1), reference_video (2), reference_images (3, 4, ...). context = [] if source_video is not None: vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) context.append(vae.encode(vid[:, :, :, :3])) + if reference_video is not None: + ref_vid = _resize_long_edge(reference_video[:length], ref_max_size) # moving content, native aspect + context.append(vae.encode(ref_vid[:, :, :, :3])) + if reference_images is not None: for i in range(reference_images.shape[0]): img = _resize_long_edge(reference_images[i:i + 1], ref_max_size) # native aspect per ref From 2c7d2561af472f40879d686a67273fe6c5b62e6f Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 2 Jun 2026 01:44:44 +0300 Subject: [PATCH 05/12] Maybe fix context windows for v2v --- comfy/model_base.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 88155b9ae752..f5224a840d27 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1524,6 +1524,20 @@ def extra_conds(self, **kwargs): return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + # In-context streams slicing (Bernini) + if cond_key == "context_latents" and isinstance(getattr(cond_value, "cond", None), list): + dim = window.dim + out = [] + for lat in cond_value.cond: + if lat.ndim > dim and lat.shape[dim] > 1 and lat.shape[dim] == x_in.shape[dim]: + idx = tuple([slice(None)] * dim + [window.index_list]) + out.append(lat[idx].to(device)) + else: + out.append(lat.to(device)) + return cond_value._copy_with(out) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN21_CausalAR(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): From f3d0b070f2f9926eefe81bc5623c3ce20510d8fc Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 2 Jun 2026 10:35:58 +0300 Subject: [PATCH 06/12] Cleanup --- comfy/ldm/wan/model.py | 10 +++------- comfy/model_base.py | 2 +- comfy_extras/nodes_bernini.py | 15 ++++++++++----- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 394b71d08f01..8e2116a6cbed 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -570,7 +570,7 @@ def forward_orig( full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) x = torch.concat((full_ref, x), dim=1) - # In-context reference streams (Bernini) + # In-context reference (Bernini) context_latents = kwargs.get("context_latents", None) main_len = x.shape[1] if context_latents is not None: @@ -650,10 +650,7 @@ def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=No freqs = self.rope_embedder(img_ids).movedim(1, 2) - # In-context reference conditioning (e.g. Bernini): a non-zero source_id - # composes an extra rotation (over the full head_dim) into the spatial - # rope so streams sharing the same spatial coords stay distinct. source_id - # 0 is identity, so this is a no-op for all normal Wan usage. + # In-context reference: a non-zero source_id composes an extra rotation into the spatial rope if source_id: d = self.dim // self.num_heads pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32) @@ -683,8 +680,7 @@ def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, tr freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) - # In-context reference streams: one rope block per stream, each with its - # own source_id (1, 2, ...) so they stay distinct from the target (id 0). + # In-context reference: one rope block per stream, each with it's own source_id (1, 2, ...) to distinguish from the target (id 0). context_latents = kwargs.get("context_latents", None) if context_latents is not None: context_latents = [comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size) for lat in context_latents] diff --git a/comfy/model_base.py b/comfy/model_base.py index f5224a840d27..3742062ce4f5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1525,7 +1525,7 @@ def extra_conds(self, **kwargs): return out def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - # In-context streams slicing (Bernini) + # In-context cond slicing (Bernini) if cond_key == "context_latents" and isinstance(getattr(cond_value, "cond", None), list): dim = window.dim out = [] diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py index c29af857ebd7..4de3460bb442 100644 --- a/comfy_extras/nodes_bernini.py +++ b/comfy_extras/nodes_bernini.py @@ -53,10 +53,15 @@ def define_schema(cls): io.Int.Input("height", default=480, min=16, max=8192, step=16), io.Int.Input("length", default=81, min=1, max=8192, step=4), io.Int.Input("batch_size", default=1, min=1, max=4096), - io.Image.Input("source_video", optional=True, tooltip="Source video to edit/restyle (task v2v or rv2v). Resized to width/height and trimmed to length. Acts as the edit base / canvas."), - io.Image.Input("reference_video", optional=True, tooltip="Moving content to composite into the source video (video insertion / ads2v), e.g. a clip to play on a screen. Kept at native aspect (long edge capped at ref_max_size), trimmed to length."), - io.Image.Input("reference_images", optional=True, tooltip="Reference image(s) injected as in-context tokens (task r2v or rv2v). Each is kept at its native aspect ratio, long edge capped at ref_max_size."), - io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True), + io.Image.Input("source_video", optional=True, tooltip=( + "Source video to edit/restyle (task v2v or rv2v). Resized to width/height and trimmed to length. Acts as the edit base / canvas.")), + io.Image.Input("reference_video", optional=True, tooltip=( + "Moving content to composite into the source video (video insertion / ads2v)," + "e.g. a clip to play on a screen. Kept at native aspect (long edge capped at ref_max_size), trimmed to length.")), + io.Image.Input("reference_images", optional=True, tooltip=( + "Reference image(s) injected as in-context tokens (task r2v or rv2v). Each is kept at its native aspect ratio, long edge capped at ref_max_size.")), + io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True, tooltip=( + "Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio, snapped to 16px, and no upscaling.")), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -72,7 +77,7 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, device=comfy.model_management.intermediate_device()) # Ordered list of condition streams -> source_id by list order: - # source_video (1), reference_video (2), reference_images (3, 4, ...). + # source_video (1), reference_video (2), reference_images (3, 4, ...). context = [] if source_video is not None: vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) From d9a28a9b3c08339d12fec77a704d45b62dbd7516 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 2 Jun 2026 18:52:46 +0300 Subject: [PATCH 07/12] Adjust context window --- comfy/model_base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 3742062ce4f5..36b3f0fc62cc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1531,8 +1531,7 @@ def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, dev out = [] for lat in cond_value.cond: if lat.ndim > dim and lat.shape[dim] > 1 and lat.shape[dim] == x_in.shape[dim]: - idx = tuple([slice(None)] * dim + [window.index_list]) - out.append(lat[idx].to(device)) + out.append(window.get_tensor(lat, device, dim=dim, retain_index_list=retain_index_list)) else: out.append(lat.to(device)) return cond_value._copy_with(out) From 471a20ac8d5da7a2be2a64fef62a93b4b59efa8e Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 2 Jun 2026 19:10:42 +0300 Subject: [PATCH 08/12] Use separate reference image inputs instead Since sizes don't have to match --- comfy_extras/nodes_bernini.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py index 4de3460bb442..3eb95c10247e 100644 --- a/comfy_extras/nodes_bernini.py +++ b/comfy_extras/nodes_bernini.py @@ -8,7 +8,8 @@ def _resize_long_edge(image, max_size, stride=16): - """Resize (preserve aspect) so the long edge <= max_size, snapped to `stride`.""" + """Resize (preserve aspect) so the long edge <= max_size, then snap each side to `stride` + (snapping can nudge a side up/down by < stride, so it never scales up by more than that).""" h, w = image.shape[1], image.shape[2] scale = min(max_size / max(h, w), 1.0) nh = max(stride, round(h * scale / stride) * stride) @@ -58,10 +59,17 @@ def define_schema(cls): io.Image.Input("reference_video", optional=True, tooltip=( "Moving content to composite into the source video (video insertion / ads2v)," "e.g. a clip to play on a screen. Kept at native aspect (long edge capped at ref_max_size), trimmed to length.")), - io.Image.Input("reference_images", optional=True, tooltip=( - "Reference image(s) injected as in-context tokens (task r2v or rv2v). Each is kept at its native aspect ratio, long edge capped at ref_max_size.")), + io.Autogrow.Input("reference_images", optional=True, + template=io.Autogrow.TemplatePrefix( + input=io.Image.Input("reference_image", tooltip=( + "A reference image injected as an in-context token (task r2v or rv2v).")), + prefix="reference_image_", min=0, max=8), + tooltip=( + "Reference image(s) injected as in-context tokens (task r2v or rv2v). Each slot is " + "encoded independently at its own native aspect ratio (long edge capped at " + "ref_max_size), so connect mixed-size references to separate slots.")), io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True, tooltip=( - "Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio, snapped to 16px, and no upscaling.")), + "Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio and snapped to 16px (snapping may nudge a side by <16px).")), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -87,10 +95,16 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_vid = _resize_long_edge(reference_video[:length], ref_max_size) # moving content, native aspect context.append(vae.encode(ref_vid[:, :, :, :3])) - if reference_images is not None: - for i in range(reference_images.shape[0]): - img = _resize_long_edge(reference_images[i:i + 1], ref_max_size) # native aspect per ref - context.append(vae.encode(img[:, :, :, :3])) + # reference_images is an autogrow dict {reference_image_0: IMAGE, ...}; each slot is a + # separate stream at its own native aspect (a multi-image batch in one slot -> one stream per frame). + if reference_images: + for name in sorted(reference_images): + imgs = reference_images[name] + if imgs is None: + continue + for i in range(imgs.shape[0]): + img = _resize_long_edge(imgs[i:i + 1], ref_max_size) # native aspect per ref + context.append(vae.encode(img[:, :, :, :3])) if context: positive = node_helpers.conditioning_set_values(positive, {"context_latents": context}) From 4a6119ba9cffe0b2b0c988eb905dc4f1c4d93b7c Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 10 Jun 2026 02:25:42 +0300 Subject: [PATCH 09/12] Adjust docstrins and tooltips --- comfy_extras/nodes_bernini.py | 36 ++++++++++++++--------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py index 3eb95c10247e..164da476bb73 100644 --- a/comfy_extras/nodes_bernini.py +++ b/comfy_extras/nodes_bernini.py @@ -8,8 +8,7 @@ def _resize_long_edge(image, max_size, stride=16): - """Resize (preserve aspect) so the long edge <= max_size, then snap each side to `stride` - (snapping can nudge a side up/down by < stride, so it never scales up by more than that).""" + """Resize (preserve aspect) so the long edge <= max_size, then snap each side to `stride`""" h, w = image.shape[1], image.shape[2] scale = min(max_size / max(h, w), 1.0) nh = max(stride, round(h * scale / stride) * stride) @@ -21,21 +20,18 @@ class BerniniConditioning(io.ComfyNode): """Bernini in-context conditioning for a Wan2.2-A14B model. Attaches the VAE-encoded source video / reference images to the conditioning - an ordered list of clean latents (source video first, then each reference image), - which the Wan model appends as extra tokens with per-stream source_id rope. + source video first, then each reference image The task is inferred from which inputs are connected: - (nothing) -> t2v - source_video -> v2v - source_video + ref images -> rv2v - ref images only -> r2v (each kept at native aspect) - source_video + ref_video -> video insertion / "ads2v" + (nothing) -> t2v (text-to-video) + source_video -> v2v (video-to-video) + source_video + ref_images -> rv2v (reference-guided video editing) + ref_images only -> r2v (reference-to-video) + source_video + ref_video -> ads2v (insert image/video into video) source_video is the edit base / canvas (resized to width x height). - reference_video is moving content to composite in (e.g. a clip to play on a - screen), kept at its native aspect like the reference images. Streams are - ordered source_video, reference_video, then reference_images -> source_id - 1, 2, 3... matching the reference repo's [base, content, refs]. + reference_video is moving content to composite in. + Streams are ordered source_video, reference_video, then reference_images -> source_id (1, 2, 3, ...). """ @classmethod @@ -44,8 +40,7 @@ def define_schema(cls): node_id="BerniniConditioning", display_name="Bernini Conditioning", category="conditioning/video_models", - description="Conditioning node for Bernini in-context video/image conditioning. Attach source video and/or reference images to the positive/negative conditioning, " - "which the Wan model will append as extra tokens with per-stream source_id rope.", + description="Conditioning node for Bernini in-context video/image conditioning. It can be used for the following tasks: t2v (text-to-video), v2v (video-to-video), rv2v (reference-guided video editing), r2v (reference-to-video), ads2v (insert image/video into video).", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -55,10 +50,9 @@ def define_schema(cls): io.Int.Input("length", default=81, min=1, max=8192, step=4), io.Int.Input("batch_size", default=1, min=1, max=4096), io.Image.Input("source_video", optional=True, tooltip=( - "Source video to edit/restyle (task v2v or rv2v). Resized to width/height and trimmed to length. Acts as the edit base / canvas.")), + "Source video to edit or restyle (v2v, rv2v). Resized to width/height and trimmed to length.")), io.Image.Input("reference_video", optional=True, tooltip=( - "Moving content to composite into the source video (video insertion / ads2v)," - "e.g. a clip to play on a screen. Kept at native aspect (long edge capped at ref_max_size), trimmed to length.")), + "Video to insert into the source video (ads2v).")), io.Autogrow.Input("reference_images", optional=True, template=io.Autogrow.TemplatePrefix( input=io.Image.Input("reference_image", tooltip=( @@ -66,10 +60,9 @@ def define_schema(cls): prefix="reference_image_", min=0, max=8), tooltip=( "Reference image(s) injected as in-context tokens (task r2v or rv2v). Each slot is " - "encoded independently at its own native aspect ratio (long edge capped at " - "ref_max_size), so connect mixed-size references to separate slots.")), + "encoded independently at its own native aspect ratio (long edge capped at ref_max_size)")), io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True, tooltip=( - "Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio and snapped to 16px (snapping may nudge a side by <16px).")), + "Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio and snapped to 16px.")), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -84,7 +77,6 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - # Ordered list of condition streams -> source_id by list order: # source_video (1), reference_video (2), reference_images (3, 4, ...). context = [] if source_video is not None: From 04752b82e25ac60f863311df6a0978d38b98dfaf Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 10 Jun 2026 02:32:35 +0300 Subject: [PATCH 10/12] Update nodes_bernini.py --- comfy_extras/nodes_bernini.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py index 164da476bb73..628e5ee986c3 100644 --- a/comfy_extras/nodes_bernini.py +++ b/comfy_extras/nodes_bernini.py @@ -40,7 +40,8 @@ def define_schema(cls): node_id="BerniniConditioning", display_name="Bernini Conditioning", category="conditioning/video_models", - description="Conditioning node for Bernini in-context video/image conditioning. It can be used for the following tasks: t2v (text-to-video), v2v (video-to-video), rv2v (reference-guided video editing), r2v (reference-to-video), ads2v (insert image/video into video).", + description="Conditioning node for Bernini in-context video/image conditioning. It can be used for the following tasks: t2v (text-to-video), v2v (video-to-video), rv2v (reference-guided video editing), r2v (reference-to-video), ads2v (insert image/video into video)." + "Reference image(s) injected as in-context tokens (task r2v or rv2v), encoded independently at its own native aspect ratio (long edge capped at ref_max_size)", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -57,10 +58,7 @@ def define_schema(cls): template=io.Autogrow.TemplatePrefix( input=io.Image.Input("reference_image", tooltip=( "A reference image injected as an in-context token (task r2v or rv2v).")), - prefix="reference_image_", min=0, max=8), - tooltip=( - "Reference image(s) injected as in-context tokens (task r2v or rv2v). Each slot is " - "encoded independently at its own native aspect ratio (long edge capped at ref_max_size)")), + prefix="reference_image_", min=0, max=8)), io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True, tooltip=( "Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio and snapped to 16px.")), ], From 01bb7f49ec653ea3eb015a98937f06240b091a7b Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Wed, 10 Jun 2026 07:41:26 +0800 Subject: [PATCH 11/12] Apply suggestions from code review Co-authored-by: Alexis Rolland --- comfy_extras/nodes_bernini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py index 628e5ee986c3..0cc21612ccd3 100644 --- a/comfy_extras/nodes_bernini.py +++ b/comfy_extras/nodes_bernini.py @@ -41,7 +41,7 @@ def define_schema(cls): display_name="Bernini Conditioning", category="conditioning/video_models", description="Conditioning node for Bernini in-context video/image conditioning. It can be used for the following tasks: t2v (text-to-video), v2v (video-to-video), rv2v (reference-guided video editing), r2v (reference-to-video), ads2v (insert image/video into video)." - "Reference image(s) injected as in-context tokens (task r2v or rv2v), encoded independently at its own native aspect ratio (long edge capped at ref_max_size)", + "Reference images injected as in-context tokens (r2v, rv2v) are encoded independently at their own native aspect ratio (long edge capped at ref_max_size)", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), From e7a6411aa82f61b4ffb4a4122ed030c430e1cba5 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Wed, 10 Jun 2026 07:41:51 +0800 Subject: [PATCH 12/12] Update comfy_extras/nodes_bernini.py --- comfy_extras/nodes_bernini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py index 0cc21612ccd3..227fa5753bdd 100644 --- a/comfy_extras/nodes_bernini.py +++ b/comfy_extras/nodes_bernini.py @@ -57,7 +57,7 @@ def define_schema(cls): io.Autogrow.Input("reference_images", optional=True, template=io.Autogrow.TemplatePrefix( input=io.Image.Input("reference_image", tooltip=( - "A reference image injected as an in-context token (task r2v or rv2v).")), + "Reference image injected as an in-context token (r2v, rv2v).")), prefix="reference_image_", min=0, max=8)), io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True, tooltip=( "Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio and snapped to 16px.")),