Skip to content

Potential bug in chain_center_of_mass_loss? #540

@abhinavb22

Description

@abhinavb22

`def chain_center_of_mass_loss(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
asym_id: torch.Tensor,
clamp_distance: float = -4.0,
weight: float = 0.05,
eps: float = 1e-10, **kwargs
) -> torch.Tensor:
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)]

one_hot = torch.nn.functional.one_hot(asym_id.long()).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).to(dtype=all_atom_positions.dtype)

def get_chain_center_of_mass(pos):
    center_sum = (chain_pos_mask[..., None] * pos[..., None, :, :]).sum(dim=-2)
    centers = center_sum / (torch.sum(chain_pos_mask, dim=-1, keepdim=True) + eps)
    return Vec3Array.from_array(centers)

pred_centers = get_chain_center_of_mass(all_atom_pred_pos)  # [B, NC, 3]
true_centers = get_chain_center_of_mass(all_atom_positions)  # [B, NC, 3]

pred_dists = euclidean_distance(pred_centers[..., None, :], pred_centers[..., :, None], epsilon=eps)
true_dists = euclidean_distance(true_centers[..., None, :], true_centers[..., :, None], epsilon=eps)
**losses = torch.clamp((weight * (pred_dists - true_dists - clamp_distance)), max=0) ** 2**
loss_mask = chain_exists[..., :, None] * chain_exists[..., None, :]

loss = masked_mean(loss_mask, losses, dim=(-1, -2))
return loss`

Here, the loss calculated is already multiplied by the weight, but then again in the cumulative loss calculation,
cum_loss = cum_loss + weight * loss losses[loss_name] = loss.detach().clone()
the loss is again multiplied by the weight - so doesn't this penalize the loss twice?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions