Skip to content

Commit a36b554

Browse files
committed
verifier needed for inference time scaling
1 parent 8617eaf commit a36b554

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

alphafold3_pytorch/alphafold3.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2835,6 +2835,7 @@ def sample(
28352835
use_tqdm_pbar = True,
28362836
tqdm_pbar_title = 'sampling time step',
28372837
return_all_timesteps = False,
2838+
verifier: Module | None = None,
28382839
**network_condition_kwargs
28392840
) -> Float['b m 3'] | Float['ts b m 3']:
28402841

@@ -6770,6 +6771,7 @@ def forward(
67706771
num_recycling_steps: int = 1,
67716772
diffusion_add_bond_loss: bool = False,
67726773
diffusion_add_smooth_lddt_loss: bool = False,
6774+
diffusion_verifier: Module | None = None,
67736775
distogram_atom_indices: Int['b n'] | None = None,
67746776
molecule_atom_indices: Int['b n'] | None = None, # the 'token centre atoms' mentioned in the paper, unsure where it is used in the architecture
67756777
num_sample_steps: int | None = None,
@@ -7187,7 +7189,8 @@ def forward(
71877189
pairwise_trunk = pairwise,
71887190
pairwise_rel_pos_feats = relative_position_encoding,
71897191
molecule_atom_lens = molecule_atom_lens,
7190-
return_all_timesteps = return_all_diffused_atom_pos
7192+
return_all_timesteps = return_all_diffused_atom_pos,
7193+
verifier = diffusion_verifier
71917194
)
71927195

71937196
if exists(atom_mask):

0 commit comments

Comments
 (0)