diff --git a/openfold/utils/loss.py b/openfold/utils/loss.py index 395e34753..1f15038cf 100644 --- a/openfold/utils/loss.py +++ b/openfold/utils/loss.py @@ -55,6 +55,89 @@ def sigmoid_cross_entropy(logits, labels): return loss +def _shape_numel(shape): + numel = 1 + for dim in shape: + numel *= dim + return numel + + +def _reduce_loss(loss: torch.Tensor, reduction: str) -> torch.Tensor: + if reduction == "none": + return loss + if reduction == "mean": + return torch.mean(loss) + raise ValueError(f"Unsupported loss reduction: {reduction}") + + +def _mean_non_batch_dims( + loss: torch.Tensor, + batch_shape: torch.Size, +) -> torch.Tensor: + if len(batch_shape) == 0: + return torch.mean(loss) + + if loss.shape[-len(batch_shape):] != batch_shape: + raise ValueError( + f"loss shape {tuple(loss.shape)} does not end with batch shape " + f"{tuple(batch_shape)}" + ) + + extra_dims = loss.ndim - len(batch_shape) + if extra_dims <= 0: + return loss + + return torch.mean(loss, dim=tuple(range(extra_dims))) + + +def _ensure_loss_shape( + loss: torch.Tensor, + batch_shape: torch.Size, + loss_name: str, +) -> torch.Tensor: + if loss.shape == batch_shape: + return loss + + if len(batch_shape) == 0 and loss.ndim == 0: + return loss + + raise ValueError( + f"{loss_name} loss has shape {tuple(loss.shape)}, expected " + f"{tuple(batch_shape)}" + ) + + +def _normalize_length_shape( + seq_length: torch.Tensor, + batch_shape: torch.Size, +) -> torch.Tensor: + if seq_length.shape == batch_shape: + return seq_length + + if seq_length.numel() == _shape_numel(batch_shape): + return seq_length.reshape(batch_shape) + + raise ValueError( + f"seq_length has shape {tuple(seq_length.shape)}, expected " + f"{tuple(batch_shape)}" + ) + + +def _apply_length_scale( + per_example_loss: torch.Tensor, + seq_length: torch.Tensor, + crop_len: int, +) -> torch.Tensor: + seq_length = seq_length.to( + device=per_example_loss.device, + dtype=per_example_loss.dtype, + ) + seq_length = _normalize_length_shape(seq_length, per_example_loss.shape) + crop_len = seq_length.new_tensor(crop_len) + scale = torch.sqrt(torch.minimum(seq_length, crop_len)) + return per_example_loss * scale + + def torsion_angle_loss( a, # [*, N, 7, 2] a_gt, # [*, N, 7, 2] @@ -175,6 +258,7 @@ def backbone_loss( clamp_distance: float = 10.0, loss_unit_distance: float = 10.0, eps: float = 1e-4, + reduction: str = "mean", **kwargs, ) -> torch.Tensor: ### need to check if the traj belongs to 4*4 matrix or a tensor_7 @@ -226,10 +310,10 @@ def backbone_loss( 1 - use_clamped_fape ) - # Average over the batch dimension - fape_loss = torch.mean(fape_loss) + batch_shape = backbone_rigid_mask.shape[:-1] + fape_loss = _mean_non_batch_dims(fape_loss, batch_shape) - return fape_loss + return _reduce_loss(fape_loss, reduction) def sidechain_loss( @@ -287,6 +371,7 @@ def fape_loss( out: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor], config: ml_collections.ConfigDict, + reduction: str = "mean", ) -> torch.Tensor: traj = out["sm"]["frames"] asym_id = batch.get("asym_id") @@ -295,11 +380,13 @@ def fape_loss( intra_chain_bb_loss = backbone_loss( traj=traj, pair_mask=intra_chain_mask, + reduction="none", **{**batch, **config.intra_chain_backbone}, ) interface_bb_loss = backbone_loss( traj=traj, pair_mask=1. - intra_chain_mask, + reduction="none", **{**batch, **config.interface_backbone}, ) weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight @@ -307,6 +394,7 @@ def fape_loss( else: bb_loss = backbone_loss( traj=traj, + reduction="none", **{**batch, **config.backbone}, ) weighted_bb_loss = bb_loss * config.backbone.weight @@ -319,10 +407,7 @@ def fape_loss( loss = weighted_bb_loss + config.sidechain.weight * sc_loss - # Average over the batch dimension - loss = torch.mean(loss) - - return loss + return _reduce_loss(loss, reduction) def supervised_chi_loss( @@ -335,6 +420,7 @@ def supervised_chi_loss( chi_weight: float, angle_norm_weight: float, eps=1e-6, + reduction: str = "mean", **kwargs, ) -> torch.Tensor: """ @@ -405,10 +491,7 @@ def supervised_chi_loss( loss = loss + angle_norm_weight * angle_norm_loss - # Average over the batch dimension - loss = torch.mean(loss) - - return loss + return _reduce_loss(loss, reduction) def compute_plddt(logits: torch.Tensor) -> torch.Tensor: @@ -515,6 +598,7 @@ def lddt_loss( min_resolution: float = 0.1, max_resolution: float = 3.0, eps: float = 1e-10, + reduction: str = "mean", **kwargs, ) -> torch.Tensor: n = all_atom_mask.shape[-2] @@ -553,10 +637,7 @@ def lddt_loss( (resolution >= min_resolution) & (resolution <= max_resolution) ) - # Average over the batch dimension - loss = torch.mean(loss) - - return loss + return _reduce_loss(loss, reduction) def distogram_loss( @@ -567,6 +648,7 @@ def distogram_loss( max_bin=21.6875, no_bins=64, eps=1e-6, + reduction: str = "mean", **kwargs, ): boundaries = torch.linspace( @@ -601,10 +683,7 @@ def distogram_loss( mean = mean / denom[..., None] mean = torch.sum(mean, dim=-1) - # Average over the batch dimensions - mean = torch.mean(mean) - - return mean + return _reduce_loss(mean, reduction) def _calculate_bin_centers(boundaries: torch.Tensor): @@ -729,6 +808,7 @@ def tm_loss( min_resolution: float = 0.1, max_resolution: float = 3.0, eps=1e-8, + reduction: str = "mean", **kwargs, ): # first check whether this is a tensor_7 or tensor_4*4 @@ -773,10 +853,7 @@ def _points(affine): (resolution >= min_resolution) & (resolution <= max_resolution) ) - # Average over the batch dimension - loss = torch.mean(loss) - - return loss + return _reduce_loss(loss, reduction) def between_residue_bond_loss( @@ -1434,9 +1511,10 @@ def violation_loss( atom14_atom_exists: torch.Tensor, average_clashes: bool = False, eps=1e-6, + reduction: str = "mean", **kwargs, ) -> torch.Tensor: - num_atoms = torch.sum(atom14_atom_exists) + num_atoms = torch.sum(atom14_atom_exists, dim=(-1, -2)) per_atom_clash = (violations["between_residues"]["clashes_per_atom_loss_sum"] + violations["within_residues"]["per_atom_loss_sum"]) @@ -1446,7 +1524,7 @@ def violation_loss( violations["within_residues"]["per_atom_num_clash"]) per_atom_clash = per_atom_clash / (num_clash + eps) - l_clash = torch.sum(per_atom_clash) / (eps + num_atoms) + l_clash = torch.sum(per_atom_clash, dim=(-1, -2)) / (eps + num_atoms) loss = ( violations["between_residues"]["bonds_c_n_loss_mean"] + violations["between_residues"]["angles_ca_c_n_loss_mean"] @@ -1454,10 +1532,7 @@ def violation_loss( + l_clash ) - # Average over the batch dimension - mean = torch.mean(loss) - - return mean + return _reduce_loss(loss, reduction) def compute_renamed_ground_truth( @@ -1576,6 +1651,7 @@ def experimentally_resolved_loss( min_resolution: float, max_resolution: float, eps: float = 1e-8, + reduction: str = "mean", **kwargs, ) -> torch.Tensor: errors = sigmoid_cross_entropy(logits, all_atom_mask) @@ -1587,12 +1663,18 @@ def experimentally_resolved_loss( (resolution >= min_resolution) & (resolution <= max_resolution) ) - loss = torch.mean(loss) - - return loss + return _reduce_loss(loss, reduction) -def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs): +def masked_msa_loss( + logits, + true_msa, + bert_mask, + num_classes, + eps=1e-8, + reduction: str = "mean", + **kwargs, +): """ Computes BERT-style masked MSA loss. Implements subsection 1.9.9. @@ -1620,9 +1702,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs loss = torch.sum(loss, dim=-1) loss = loss * scale - loss = torch.mean(loss) - - return loss + return _reduce_loss(loss, reduction) def chain_center_of_mass_loss( @@ -1712,33 +1792,40 @@ def loss(self, out, batch, _return_breakdown=False): loss_fns = { "distogram": lambda: distogram_loss( logits=out["distogram_logits"], + reduction="none", **{**batch, **self.config.distogram}, ), "experimentally_resolved": lambda: experimentally_resolved_loss( logits=out["experimentally_resolved_logits"], + reduction="none", **{**batch, **self.config.experimentally_resolved}, ), "fape": lambda: fape_loss( out, batch, self.config.fape, + reduction="none", ), "plddt_loss": lambda: lddt_loss( logits=out["lddt_logits"], all_atom_pred_pos=out["final_atom_positions"], + reduction="none", **{**batch, **self.config.plddt_loss}, ), "masked_msa": lambda: masked_msa_loss( logits=out["masked_msa_logits"], + reduction="none", **{**batch, **self.config.masked_msa}, ), "supervised_chi": lambda: supervised_chi_loss( out["sm"]["angles"], out["sm"]["unnormalized_angles"], + reduction="none", **{**batch, **self.config.supervised_chi}, ), "violation": lambda: violation_loss( out["violation"], + reduction="none", **{**batch, **self.config.violation}, ), } @@ -1746,6 +1833,7 @@ def loss(self, out, batch, _return_breakdown=False): if self.config.tm.enabled: loss_fns["tm"] = lambda: tm_loss( logits=out["tm_logits"], + reduction="none", **{**batch, **out, **self.config.tm}, ) @@ -1755,27 +1843,31 @@ def loss(self, out, batch, _return_breakdown=False): **{**batch, **self.config.chain_center_of_mass}, ) - cum_loss = 0. + batch_shape = batch["aatype"].shape[:-1] + cum_loss = None losses = {} for loss_name, loss_fn in loss_fns.items(): weight = self.config[loss_name].weight loss = loss_fn() - if torch.isnan(loss) or torch.isinf(loss): + loss = _ensure_loss_shape(loss, batch_shape, loss_name) + if not torch.isfinite(loss).all(): # for k,v in batch.items(): # if torch.any(torch.isnan(v)) or torch.any(torch.isinf(v)): # logging.warning(f"{k}: is nan") # logging.warning(f"{loss_name}: {loss}") - logging.warning(f"{loss_name} loss is NaN. Skipping...") - loss = loss.new_tensor(0., requires_grad=True) - cum_loss = cum_loss + weight * loss - losses[loss_name] = loss.detach().clone() - losses["unscaled_loss"] = cum_loss.detach().clone() - - # Scale the loss by the square root of the minimum of the crop size and - # the (average) sequence length. See subsection 1.9. - seq_len = torch.mean(batch["seq_length"].float()) + logging.warning(f"{loss_name} loss is NaN or inf. Skipping...") + loss = torch.zeros_like(loss, requires_grad=True) + weighted_loss = weight * loss + cum_loss = weighted_loss if cum_loss is None else cum_loss + weighted_loss + losses[loss_name] = torch.mean(loss).detach().clone() + + losses["unscaled_loss"] = torch.mean(cum_loss).detach().clone() + + # Keep losses per example until this scale is applied; otherwise + # mixed-length local batches use the wrong scale. crop_len = batch["aatype"].shape[-1] - cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len)) + cum_loss = _apply_length_scale(cum_loss, batch["seq_length"], crop_len) + cum_loss = torch.mean(cum_loss) losses["loss"] = cum_loss.detach().clone() diff --git a/tests/test_loss.py b/tests/test_loss.py index b52ea24fb..9cb6346bd 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -17,10 +17,12 @@ import numpy as np from pathlib import Path import unittest +from unittest import mock import ml_collections as mlc from openfold.data import data_transforms from openfold.np import residue_constants +import openfold.utils.loss as loss_module from openfold.utils.rigid_utils import ( Rotation, Rigid, @@ -44,7 +46,9 @@ tm_loss, compute_plddt, compute_tm, - chain_center_of_mass_loss + chain_center_of_mass_loss, + _apply_length_scale, + _ensure_loss_shape, ) from openfold.utils.tensor_utils import ( tree_map, @@ -91,6 +95,376 @@ def setUpClass(cls): cls.am_modules = alphafold.model.modules cls.am_rigid = alphafold.model.r3 + def test_apply_length_scale_per_example(self): + per_example = torch.tensor([1.0, 4.0]) + seq_length = torch.tensor([100, 400]) + + scaled = _apply_length_scale(per_example, seq_length, crop_len=256) + expected = torch.tensor([10.0, 64.0]) + + torch.testing.assert_close(scaled, expected) + + old_scaled_mean = per_example.mean() * torch.sqrt( + torch.minimum(seq_length.float().mean(), torch.tensor(256.0)) + ) + self.assertFalse(torch.allclose(scaled.mean(), old_scaled_mean)) + + def test_ensure_loss_shape(self): + loss = torch.ones(2) + + self.assertIs( + _ensure_loss_shape(loss, torch.Size([2]), "example"), + loss, + ) + + scalar_loss = torch.tensor(1.0) + self.assertIs( + _ensure_loss_shape(scalar_loss, torch.Size([]), "example"), + scalar_loss, + ) + + with self.assertRaises(ValueError): + _ensure_loss_shape(scalar_loss, torch.Size([2]), "example") + + def test_alphafold_loss_applies_length_scale_per_example(self): + loss_names = [ + "distogram", + "experimentally_resolved", + "fape", + "plddt_loss", + "masked_msa", + "supervised_chi", + "violation", + ] + config = mlc.ConfigDict( + {name: {"weight": 0.0} for name in loss_names} + ) + config.distogram.weight = 1.0 + config.tm = mlc.ConfigDict({"enabled": False, "weight": 0.0}) + config.chain_center_of_mass = mlc.ConfigDict( + {"enabled": False, "weight": 0.0} + ) + + per_example_loss = torch.tensor([1.0, 4.0]) + zero_loss = torch.zeros_like(per_example_loss) + batch = { + "aatype": torch.zeros(2, 256, dtype=torch.long), + "seq_length": torch.tensor([100, 400]), + } + out = { + "distogram_logits": torch.empty(0), + "experimentally_resolved_logits": torch.empty(0), + "lddt_logits": torch.empty(0), + "masked_msa_logits": torch.empty(0), + "final_atom_positions": torch.empty(0), + "renamed_atom14_gt_positions": torch.empty(0), + "violation": {}, + "sm": { + "angles": torch.empty(0), + "frames": torch.empty(0), + "positions": torch.empty(0), + "sidechain_frames": torch.empty(0), + "unnormalized_angles": torch.empty(0), + }, + } + + with mock.patch.multiple( + loss_module, + distogram_loss=mock.Mock(return_value=per_example_loss), + experimentally_resolved_loss=mock.Mock( + return_value=zero_loss + ), + fape_loss=mock.Mock(return_value=zero_loss), + lddt_loss=mock.Mock(return_value=zero_loss), + masked_msa_loss=mock.Mock(return_value=zero_loss), + supervised_chi_loss=mock.Mock(return_value=zero_loss), + violation_loss=mock.Mock(return_value=zero_loss), + ): + loss, losses = loss_module.AlphaFoldLoss(config)( + out, + batch, + _return_breakdown=True, + ) + + expected = torch.tensor([10.0, 64.0]).mean() + old_scaled_loss = per_example_loss.mean() * torch.sqrt( + torch.minimum( + batch["seq_length"].float().mean(), + torch.tensor(256.0), + ) + ) + + torch.testing.assert_close(loss, expected) + torch.testing.assert_close( + losses["unscaled_loss"], + per_example_loss.mean(), + ) + self.assertFalse(torch.allclose(loss, old_scaled_loss)) + + def test_distogram_loss_reduction(self): + batch_size = 2 + n_res = 4 + no_bins = 5 + + logits = torch.randn(batch_size, n_res, n_res, no_bins) + pseudo_beta = torch.randn(batch_size, n_res, 3) + pseudo_beta_mask = torch.ones(batch_size, n_res) + + loss = distogram_loss( + logits, + pseudo_beta, + pseudo_beta_mask, + no_bins=no_bins, + reduction="none", + ) + mean_loss = distogram_loss( + logits, + pseudo_beta, + pseudo_beta_mask, + no_bins=no_bins, + reduction="mean", + ) + + self.assertEqual(loss.shape, (batch_size,)) + torch.testing.assert_close(mean_loss, loss.mean()) + + def test_masked_msa_loss_reduction(self): + batch_size = 2 + n_seq = 3 + n_res = 4 + num_classes = 7 + + logits = torch.randn(batch_size, n_seq, n_res, num_classes) + true_msa = torch.randint(0, num_classes, (batch_size, n_seq, n_res)) + bert_mask = torch.ones(batch_size, n_seq, n_res) + + loss = masked_msa_loss( + logits, + true_msa, + bert_mask, + num_classes, + reduction="none", + ) + mean_loss = masked_msa_loss( + logits, + true_msa, + bert_mask, + num_classes, + reduction="mean", + ) + + self.assertEqual(loss.shape, (batch_size,)) + torch.testing.assert_close(mean_loss, loss.mean()) + + def test_experimentally_resolved_loss_reduction(self): + batch_size = 2 + n_res = 4 + + logits = torch.randn(batch_size, n_res, 37) + atom37_atom_exists = torch.ones(batch_size, n_res, 37) + all_atom_mask = torch.randint(0, 2, (batch_size, n_res, 37)).float() + resolution = torch.ones(batch_size) + + loss = experimentally_resolved_loss( + logits, + atom37_atom_exists, + all_atom_mask, + resolution, + min_resolution=0.1, + max_resolution=3.0, + reduction="none", + ) + mean_loss = experimentally_resolved_loss( + logits, + atom37_atom_exists, + all_atom_mask, + resolution, + min_resolution=0.1, + max_resolution=3.0, + reduction="mean", + ) + + self.assertEqual(loss.shape, (batch_size,)) + torch.testing.assert_close(mean_loss, loss.mean()) + + def test_lddt_loss_reduction(self): + batch_size = 2 + n_res = 4 + no_bins = 6 + + logits = torch.randn(batch_size, n_res, no_bins) + all_atom_pred_pos = torch.randn(batch_size, n_res, 37, 3) + all_atom_positions = torch.randn(batch_size, n_res, 37, 3) + all_atom_mask = torch.ones(batch_size, n_res, 37) + resolution = torch.ones(batch_size) + + loss = lddt_loss( + logits, + all_atom_pred_pos, + all_atom_positions, + all_atom_mask, + resolution, + no_bins=no_bins, + reduction="none", + ) + mean_loss = lddt_loss( + logits, + all_atom_pred_pos, + all_atom_positions, + all_atom_mask, + resolution, + no_bins=no_bins, + reduction="mean", + ) + + self.assertEqual(loss.shape, (batch_size,)) + torch.testing.assert_close(mean_loss, loss.mean()) + + def test_supervised_chi_loss_reduction(self): + batch_size = 2 + n_layer = 3 + n_res = 4 + + angles = torch.randn(n_layer, batch_size, n_res, 7, 2) + unnormalized_angles = torch.randn(n_layer, batch_size, n_res, 7, 2) + aatype = torch.randint( + 0, + residue_constants.restype_num, + (batch_size, n_res), + ) + seq_mask = torch.ones(batch_size, n_res) + chi_mask = torch.ones(batch_size, n_res, 4) + chi_angles = torch.randn(batch_size, n_res, 4, 2) + + loss = supervised_chi_loss( + angles, + unnormalized_angles, + aatype, + seq_mask, + chi_mask, + chi_angles, + chi_weight=0.5, + angle_norm_weight=0.01, + reduction="none", + ) + mean_loss = supervised_chi_loss( + angles, + unnormalized_angles, + aatype, + seq_mask, + chi_mask, + chi_angles, + chi_weight=0.5, + angle_norm_weight=0.01, + reduction="mean", + ) + + self.assertEqual(loss.shape, (batch_size,)) + torch.testing.assert_close(mean_loss, loss.mean()) + + def test_backbone_loss_reduction_preserves_batch(self): + batch_size = 2 + n_layer = 3 + n_res = 4 + + traj = torch.eye(4).repeat(n_layer, batch_size, n_res, 1, 1) + traj[..., :3, 3] = torch.randn(n_layer, batch_size, n_res, 3) + backbone_rigid_tensor = torch.eye(4).repeat(batch_size, n_res, 1, 1) + backbone_rigid_tensor[..., :3, 3] = torch.randn(batch_size, n_res, 3) + backbone_rigid_mask = torch.ones(batch_size, n_res) + + loss = backbone_loss( + backbone_rigid_tensor=backbone_rigid_tensor, + backbone_rigid_mask=backbone_rigid_mask, + traj=traj, + reduction="none", + ) + mean_loss = backbone_loss( + backbone_rigid_tensor=backbone_rigid_tensor, + backbone_rigid_mask=backbone_rigid_mask, + traj=traj, + reduction="mean", + ) + + self.assertEqual(loss.shape, (batch_size,)) + torch.testing.assert_close(mean_loss, loss.mean()) + + def test_tm_loss_reduction(self): + batch_size = 2 + n_res = 4 + no_bins = 6 + + logits = torch.randn(batch_size, n_res, n_res, no_bins) + final_affine_tensor = torch.eye(4).repeat(batch_size, n_res, 1, 1) + final_affine_tensor[..., :3, 3] = torch.randn(batch_size, n_res, 3) + backbone_rigid_tensor = torch.eye(4).repeat(batch_size, n_res, 1, 1) + backbone_rigid_tensor[..., :3, 3] = torch.randn(batch_size, n_res, 3) + backbone_rigid_mask = torch.ones(batch_size, n_res) + resolution = torch.ones(batch_size) + + loss = tm_loss( + logits, + final_affine_tensor, + backbone_rigid_tensor, + backbone_rigid_mask, + resolution, + no_bins=no_bins, + reduction="none", + ) + mean_loss = tm_loss( + logits, + final_affine_tensor, + backbone_rigid_tensor, + backbone_rigid_mask, + resolution, + no_bins=no_bins, + reduction="mean", + ) + + self.assertEqual(loss.shape, (batch_size,)) + torch.testing.assert_close(mean_loss, loss.mean()) + + def test_violation_loss_reduction_uses_per_example_atom_count(self): + batch_size = 2 + n_res = 3 + + atom14_atom_exists = torch.zeros(batch_size, n_res, 14) + atom14_atom_exists[0] = 1.0 + atom14_atom_exists[1, 0, :7] = 1.0 + + between_clashes = torch.zeros_like(atom14_atom_exists) + between_clashes[0] = atom14_atom_exists[0] + between_clashes[1] = 2.0 * atom14_atom_exists[1] + + zeros = torch.zeros(batch_size) + violations = { + "between_residues": { + "bonds_c_n_loss_mean": zeros, + "angles_ca_c_n_loss_mean": zeros, + "angles_c_n_ca_loss_mean": zeros, + "clashes_per_atom_loss_sum": between_clashes, + }, + "within_residues": { + "per_atom_loss_sum": torch.zeros_like(atom14_atom_exists), + }, + } + + loss = violation_loss( + violations, + atom14_atom_exists, + reduction="none", + ) + mean_loss = violation_loss( + violations, + atom14_atom_exists, + reduction="mean", + ) + + expected = torch.tensor([1.0, 2.0]) + self.assertEqual(loss.shape, (batch_size,)) + torch.testing.assert_close(loss, expected, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(mean_loss, loss.mean()) + def test_run_torsion_angle_loss(self): batch_size = consts.batch_size n_res = consts.n_res