@@ -7407,6 +7407,8 @@ def forward(
7407
7407
7408
7408
denoised_molecule_pos = denoised_atom_pos .gather (1 , distogram_atom_coords_indices )
7409
7409
7410
+ # get frames atom positions
7411
+
7410
7412
# three_atoms = einx.get_at('b [m] c, b n three -> three b n c', atom_pos, atom_indices_for_frame)
7411
7413
# pred_three_atoms = einx.get_at('b [m] c, b n three -> three b n c', denoised_atom_pos, atom_indices_for_frame)
7412
7414
@@ -7421,10 +7423,8 @@ def forward(
7421
7423
three_atoms = three_atom_pos .gather (2 , atom_indices_for_frame )
7422
7424
pred_three_atoms = three_denoised_atom_pos .gather (2 , atom_indices_for_frame )
7423
7425
7424
- # compute frames
7425
-
7426
- frames , _ = self .rigid_from_three_points (three_atoms )
7427
- pred_frames , _ = self .rigid_from_three_points (pred_three_atoms )
7426
+ frame_atoms = rearrange (three_atoms , "three b n c -> b n c three" )
7427
+ pred_frame_atoms = rearrange (pred_three_atoms , "three b n c -> b n c three" )
7428
7428
7429
7429
# determine mask
7430
7430
# must be amino acid, nucleotide, or ligand with greater than 0 atoms
@@ -7436,8 +7436,8 @@ def forward(
7436
7436
align_error = self .compute_alignment_error (
7437
7437
denoised_molecule_pos ,
7438
7438
molecule_pos ,
7439
- pred_frames ,
7440
- frames ,
7439
+ pred_frame_atoms , # In the paragraph 2 of section 4.3.2, the Phi_i denotes the coordinates of these frame atoms rather than the rotation matrix.
7440
+ frame_atoms ,
7441
7441
mask = align_error_mask ,
7442
7442
)
7443
7443
0 commit comments