Skip to content

Fix batched multimer template embedding shapes#575

Open
taivu1998 wants to merge 1 commit into
aqlaboratory:mainfrom
taivu1998:tdv/issue-513-batch-template-shapes
Open

Fix batched multimer template embedding shapes#575
taivu1998 wants to merge 1 commit into
aqlaboratory:mainfrom
taivu1998:tdv/issue-513-batch-template-shapes

Conversation

@taivu1998
Copy link
Copy Markdown

Summary

Fixes #513.

This PR fixes batched multimer template embedding when batch_size > 1. The previous multimer template loop selected one template with index_select but retained the selected singleton template axis. With batched inputs, that left tensors shaped like [B, 1, N, ...], causing PyTorch to align the retained template axis against the batch axis when multiplying by the [B, N, N] multichain mask.

Root Cause

The monomer template embedder already removes the selected template axis and restores it later with torch.stack. The multimer template embedder did not, so TemplatePairEmbedderMultimer received single-template features with an extra dimension. That directly caused the reported broadcast error:

output with shape [2, 1, 200, 200] doesn't match the broadcast shape [2, 2, 200, 200]

Once the selected template axis is removed correctly, the pairwise unit-vector calculation also needs to make its point axis explicit (points[..., None, :]) so batched rigid frames broadcast over residue pairs as intended.

Changes

  • Squeeze templ_dim after selecting one multimer template.
  • Stack per-template multimer outputs back on templ_dim instead of concatenating tensors that no longer contain that axis.
  • Make the template unit-vector point axis explicit for batched inputs.
  • Add a focused batch_size=2, n_templ=2 regression test for TemplateEmbedderMultimer.

Validation

  • python -m py_compile openfold/model/embedders.py tests/test_template.py
  • git diff --check
  • pytest tests/test_template.py -k batched_template_embedding_shape
  • pytest tests/test_template.py

Local note: the tests were run in a temporary validation environment because the base Python on this macOS host was missing several declared OpenFold dependencies. Temporary import shims were used only for unavailable local infrastructure modules (tree and the eager attn_core_inplace_cuda import); the focused multimer template path itself executed and passed.

@taivu1998 taivu1998 marked this pull request as ready for review May 11, 2026 03:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Shape Mismatch with batch_size=2

1 participant