Skip to content

Commit 8266776

Browse files
committed
Fix bugs related to pAE
1 parent 3415855 commit 8266776

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

alphafold3_pytorch/alphafold3.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -7407,6 +7407,8 @@ def forward(
74077407

74087408
denoised_molecule_pos = denoised_atom_pos.gather(1, distogram_atom_coords_indices)
74097409

7410+
# get frames atom positions
7411+
74107412
# three_atoms = einx.get_at('b [m] c, b n three -> three b n c', atom_pos, atom_indices_for_frame)
74117413
# pred_three_atoms = einx.get_at('b [m] c, b n three -> three b n c', denoised_atom_pos, atom_indices_for_frame)
74127414

@@ -7421,10 +7423,8 @@ def forward(
74217423
three_atoms = three_atom_pos.gather(2, atom_indices_for_frame)
74227424
pred_three_atoms = three_denoised_atom_pos.gather(2, atom_indices_for_frame)
74237425

7424-
# compute frames
7425-
7426-
frames, _ = self.rigid_from_three_points(three_atoms)
7427-
pred_frames, _ = self.rigid_from_three_points(pred_three_atoms)
7426+
frame_atoms = rearrange(three_atoms, "three b n c -> b n c three")
7427+
pred_frame_atoms = rearrange(pred_three_atoms, "three b n c -> b n c three")
74287428

74297429
# determine mask
74307430
# must be amino acid, nucleotide, or ligand with greater than 0 atoms
@@ -7436,8 +7436,8 @@ def forward(
74367436
align_error = self.compute_alignment_error(
74377437
denoised_molecule_pos,
74387438
molecule_pos,
7439-
pred_frames,
7440-
frames,
7439+
pred_frame_atoms, # In the paragraph 2 of section 4.3.2, the Phi_i denotes the coordinates of these frame atoms rather than the rotation matrix.
7440+
frame_atoms,
74417441
mask=align_error_mask,
74427442
)
74437443

0 commit comments

Comments
 (0)