Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e78a40b
Add some helper functions
wukevin Feb 6, 2025
0f6b236
Add index_select helper function
wukevin Feb 6, 2025
bbdb2c2
fasta helpers
wukevin Feb 6, 2025
be4ee26
Check for kalign binary
wukevin Feb 7, 2025
f2d6c5c
Parse m8 template hits
wukevin Feb 7, 2025
8be4714
Add TemplateHit object
wukevin Feb 7, 2025
f52613a
Remove complexity
wukevin Feb 7, 2025
865d3de
logging
wukevin Feb 7, 2025
5df3f79
logging
wukevin Feb 7, 2025
1b62b53
Add example
wukevin Feb 7, 2025
b03574c
Add type
wukevin Feb 7, 2025
d5a2d8a
Add type
wukevin Feb 7, 2025
1306c00
Fill out some implementations
wukevin Feb 7, 2025
d208370
Fixes
wukevin Feb 7, 2025
117a955
Misc.
wukevin Feb 7, 2025
e1ce093
Enable MSA for templates example
wukevin Feb 7, 2025
fee1d7d
Load template hits
wukevin Feb 7, 2025
0ebea02
add rigid
wukevin Feb 7, 2025
8f52888
Tokenization
wukevin Feb 7, 2025
941480d
Update flags
wukevin Feb 7, 2025
16b5915
Create a folder for templates
wukevin Feb 8, 2025
07561f2
Logging
wukevin Feb 8, 2025
f357ec2
Fixes
wukevin Feb 8, 2025
01b1f2f
Remove debugging call
wukevin Feb 8, 2025
165e0db
Move check for kalign
wukevin Feb 8, 2025
618854d
Refactor
wukevin Feb 8, 2025
8c97ec8
Ensure correct types
wukevin Feb 8, 2025
8a62b3d
Support for querying server for templates
wukevin Feb 10, 2025
7056f46
Merge branch 'main' into kevin/templates
wukevin Feb 12, 2025
0bb67dd
Update comments
wukevin Feb 12, 2025
6009c33
Update script entry command
wukevin Feb 12, 2025
d12293f
Relax ruff version
wukevin Feb 12, 2025
7fc5e46
Fix typo
jackdent Feb 12, 2025
ef35618
Add kalign to dockerfile
wukevin Feb 12, 2025
316b935
Add note for kalign
wukevin Feb 12, 2025
79d0bba
Revert "Relax ruff version"
wukevin Feb 12, 2025
fc436e6
Revert "Update script entry command"
wukevin Feb 12, 2025
28bff14
Update README
wukevin Feb 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
from chai_lab.data.dataset.structure.bond_utils import (
get_atom_covalent_bond_pairs_from_constraints,
)
from chai_lab.data.dataset.templates.context import TemplateContext
from chai_lab.data.dataset.templates.context import (
TemplateContext,
get_template_context,
)
from chai_lab.data.features.feature_factory import FeatureFactory
from chai_lab.data.features.feature_type import FeatureType
from chai_lab.data.features.generators.atom_element import AtomElementOneHot
Expand Down Expand Up @@ -326,11 +329,16 @@ def make_all_atom_feature_context(
msa_server_url: str = "https://api.colabfold.com",
msa_directory: Path | None = None,
constraint_path: Path | None = None,
use_templates_server: bool = False,
templates_path: Path | None = None,
esm_device: torch.device = torch.device("cpu"),
):
assert not (
use_msa_server and msa_directory
), "Cannot specify both MSA server and directory"
assert not (
use_templates_server and templates_path
), "Cannot specify both templates server and path"

