`def chain_center_of_mass_loss(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
asym_id: torch.Tensor,
clamp_distance: float = -4.0,
weight: float = 0.05,
eps: float = 1e-10, **kwargs
) -> torch.Tensor:
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)]
one_hot = torch.nn.functional.one_hot(asym_id.long()).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).to(dtype=all_atom_positions.dtype)
def get_chain_center_of_mass(pos):
center_sum = (chain_pos_mask[..., None] * pos[..., None, :, :]).sum(dim=-2)
centers = center_sum / (torch.sum(chain_pos_mask, dim=-1, keepdim=True) + eps)
return Vec3Array.from_array(centers)
pred_centers = get_chain_center_of_mass(all_atom_pred_pos) # [B, NC, 3]
true_centers = get_chain_center_of_mass(all_atom_positions) # [B, NC, 3]
pred_dists = euclidean_distance(pred_centers[..., None, :], pred_centers[..., :, None], epsilon=eps)
true_dists = euclidean_distance(true_centers[..., None, :], true_centers[..., :, None], epsilon=eps)
**losses = torch.clamp((weight * (pred_dists - true_dists - clamp_distance)), max=0) ** 2**
loss_mask = chain_exists[..., :, None] * chain_exists[..., None, :]
loss = masked_mean(loss_mask, losses, dim=(-1, -2))
return loss`
Here, the loss calculated is already multiplied by the weight, but then again in the cumulative loss calculation,
cum_loss = cum_loss + weight * loss losses[loss_name] = loss.detach().clone()
the loss is again multiplied by the weight - so doesn't this penalize the loss twice?
`def chain_center_of_mass_loss(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
asym_id: torch.Tensor,
clamp_distance: float = -4.0,
weight: float = 0.05,
eps: float = 1e-10, **kwargs
) -> torch.Tensor:
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)]
Here, the loss calculated is already multiplied by the weight, but then again in the cumulative loss calculation,
cum_loss = cum_loss + weight * loss losses[loss_name] = loss.detach().clone()the loss is again multiplied by the weight - so doesn't this penalize the loss twice?