Skip to content

Commit efb7b9b

Browse files
committed
Subsample via mask
1 parent 5028e71 commit efb7b9b

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

chai_lab/chai1.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas
3333
from chai_lab.data.dataset.msas.load import get_msa_contexts
3434
from chai_lab.data.dataset.msas.msa_context import MSAContext
35+
from chai_lab.data.dataset.msas.utils import subsample_msa_rows
3536
from chai_lab.data.dataset.structure.all_atom_structure_context import (
3637
AllAtomStructureContext,
3738
)
@@ -441,6 +442,7 @@ def run_inference(
441442
msa_directory: Path | None = None,
442443
constraint_path: Path | None = None,
443444
# expose some params for easy tweaking
445+
recycle_msa_subsample: int = 0,
444446
num_trunk_recycles: int = 3,
445447
num_diffn_timesteps: int = 200,
446448
num_diffn_samples: int = 5,
@@ -472,6 +474,7 @@ def run_inference(
472474
num_trunk_recycles=num_trunk_recycles,
473475
num_diffn_timesteps=num_diffn_timesteps,
474476
num_diffn_samples=num_diffn_samples,
477+
recycle_msa_subsample=recycle_msa_subsample,
475478
seed=seed,
476479
device=torch_device,
477480
low_memory=low_memory,
@@ -488,6 +491,7 @@ def run_folding_on_context(
488491
*,
489492
output_dir: Path,
490493
# expose some params for easy tweaking
494+
recycle_msa_subsample: int = 0,
491495
num_trunk_recycles: int = 3,
492496
num_diffn_timesteps: int = 200,
493497
# all diffusion samples come from the same trunk
@@ -647,7 +651,7 @@ def run_folding_on_context(
647651
token_single_trunk_repr=token_single_trunk_repr, # recycled
648652
token_pair_trunk_repr=token_pair_trunk_repr, # recycled
649653
msa_input_feats=msa_input_feats,
650-
msa_mask=msa_mask,
654+
msa_mask=subsample_msa_rows(msa_mask, select_n_rows=recycle_msa_subsample),
651655
template_input_feats=template_input_feats,
652656
template_input_masks=template_input_masks,
653657
token_single_mask=token_single_mask,
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) 2024 Chai Discovery, Inc.
2+
# Licensed under the Apache License, Version 2.0.
3+
# See the LICENSE file for details.
4+
5+
import torch
6+
from einops import rearrange, repeat
7+
from torch import Tensor
8+
9+
from chai_lab.utils.typing import Bool
10+
11+
12+
def subsample_msa_rows(
13+
mask: Bool[Tensor, "1 depth tokens"],
14+
select_n_rows: int = 4096,
15+
generator: torch.Generator | None = None,
16+
) -> Bool[Tensor, "1 depth tokens"]:
17+
"""Adjust masking to look at a random subset of msas.
18+
19+
Returns input mask as-is if select_n_rows <= 0 or depth < select_n_rows."""
20+
nonnull_rows_mask = rearrange(mask.any(dim=-1), "1 d -> d")
21+
input_depth = nonnull_rows_mask.sum().item()
22+
if select_n_rows <= 0 or input_depth <= select_n_rows:
23+
return mask
24+
25+
# Select from rows of the MSA that are not fully masked out
26+
(nonnull_row_indices,) = torch.where(nonnull_rows_mask)
27+
assert (n := nonnull_row_indices.numel()) > select_n_rows
28+
permuted = torch.randperm(n, device=mask.device, generator=generator)
29+
selected_row_indices = nonnull_row_indices[permuted[:select_n_rows]]
30+
31+
# Create a mask for selected row indices
32+
selection_mask = torch.zeros_like(nonnull_rows_mask)
33+
selection_mask[selected_row_indices] = True
34+
selection_mask = repeat(selection_mask, "d -> 1 d 1")
35+
36+
return mask & selection_mask

0 commit comments

Comments
 (0)