# Prepare inputs
assert fasta_file.exists(), fasta_file
Expand Down Expand Up @@ -366,8 +374,13 @@ def make_all_atom_feature_context(
generate_colabfold_msas(
protein_seqs=protein_sequences,
msa_dir=msa_dir,
search_templates=use_templates_server,
msa_server_url=msa_server_url,
)
if use_templates_server: # Override templates path with server path
assert templates_path is None
templates_path = msa_dir / "all_chain_templates.m8"
assert templates_path.is_file()
msa_context, msa_profile_context = get_msa_contexts(
chains, msa_directory=msa_dir
)
Expand All @@ -388,10 +401,24 @@ def make_all_atom_feature_context(
), f"Discrepant tokens in input and MSA: {merged_context.num_tokens} != {msa_context.num_tokens}"

# Load templates
template_context = TemplateContext.empty(
n_tokens=n_actual_tokens,
n_templates=MAX_NUM_TEMPLATES,
)
if templates_path is None:
assert (
not use_templates_server
), "Templates path should never be none when querying server for templates"
template_context = TemplateContext.empty(
n_tokens=n_actual_tokens,
n_templates=MAX_NUM_TEMPLATES,
)
else:
# NOTE templates m8 file should contain hits with query name matching chain entity_names
# or the hash of the chain sequence. When we query the server, we use the hash of the
# sequence to identify each hit.
template_context = get_template_context(
chains=chains,
use_sequence_hash_for_lookup=use_templates_server,
template_hits_m8=templates_path,
template_cif_cache_folder=output_dir / "templates",
)

# Load ESM embeddings
if use_esm_embeddings:
Expand Down Expand Up @@ -456,12 +483,15 @@ def run_inference(
fasta_file: Path,
*,
output_dir: Path,
# Configuration for ESM, MSA, constraints, and templates
use_esm_embeddings: bool = True,
use_msa_server: bool = False,
msa_server_url: str = "https://api.colabfold.com",
msa_directory: Path | None = None,
constraint_path: Path | None = None,
# expose some params for easy tweaking
use_templates_server: bool = False,
template_hits_path: Path | None = None,
# Parameters controlling how we do inference
recycle_msa_subsample: int = 0,
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
Expand All @@ -487,6 +517,8 @@ def run_inference(
msa_server_url=msa_server_url,
msa_directory=msa_directory,
constraint_path=constraint_path,
use_templates_server=use_templates_server,
templates_path=template_hits_path,
esm_device=torch_device,
)

Expand Down
99 changes: 30 additions & 69 deletions chai_lab/data/dataset/msas/colabfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from chai_lab.data.parsing.fasta import Fasta, read_fasta
from chai_lab.data.parsing.msas.aligned_pqt import expected_basename, hash_sequence
from chai_lab.data.parsing.msas.data_source import MSADataSource
from chai_lab.data.parsing.templates.m8 import parse_m8_file

logger = logging.getLogger(__name__)

Expand All @@ -29,6 +30,7 @@

# N.B. this function (and this function only) is copied from https://github.com/sokrypton/ColabFold
# and follows the license in that repository
# We have made modifications to how templates are returned from thsi function.
@typing.no_type_check # Original ColabFold code was not well typed
def _run_mmseqs2(
x,
Expand All @@ -41,8 +43,8 @@ def _run_mmseqs2(
pairing_strategy="greedy",
host_url="https://api.colabfold.com",
user_agent: str = "",
) -> list[str] | tuple[list[str], list[str]]:
"""Return a block of a3m lines for each of the input sequences in x."""
) -> tuple[list[str], str | None]:
"""Return a block of a3m lines and optionally template hits for each of the input sequences in x."""
submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"

headers = {}
Expand Down Expand Up @@ -257,59 +259,13 @@ def download(ID, path):
tar_gz.extractall(path)

# templates
template_path: str | None = None
if use_templates:
templates = {}
# print("seq\tpdb\tcid\tevalue")
for line in open(f"{path}/pdb70.m8", "r"):
p = line.rstrip().split()
M, pdb, _, _ = p[0], p[1], p[2], p[10]
M = int(M)
if M not in templates:
templates[M] = []
templates[M].append(pdb)
# if len(templates[M]) <= 20:
# print(f"{int(M)-N}\t{pdb}\t{qid}\t{e_value}")

template_paths = {}
for k, TMPL in templates.items():
TMPL_PATH = f"{prefix}_{mode}/templates_{k}"
if not os.path.isdir(TMPL_PATH):
os.mkdir(TMPL_PATH)
TMPL_LINE = ",".join(TMPL[:20])
response = None
while True:
error_count = 0
try:
# https://requests.readthedocs.io/en/latest/user/advanced/#advanced
# "good practice to set connect timeouts to slightly larger than a multiple of 3"
response = requests.get(
f"{host_url}/template/{TMPL_LINE}",
stream=True,
timeout=6.02,
headers=headers,
)
except requests.exceptions.Timeout:
logger.warning(
"Timeout while submitting to template server. Retrying..."
)
continue
except Exception as e:
error_count += 1
logger.warning(
f"Error while fetching result from template server. Retrying... ({error_count}/5)"
)
logger.warning(f"Error: {e}")
time.sleep(5)
if error_count > 5:
raise
continue
break
with tarfile.open(fileobj=response.raw, mode="r|gz") as tar:
tar.extractall(path=TMPL_PATH)
os.symlink("pdb70_a3m.ffindex", f"{TMPL_PATH}/pdb70_cs219.ffindex")
with open(f"{TMPL_PATH}/pdb70_cs219.ffdata", "w") as f:
f.write("")
template_paths[k] = TMPL_PATH
# NOTE this section has been significantly reduced to enable Chai-1 to take m8 files
# as a common input format, while also reducing how much we ping the server.
template_path = os.path.join(path, "pdb70.m8")
assert os.path.isfile(template_path)

# gather a3m lines
a3m_lines = {}
Expand All @@ -327,21 +283,8 @@ def download(ID, path):
a3m_lines[M] = []
a3m_lines[M].append(line)

# return results

a3m_lines = ["".join(a3m_lines[n]) for n in Ms]

if use_templates:
template_paths_ = []
for n in Ms:
if n not in template_paths:
template_paths_.append(None)
# print(f"{n-N}\tno_templates_found")
else:
template_paths_.append(template_paths[n])
template_paths = template_paths_

return (a3m_lines, template_paths) if use_templates else a3m_lines
return a3m_lines, template_path


def _is_padding_msa_row(sequence: str) -> bool:
Expand All @@ -354,6 +297,7 @@ def generate_colabfold_msas(
protein_seqs: list[str],
msa_dir: Path,
msa_server_url: str,
search_templates: bool = False,
write_a3m_to_msa_dir: bool = False, # Useful for manual inspection + debugging
):
"""
Expand Down Expand Up @@ -398,10 +342,11 @@ def generate_colabfold_msas(
# as the i-th index of the sequence so long as it isn't a padding sequence (all -)
paired_msas: list[str]
if len(protein_seqs) > 1:
paired_msas = _run_mmseqs2(
paired_msas, _ = _run_mmseqs2(
protein_seqs,
mmseqs_paired_dir,
use_pairing=True,
use_templates=False, # No templates when running paired search
host_url=msa_server_url,
user_agent=user_agent,
)
Expand All @@ -411,13 +356,29 @@ def generate_colabfold_msas(

# MSAs without pairing logic attached; may include sequences not contained in the paired MSA
# Needs a second call as the colabfold server returns either paired or unpaired, not both
per_chain_msas = _run_mmseqs2(
per_chain_msas, template_hits_file = _run_mmseqs2(
protein_seqs,
mmseqs_dir,
use_pairing=False,
use_templates=search_templates,
host_url=msa_server_url,
user_agent=user_agent,
)
if search_templates:
assert template_hits_file is not None and os.path.isfile(template_hits_file)
all_templates = parse_m8_file(Path(template_hits_file))
# query IDs are 101, 102, ... from the server; remap IDs
query_map = {}
for orig_query_id, orig_seq in enumerate(protein_seqs, start=101):
h = hash_sequence(orig_seq)
query_map[orig_query_id] = h
all_templates["query_id"] = all_templates["query_id"].apply(query_map.get)
assert not pd.isnull(all_templates["query_id"]).any()

logger.info(f"Found {len(all_templates)} template hits")
all_templates.to_csv(
msa_dir / "all_chain_templates.m8", index=False, header=False, sep="\t"
)

# Process the MSAs into our internal format
for protein_seq, pair_msa, single_msa in zip(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,9 @@ def _tokenize_entity(
for residue in entity_data.residues
]

valid_residues = [x for x in tokenized_residues if x is not None]
valid_residues: list[TokenSpan] = [
x for x in tokenized_residues if x is not None
]
if len(valid_residues) == 0:
logger.warning(
f"Got no residues for entity {entity_data.entity_id} with residues {entity_data.residues}"
Expand Down
98 changes: 98 additions & 0 deletions chai_lab/data/dataset/structure/all_atom_structure_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,104 @@ def __post_init__(self):
def residue_names(self) -> list[str]:
return batch_tensorcode_to_string(self.token_residue_name)

@typecheck
def index_select(self, idxs: Int[Tensor, "n"]) -> "AllAtomStructureContext":
"""
Selects a subset of the data in the context, reindexing the tokens and atoms in
the new context (i.e. the new context will be indexed from 0).

Parameters
----------
idxs : Int[Tensor, "n"]
The indices of the tokens to select.

Returns
-------
AllAtomStructureContext
A new context with the selected tokens and atoms.
"""
assert ((idxs >= 0) & (idxs < self.num_tokens)).all()

# get atoms to keep
selected_atom_index = torch.where(
(self.atom_token_index == idxs[..., None]).any(dim=0)
)[0]

# rebuild token index and atom-token index
token_index = torch.arange(len(idxs))

atom_token_index, selected_atom_idx = torch.where(
self.atom_token_index == idxs[..., None]
)

def _reselect_atom_indices(
prior_atom_index: Int[Tensor, "n_tokens"],
) -> Int[Tensor, "n_tokens_new"]:
mask = torch.zeros(self.num_atoms, dtype=torch.bool)
mask[prior_atom_index] = True
selected_mask = mask[selected_atom_idx]
return torch.where(selected_mask)[0]

token_centre_atom_index = _reselect_atom_indices(self.token_centre_atom_index)
token_ref_atom_index = _reselect_atom_indices(self.token_ref_atom_index)

atom_covalent_bond_indices = None
if self.atom_covalent_bond_indices is not None:
left_idx, right_idx = self.atom_covalent_bond_indices
atom_pairs = torch.zeros(self.num_atoms, self.num_atoms, dtype=torch.bool)
atom_pairs[left_idx, right_idx] = True
selected_atom_pairs = atom_pairs[selected_atom_idx][:, selected_atom_idx]
new_left, new_right = torch.where(selected_atom_pairs)
atom_covalent_bond_indices = new_left, new_right

token_backbone_frame_atom_index = torch.stack(
[
_reselect_atom_indices(x)
for x in torch.unbind(self.token_backbone_frame_index, dim=-1)
],
dim=-1,
)

return AllAtomStructureContext(
# token-level
token_residue_type=self.token_residue_type[idxs],
token_residue_index=self.token_residue_index[idxs],
token_index=token_index,
token_centre_atom_index=token_centre_atom_index,
token_ref_atom_index=token_ref_atom_index,
token_exists_mask=self.token_exists_mask[idxs],
token_backbone_frame_mask=self.token_backbone_frame_mask[idxs],
token_backbone_frame_index=token_backbone_frame_atom_index,
token_asym_id=self.token_asym_id[idxs],
token_entity_id=self.token_entity_id[idxs],
token_sym_id=self.token_sym_id[idxs],
token_entity_type=self.token_entity_type[idxs],
token_residue_name=self.token_residue_name[idxs],
token_b_factor_or_plddt=self.token_b_factor_or_plddt[idxs],
# atom-level
atom_token_index=atom_token_index,
atom_within_token_index=self.atom_within_token_index[selected_atom_index],
atom_ref_pos=self.atom_ref_pos[selected_atom_index],
atom_ref_mask=self.atom_ref_mask[selected_atom_index],
atom_ref_element=self.atom_ref_element[selected_atom_index],
atom_ref_charge=self.atom_ref_charge[selected_atom_index],
atom_ref_name=[self.atom_ref_name[i] for i in selected_atom_index],
atom_ref_name_chars=self.atom_ref_name_chars[selected_atom_index],
atom_ref_space_uid=self.atom_ref_space_uid[selected_atom_index],
atom_is_not_padding_mask=self.atom_is_not_padding_mask[selected_atom_index],
# supervision-only
atom_gt_coords=self.atom_gt_coords[selected_atom_index],
atom_exists_mask=self.atom_exists_mask[selected_atom_index],
# structure-level
pdb_id=self.pdb_id[idxs],
source_pdb_chain_id=self.source_pdb_chain_id[idxs],
subchain_id=self.subchain_id[idxs],
resolution=self.resolution,
is_distillation=self.is_distillation,
symmetries=self.symmetries[selected_atom_index],
atom_covalent_bond_indices=atom_covalent_bond_indices,
)

def report_bonds(self) -> None:
"""Log information about covalent bonds."""
for i, (atom_a, atom_b) in enumerate(zip(*self.atom_covalent_bond_indices)):
Expand Down
Loading