Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions openfold/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,34 +342,29 @@ def supervised_chi_loss(

Args:
angles_sin_cos:
[*, N, 7, 2] predicted angles
[I, *, N, 7, 2] predicted angles, where I is the
StructureModule block axis and * are optional batch dimensions
unnormalized_angles_sin_cos:
The same angles, but unnormalized
aatype:
[*, N] residue indices
seq_mask:
[*, N] sequence mask
chi_mask:
[*, N, 7] angle mask
[*, N, 4] chi angle mask
chi_angles_sin_cos:
[*, N, 7, 2] ground truth angles
[*, N, 4, 2] ground truth chi angles
chi_weight:
Weight for the angle component of the loss
angle_norm_weight:
Weight for the normalization component of the loss
Returns:
[*] loss tensor
Scalar loss tensor
"""
pred_angles = angles_sin_cos[..., 3:, :]
residue_type_one_hot = torch.nn.functional.one_hot(
aatype,
residue_constants.restype_num + 1,
)
chi_pi_periodic = torch.einsum(
"...ij,jk->ik",
residue_type_one_hot.type(angles_sin_cos.dtype),
angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
)
chi_pi_periodic = angles_sin_cos.new_tensor(
residue_constants.chi_pi_periodic,
)[aatype, ...]

true_chi = chi_angles_sin_cos[None]

Expand Down
33 changes: 33 additions & 0 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,39 @@ def run_supervised_chi_loss(value, batch):

self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

def test_supervised_chi_loss_preserves_batched_chi_periodicity(self):
num_blocks, batch_size, n_res = 1, 2, 1
angles = torch.zeros(num_blocks, batch_size, n_res, 7, 2)
unnormalized_angles = torch.ones_like(angles)

chi_angles_sin_cos = torch.zeros(batch_size, n_res, 4, 2)
chi_angles_sin_cos[..., 1, :] = torch.tensor([0.0, 1.0])
angles[..., 3:, :] = chi_angles_sin_cos.unsqueeze(0)
angles[..., 3 + 1, :] = torch.tensor([0.0, -1.0])

aatype = torch.tensor(
[
[residue_constants.restype_order["D"]],
[residue_constants.restype_order["R"]],
]
)
seq_mask = torch.ones(batch_size, n_res)
chi_mask = torch.zeros(batch_size, n_res, 4)
chi_mask[..., 1] = 1.0

loss = supervised_chi_loss(
angles_sin_cos=angles,
unnormalized_angles_sin_cos=unnormalized_angles,
aatype=aatype,
seq_mask=seq_mask,
chi_mask=chi_mask,
chi_angles_sin_cos=chi_angles_sin_cos,
chi_weight=1.0,
angle_norm_weight=0.0,
)

self.assertTrue(torch.isclose(loss, torch.tensor(2.0), atol=1e-3))

@compare_utils.skip_unless_alphafold_installed()
def test_violation_loss(self):
config = compare_utils.get_alphafold_config()
Expand Down