diff --git a/openfold/data/data_pipeline.py b/openfold/data/data_pipeline.py index 393c1cef3..0273b3de8 100644 --- a/openfold/data/data_pipeline.py +++ b/openfold/data/data_pipeline.py @@ -1327,12 +1327,14 @@ def process_mmcif( sequence_features = {} is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1 for chain_id, seq in mmcif.chain_to_seqres.items(): - desc= "_".join([mmcif.file_id, chain_id]) + desc = "_".join([mmcif.file_id, chain_id]) if seq in sequence_features: - all_chain_features[desc] = copy.deepcopy( + chain_features = copy.deepcopy( sequence_features[seq] ) + chain_features["auth_chain_id"] = np.asarray(desc, dtype=object) + all_chain_features[desc] = chain_features continue if alignment_index is not None: @@ -1356,11 +1358,31 @@ def process_mmcif( chain_id=desc ) - mmcif_feats = self.get_mmcif_features(mmcif, chain_id) - chain_features.update(mmcif_feats) all_chain_features[desc] = chain_features sequence_features[seq] = chain_features + for chain_id, seq in mmcif.chain_to_seqres.items(): + desc = "_".join([mmcif.file_id, chain_id]) + + mmcif_feats = self.get_mmcif_features(mmcif, chain_id) + num_res = len(seq) + if mmcif_feats["all_atom_positions"].shape[0] != num_res: + raise ValueError( + f"mmCIF atom positions for {desc} have length " + f"{mmcif_feats['all_atom_positions'].shape[0]}, " + f"expected {num_res}" + ) + if mmcif_feats["all_atom_mask"].shape[0] != num_res: + raise ValueError( + f"mmCIF atom mask for {desc} has length " + f"{mmcif_feats['all_atom_mask'].shape[0]}, " + f"expected {num_res}" + ) + all_chain_features[desc] = { + **all_chain_features[desc], + **mmcif_feats, + } + all_chain_features = add_assembly_features(all_chain_features) np_example = feature_processing_multimer.pair_and_merge( diff --git a/tests/test_data_pipeline_multimer.py b/tests/test_data_pipeline_multimer.py new file mode 100644 index 000000000..052e3380b --- /dev/null +++ b/tests/test_data_pipeline_multimer.py @@ -0,0 +1,126 @@ +# Copyright 2026 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from types import SimpleNamespace +import unittest +from unittest import mock + +import numpy as np + +from openfold.data import data_pipeline + + +class TestDataPipelineMultimer(unittest.TestCase): + def test_duplicate_sequence_mmcif_chains_get_distinct_structural_features(self): + pipeline = data_pipeline.DataPipelineMultimer( + monomer_data_pipeline=object() + ) + mmcif = SimpleNamespace( + file_id="fake", + chain_to_seqres={"A": "AC", "B": "AC"}, + ) + processed_chain_ids = [] + + def fake_process_single_chain(chain_id, sequence, description, **_): + processed_chain_ids.append(chain_id) + num_res = len(sequence) + aatype = np.zeros((num_res, 21), dtype=np.float32) + aatype[np.arange(num_res), np.arange(num_res) % 21] = 1.0 + return { + "aatype": aatype, + "sequence": np.array([sequence.encode("utf-8")], dtype=object), + "domain_name": np.array( + [description.encode("utf-8")], dtype=object + ), + "num_alignments": np.array( + [1] * num_res, dtype=np.int32 + ), + "seq_length": np.array([num_res] * num_res, dtype=np.int32), + } + + atom_positions = { + "A": np.full((2, 37, 3), 1.0, dtype=np.float32), + "B": np.full((2, 37, 3), 2.0, dtype=np.float32), + } + atom_masks = { + "A": np.ones((2, 37), dtype=np.float32), + "B": np.zeros((2, 37), dtype=np.float32), + } + requested_chain_ids = [] + + def fake_get_mmcif_features(mmcif_object, chain_id): + requested_chain_ids.append(chain_id) + return { + "all_atom_positions": atom_positions[chain_id], + "all_atom_mask": atom_masks[chain_id], + "resolution": np.array(1.0, dtype=np.float32), + "release_date": np.array([b"2026-05-10"], dtype=object), + "is_distillation": np.array(0.0, dtype=np.float32), + } + + captured = {} + + def fake_pair_and_merge(all_chain_features): + captured["all_chain_features"] = copy.deepcopy(all_chain_features) + return {"msa": np.zeros((1, 1), dtype=np.int32)} + + pipeline._process_single_chain = fake_process_single_chain + pipeline.get_mmcif_features = fake_get_mmcif_features + + with mock.patch.object( + data_pipeline.feature_processing_multimer, + "pair_and_merge", + side_effect=fake_pair_and_merge, + ), mock.patch.object( + data_pipeline, + "pad_msa", + side_effect=lambda np_example, _: np_example, + ): + pipeline.process_mmcif(mmcif=mmcif, alignment_dir="/unused") + + self.assertEqual(processed_chain_ids, ["fake_A"]) + self.assertEqual(requested_chain_ids, ["A", "B"]) + + all_chain_features = captured["all_chain_features"] + np.testing.assert_array_equal( + all_chain_features["A_1"]["all_atom_positions"], + atom_positions["A"], + ) + np.testing.assert_array_equal( + all_chain_features["A_2"]["all_atom_positions"], + atom_positions["B"], + ) + np.testing.assert_array_equal( + all_chain_features["A_1"]["all_atom_mask"], + atom_masks["A"], + ) + np.testing.assert_array_equal( + all_chain_features["A_2"]["all_atom_mask"], + atom_masks["B"], + ) + self.assertFalse( + np.array_equal( + all_chain_features["A_1"]["all_atom_positions"], + all_chain_features["A_2"]["all_atom_positions"], + ) + ) + self.assertEqual( + all_chain_features["A_2"]["auth_chain_id"].item(), + "fake_B", + ) + + +if __name__ == "__main__": + unittest.main()