Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions openfold/model/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ def forward(self,
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch,
)

Expand Down Expand Up @@ -944,7 +944,9 @@ def forward(self,
single_template_feats["template_aatype"],
)
points = rigid.translation
rigid_vec = rigid[..., None].inverse().apply_to_point(points)
rigid_vec = rigid[..., None].inverse().apply_to_point(
points[..., None, :]
)
unit_vector = rigid_vec.normalized()

pair_act = self.template_pair_embedder(
Expand All @@ -968,7 +970,7 @@ def forward(self,
template_embeds.append(single_template_embeds)

template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
partial(torch.stack, dim=templ_dim),
template_embeds,
)

Expand Down
97 changes: 97 additions & 0 deletions tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
import torch
import numpy as np
import unittest
from openfold.config import model_config
from openfold.model.embedders import TemplateEmbedderMultimer
from openfold.model.template import (
TemplatePointwiseAttention,
TemplatePairStack,
)
from openfold.np import residue_constants
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_template_feats
Expand Down Expand Up @@ -194,6 +197,100 @@ def template_iteration_fn(x):
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)


class TestTemplateEmbedderMultimer(unittest.TestCase):
def test_batched_template_embedding_shape(self):
torch.manual_seed(0)
batch_size = 2
n_templ = 2
n_res = 8

config = model_config("model_1_multimer_v3")
config.model.template.template_pair_stack.no_blocks = 1
config.model.template.template_pair_stack.blocks_per_ckpt = None

embedder = TemplateEmbedderMultimer(config.model.template)
embedder.eval()

aatype = torch.full(
(batch_size, n_templ, n_res),
residue_constants.restype_order["A"],
dtype=torch.long,
)
atom_positions = torch.zeros(
batch_size,
n_templ,
n_res,
residue_constants.atom_type_num,
3,
)
atom_mask = torch.zeros(
batch_size,
n_templ,
n_res,
residue_constants.atom_type_num,
)

n_idx = residue_constants.atom_order["N"]
ca_idx = residue_constants.atom_order["CA"]
c_idx = residue_constants.atom_order["C"]
cb_idx = residue_constants.atom_order["CB"]
residue_positions = torch.arange(n_res, dtype=torch.float32)

atom_positions[..., n_idx, 0] = residue_positions
atom_positions[..., ca_idx, 0] = residue_positions
atom_positions[..., ca_idx, 1] = 1.0
atom_positions[..., c_idx, 0] = residue_positions
atom_positions[..., c_idx, 2] = 1.0
atom_positions[..., cb_idx, 0] = residue_positions
atom_positions[..., cb_idx, 1] = 1.0
atom_positions[..., cb_idx, 2] = 1.0
atom_mask[..., [n_idx, ca_idx, c_idx, cb_idx]] = 1.0

batch = {
"template_aatype": aatype,
"template_all_atom_positions": atom_positions,
"template_all_atom_mask": atom_mask,
}
z = torch.randn(batch_size, n_res, n_res, config.globals.c_z)
padding_mask_2d = torch.ones(batch_size, n_res, n_res)
asym_id = torch.tensor(
[[1] * (n_res // 2) + [2] * (n_res - n_res // 2)]
* batch_size
)
multichain_mask_2d = asym_id[..., None] == asym_id[..., None, :]

with torch.no_grad():
out = embedder(
batch,
z,
padding_mask_2d,
templ_dim=1,
chunk_size=None,
multichain_mask_2d=multichain_mask_2d,
use_lma=False,
inplace_safe=False,
)

self.assertEqual(
out["template_pair_embedding"].shape,
(batch_size, n_res, n_res, config.globals.c_z),
)
self.assertEqual(
out["template_single_embedding"].shape,
(batch_size, n_templ, n_res, config.globals.c_m),
)
self.assertEqual(
out["template_mask"].shape,
(batch_size, n_templ, n_res),
)
self.assertTrue(
torch.isfinite(out["template_pair_embedding"]).all().item()
)
self.assertTrue(
torch.isfinite(out["template_single_embedding"]).all().item()
)


class Template(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down