diff --git a/openfold/utils/loss.py b/openfold/utils/loss.py index 395e34753..6c2e05ffa 100644 --- a/openfold/utils/loss.py +++ b/openfold/utils/loss.py @@ -342,7 +342,8 @@ def supervised_chi_loss( Args: angles_sin_cos: - [*, N, 7, 2] predicted angles + [I, *, N, 7, 2] predicted angles, where I is the + StructureModule block axis and * are optional batch dimensions unnormalized_angles_sin_cos: The same angles, but unnormalized aatype: @@ -350,26 +351,20 @@ def supervised_chi_loss( seq_mask: [*, N] sequence mask chi_mask: - [*, N, 7] angle mask + [*, N, 4] chi angle mask chi_angles_sin_cos: - [*, N, 7, 2] ground truth angles + [*, N, 4, 2] ground truth chi angles chi_weight: Weight for the angle component of the loss angle_norm_weight: Weight for the normalization component of the loss Returns: - [*] loss tensor + Scalar loss tensor """ pred_angles = angles_sin_cos[..., 3:, :] - residue_type_one_hot = torch.nn.functional.one_hot( - aatype, - residue_constants.restype_num + 1, - ) - chi_pi_periodic = torch.einsum( - "...ij,jk->ik", - residue_type_one_hot.type(angles_sin_cos.dtype), - angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic), - ) + chi_pi_periodic = angles_sin_cos.new_tensor( + residue_constants.chi_pi_periodic, + )[aatype, ...] true_chi = chi_angles_sin_cos[None] diff --git a/tests/test_loss.py b/tests/test_loss.py index b52ea24fb..38edd9740 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -682,6 +682,39 @@ def run_supervised_chi_loss(value, batch): self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) + def test_supervised_chi_loss_preserves_batched_chi_periodicity(self): + num_blocks, batch_size, n_res = 1, 2, 1 + angles = torch.zeros(num_blocks, batch_size, n_res, 7, 2) + unnormalized_angles = torch.ones_like(angles) + + chi_angles_sin_cos = torch.zeros(batch_size, n_res, 4, 2) + chi_angles_sin_cos[..., 1, :] = torch.tensor([0.0, 1.0]) + angles[..., 3:, :] = chi_angles_sin_cos.unsqueeze(0) + angles[..., 3 + 1, :] = torch.tensor([0.0, -1.0]) + + aatype = torch.tensor( + [ + [residue_constants.restype_order["D"]], + [residue_constants.restype_order["R"]], + ] + ) + seq_mask = torch.ones(batch_size, n_res) + chi_mask = torch.zeros(batch_size, n_res, 4) + chi_mask[..., 1] = 1.0 + + loss = supervised_chi_loss( + angles_sin_cos=angles, + unnormalized_angles_sin_cos=unnormalized_angles, + aatype=aatype, + seq_mask=seq_mask, + chi_mask=chi_mask, + chi_angles_sin_cos=chi_angles_sin_cos, + chi_weight=1.0, + angle_norm_weight=0.0, + ) + + self.assertTrue(torch.isclose(loss, torch.tensor(2.0), atol=1e-3)) + @compare_utils.skip_unless_alphafold_installed() def test_violation_loss(self): config = compare_utils.get_alphafold_config()