Skip to content

fix: match rope batch dim with input#802

Open
ssitu wants to merge 2 commits intonunchaku-ai:devfrom
ssitu:fix/z_image_fuse_rope_batch
Open

fix: match rope batch dim with input#802
ssitu wants to merge 2 commits intonunchaku-ai:devfrom
ssitu:fix/z_image_fuse_rope_batch

Conversation

@ssitu
Copy link

@ssitu ssitu commented Feb 3, 2026

Motivation

When running z-image-turbo through nunchaku with a batch size > 1, the following occurred:

cache miss and created, cache_key: (13215415296, torch.Size([1, 224, 1, 64, 2, 2])), orig shape: torch.Size([1, 224, 1, 64, 2, 2]), packed shape: torch.Size([1, 256, 128])
x shape: torch.Size([2, 224, 3840]), freqs_cis shape: torch.Size([1, 256, 128])
Assertion failed: rotary_emb.shape[0] * rotary_emb.shape[1] == M, file C:\Users\nunchaku\Desktop\actions-runner\_work\nunchaku\nunchaku\src\kernels\zgemm\gemm_w4a4_launch_impl.cuh, line 353

Otherwise, sampling single images is completely functional for me.
This problem started after commit c43921f

Environment:
pytorch version: 2.10.0+cu130
Set vram state to: LOW_VRAM
Using pytorch attention
Python version: 3.13.2 (tags/v3.13.2:4f8bb39, Feb  4 2025, 15:23:48) [MSC v.1942 64 bit (AMD64)]
ComfyUI version: 0.11.1
Nunchaku version: nunchaku-1.2.1+cu13.0torch2.10-cp313-cp313-win_amd64.whl
ComfyUI-nunchaku version: 1.2.1
platform is win32
This PR prevents the assertion error and makes sampling batches of images working again while still using the fused operation.

Should fix #778; they use a batch size of 8.
Maybe related to #774; discusses the same assertion error.

Modifications

In the RopeFuseAttentionHook during a cache miss and a packed freq_cis must be created, freq_cis is expanded to match the batch size of the input before being flattened with the sequence dimension before continuing on to be padded and packed:

freqs_cis = freqs_cis[..., [1], :].squeeze(2)  # See comfy.ldm.flux.math#rope, #apply_rope
+ freqs_cis = freqs_cis.expand(x.shape[0], -1, -1, -1, -1)  # [b, s, 64, 1, 2]
+ freqs_cis = freqs_cis.flatten(0, 1)  # [b*s, 64, 1, 2]
+ freqs_cis = freqs_cis.unsqueeze(0)  # [1, b*s, 64, 1, 2]
packed_freqs_cis = pack_rotemb(pad_tensor(freqs_cis, 256, 1))
self.packed_freqs_cis_cache[cache_key] = packed_freqs_cis

If we just use expand by itself to match the batch size, it passes the assert, but the output images in the batch after the first image does not match the reference workflow, and is very low quality. I do not know how to avoid copying data here.

A test case is added which was adapted from the regular z-image-turbo workflow, but with a batch size of 2 and a node for selecting the second image in the batch.

Checklist

  • Code is formatted using Pre-Commit hooks (run pre-commit run --all-files).
  • Relevant unit tests are added in the tests/workflows directory following the guidance in the Contribution Guide.
  • Reference images are uploaded to PR comments and URLs are added to test_cases.json.
  • Additional test data (if needed) is registered in test_data/inputs.yaml.
  • Additional models (if needed) are registered in scripts/download_models.py and test_data/models.yaml.
  • Additional custom nodes (if needed) are added to .github/workflows/pr-test.yaml.
  • For reviewers: If you're only helping merge the main branch and haven't contributed code to this PR, please remove yourself as a co-author when merging.
  • Please feel free to join our Discord or WeChat to discuss your PR.

@ssitu ssitu changed the base branch from main to dev February 3, 2026 01:02
@ssitu
Copy link
Author

ssitu commented Feb 3, 2026

Reference image: z-image-batch_00001_
SVDQ image: z-image-batch_00015_

@ssitu ssitu marked this pull request as ready for review February 3, 2026 04:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

quit crash auto

1 participant