diff --git a/openfold/utils/loss.py b/openfold/utils/loss.py index 395e34753..9741cdce0 100644 --- a/openfold/utils/loss.py +++ b/openfold/utils/loss.py @@ -1649,7 +1649,8 @@ def chain_center_of_mass_loss( clamp_distance: Cutoff above which distance errors are disregarded weight: - Weight for loss + Accepted for config/backward compatibility. The top-level loss + weight is applied by AlphaFoldLoss. eps: Small value used to regularize denominators Returns: @@ -1675,7 +1676,7 @@ def get_chain_center_of_mass(pos): pred_dists = euclidean_distance(pred_centers[..., None, :], pred_centers[..., :, None], epsilon=eps) true_dists = euclidean_distance(true_centers[..., None, :], true_centers[..., :, None], epsilon=eps) - losses = torch.clamp((weight * (pred_dists - true_dists - clamp_distance)), max=0) ** 2 + losses = torch.clamp(pred_dists - true_dists - clamp_distance, max=0) ** 2 loss_mask = chain_exists[..., :, None] * chain_exists[..., None, :] loss = masked_mean(loss_mask, losses, dim=(-1, -2)) diff --git a/tests/test_loss.py b/tests/test_loss.py index b52ea24fb..17b4a9fe3 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -1112,6 +1112,67 @@ def run_tm_loss(representations, batch, value): self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) + def _make_chain_center_of_mass_inputs(self, requires_grad=False): + ca_pos = residue_constants.atom_order["CA"] + all_atom_positions = torch.zeros((1, 2, 37, 3), dtype=torch.float32) + all_atom_positions[0, 0, ca_pos] = torch.tensor([0.0, 0.0, 0.0]) + all_atom_positions[0, 1, ca_pos] = torch.tensor([10.0, 0.0, 0.0]) + + all_atom_pred_pos = torch.zeros_like(all_atom_positions) + all_atom_pred_pos[0, 0, ca_pos] = torch.tensor([0.0, 0.0, 0.0]) + all_atom_pred_pos[0, 1, ca_pos] = torch.tensor([2.0, 0.0, 0.0]) + if requires_grad: + all_atom_pred_pos.requires_grad_() + + all_atom_mask = torch.zeros((1, 2, 37), dtype=torch.float32) + all_atom_mask[:, :, ca_pos] = 1.0 + asym_id = torch.tensor([[1, 2]], dtype=torch.float32) + + return all_atom_pred_pos, all_atom_positions, all_atom_mask, asym_id + + def test_chain_center_of_mass_loss_is_unweighted(self): + inputs = self._make_chain_center_of_mass_inputs() + loss = chain_center_of_mass_loss( + all_atom_pred_pos=inputs[0], + all_atom_positions=inputs[1], + all_atom_mask=inputs[2], + asym_id=inputs[3], + clamp_distance=-4.0, + weight=0.05, + ) + larger_weight_loss = chain_center_of_mass_loss( + all_atom_pred_pos=inputs[0], + all_atom_positions=inputs[1], + all_atom_mask=inputs[2], + asym_id=inputs[3], + clamp_distance=-4.0, + weight=0.5, + ) + + expected = torch.tensor([32.0 / (4.0 + 1e-4)], dtype=loss.dtype) + self.assertTrue(torch.allclose(loss, expected, rtol=1e-5, atol=1e-5)) + self.assertTrue( + torch.allclose(loss, larger_weight_loss, rtol=1e-6, atol=1e-6) + ) + + def test_chain_center_of_mass_loss_backpropagates(self): + inputs = self._make_chain_center_of_mass_inputs(requires_grad=True) + loss = chain_center_of_mass_loss( + all_atom_pred_pos=inputs[0], + all_atom_positions=inputs[1], + all_atom_mask=inputs[2], + asym_id=inputs[3], + clamp_distance=-4.0, + weight=0.05, + ) + + loss.sum().backward() + + ca_pos = residue_constants.atom_order["CA"] + ca_grad = inputs[0].grad[:, :, ca_pos, :] + self.assertTrue(torch.all(torch.isfinite(ca_grad))) + self.assertGreater(torch.norm(ca_grad).item(), 0.0) + @compare_utils.skip_unless_alphafold_installed() def test_chain_center_of_mass_loss(self): batch_size = consts.batch_size @@ -1139,6 +1200,9 @@ def test_chain_center_of_mass_loss(self): ) out_repro = out_repro.cpu() + self.assertTrue(torch.all(torch.isfinite(out_repro))) + self.assertTrue(torch.all(out_repro >= 0)) + if __name__ == "__main__": unittest.main()