@@ -2835,6 +2835,7 @@ def sample(
2835
2835
use_tqdm_pbar = True ,
2836
2836
tqdm_pbar_title = 'sampling time step' ,
2837
2837
return_all_timesteps = False ,
2838
+ verifier : Module | None = None ,
2838
2839
** network_condition_kwargs
2839
2840
) -> Float ['b m 3' ] | Float ['ts b m 3' ]:
2840
2841
@@ -6770,6 +6771,7 @@ def forward(
6770
6771
num_recycling_steps : int = 1 ,
6771
6772
diffusion_add_bond_loss : bool = False ,
6772
6773
diffusion_add_smooth_lddt_loss : bool = False ,
6774
+ diffusion_verifier : Module | None = None ,
6773
6775
distogram_atom_indices : Int ['b n' ] | None = None ,
6774
6776
molecule_atom_indices : Int ['b n' ] | None = None , # the 'token centre atoms' mentioned in the paper, unsure where it is used in the architecture
6775
6777
num_sample_steps : int | None = None ,
@@ -7187,7 +7189,8 @@ def forward(
7187
7189
pairwise_trunk = pairwise ,
7188
7190
pairwise_rel_pos_feats = relative_position_encoding ,
7189
7191
molecule_atom_lens = molecule_atom_lens ,
7190
- return_all_timesteps = return_all_diffused_atom_pos
7192
+ return_all_timesteps = return_all_diffused_atom_pos ,
7193
+ verifier = diffusion_verifier
7191
7194
)
7192
7195
7193
7196
if exists (atom_mask ):
0 commit comments