diff --git a/README.md b/README.md
index 6f4d34fc..2bfa51dd 100644
--- a/README.md
+++ b/README.md
@@ -88,7 +88,16 @@ For user convenience, we also support automatic MSA generation via the ColabFold
How can I customize the inputs to the model further?
-For more advanced use cases, we also expose the `chai_lab.chai1.run_folding_on_context`, which allows users to construct an `AllAtomFeatureContext` manually. This allows users to specify their own templates, MSAs, embeddings, and constraints, including support for specifying covalent bonds (for example, for specifying branched ligands). We currently provide examples of how to construct an embeddings context, an MSA context, restraint contexts, and covalent bonds. We will be releasing helper methods to build template contexts soon.
+For more advanced use cases, we also expose the `chai_lab.chai1.run_folding_on_context`, which allows users to construct an `AllAtomFeatureContext` manually. This allows users to specify their own templates, MSAs, embeddings, and constraints, including support for specifying covalent bonds (for example, for specifying branched ligands). We currently provide examples of how to construct an embeddings context, an MSA context, template contexts, restraint contexts, and covalent bonds.
+
+
+
+
+
+How can I provide custom templates to Chai-1?
+
+
+Templates are loaded in two steps - (1) a `m8` file is read, providing a table of template hits to load (2) we load each hit by downloading the corresponding identifier from RCSB and parsing the corresponding chain. You can provide your own `m8` file to specify template hits of your choice, and you can also place structure cif files in the directory specified by the environment variable `CHAI_TEMPLATE_CIF_FOLDER` to specify custom (non-RCSB) structures corresponding to each identifier in the `m8` file. Note that the template loading code expects cif files to be named as `$CHAI_TEMPLATE_CIF_FOLDER/identifier.cif.gz` where `identifier` matches that provided in the `m8` file.
diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py
index c0a821c7..27d44b94 100644
--- a/chai_lab/chai1.py
+++ b/chai_lab/chai1.py
@@ -437,7 +437,6 @@ def make_all_atom_feature_context(
chains=chains,
use_sequence_hash_for_lookup=use_templates_server,
template_hits_m8=templates_path,
- template_cif_cache_folder=output_dir / "templates",
)
# Load ESM embeddings
diff --git a/chai_lab/data/dataset/templates/context.py b/chai_lab/data/dataset/templates/context.py
index 4e93d166..3430ee7f 100644
--- a/chai_lab/data/dataset/templates/context.py
+++ b/chai_lab/data/dataset/templates/context.py
@@ -3,6 +3,7 @@
# See the LICENSE file for details.
import logging
+import os
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Iterator
@@ -27,10 +28,15 @@
from chai_lab.data.parsing.templates.template_hit import TemplateHit
from chai_lab.data.sources.rdkit import RefConformerGenerator
from chai_lab.utils.defaults import default
+from chai_lab.utils.paths import downloads_path
from chai_lab.utils.typing import Bool, Float, Int, typecheck
logger = logging.getLogger(__name__)
+TEMPLATE_CIF_FOLDER = Path(
+ os.environ.get("CHAI_TEMPLATE_CIF_FOLDER", downloads_path / "template_cifs")
+)
+
@typecheck
@dataclass(frozen=True)
@@ -329,7 +335,7 @@ def get_template_context(
chains: list[Chain],
template_hits_m8: Path,
use_sequence_hash_for_lookup: bool = False,
- template_cif_cache_folder: Path | None = None,
+ template_cif_cache_folder: Path = TEMPLATE_CIF_FOLDER,
) -> TemplateContext:
"""
For each example, loads templates for cropped chain, collate the templates.
diff --git a/chai_lab/data/io/rcsb.py b/chai_lab/data/io/rcsb.py
index 3db61664..2e53d5fa 100644
--- a/chai_lab/data/io/rcsb.py
+++ b/chai_lab/data/io/rcsb.py
@@ -7,7 +7,10 @@
def download_cif_file(pdb_id: str, directory: Path) -> Path:
- """Download the cif file for the given PDB ID from RCSB into the directory."""
+ """Download the cif file for the given PDB ID from RCSB into the directory.
+
+ No-op if the directory/pdb_id.cif.gz already exists.
+ """
outfile = directory / f"{pdb_id}.cif.gz"
source_url = f"https://files.rcsb.org/download/{pdb_id}.cif.gz"
download_if_not_exists(source_url, outfile)