From 61af92e18e9f831c06447531ad44f306c4ed1c51 Mon Sep 17 00:00:00 2001 From: tintinrevient Date: Tue, 29 Jul 2025 17:02:29 +0200 Subject: [PATCH 1/3] Centralise manifest in pg2-benchmark --- models/esm/Dockerfile | 12 +++- models/esm/{esm.toml => manifest.toml} | 4 +- models/esm/pyproject.toml | 2 + models/esm/src/pg2_model_esm/__main__.py | 14 ++--- models/pls/Dockerfile | 12 +++- models/pls/{pls.toml => manifest.toml} | 0 models/pls/pyproject.toml | 2 + models/pls/src/pg2_model_pls/__main__.py | 2 +- models/pls/src/pg2_model_pls/manifest.py | 13 ----- models/pls/src/pg2_model_pls/utils.py | 2 +- models/pls/uv.lock | 58 +++++++++++++++++-- .../pg2_benchmark}/manifest.py | 7 +-- supervised/dvc.lock | 48 ++++++++------- supervised/dvc.yaml | 8 +-- zero_shot/dvc.lock | 17 +++--- zero_shot/dvc.yaml | 8 +-- 16 files changed, 128 insertions(+), 81 deletions(-) rename models/esm/{esm.toml => manifest.toml} (100%) rename models/pls/{pls.toml => manifest.toml} (100%) delete mode 100644 models/pls/src/pg2_model_pls/manifest.py rename {models/esm/src/pg2_model_esm => src/pg2_benchmark}/manifest.py (76%) diff --git a/models/esm/Dockerfile b/models/esm/Dockerfile index ff64d733..2084d6c0 100644 --- a/models/esm/Dockerfile +++ b/models/esm/Dockerfile @@ -12,8 +12,14 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ WORKDIR /opt/program -COPY ./README.md ./README.md -COPY ./pyproject.toml ./pyproject.toml +# Copy pg2-benchmark source +COPY ./README.md ./pg2-benchmark/README.md +COPY ./pyproject.toml ./pg2-benchmark/pyproject.toml +COPY ./src ./pg2-benchmark/src + +# Copy pg2-model-esm dependencies +COPY ./models/esm/README.md ./README.md +COPY ./models/esm/pyproject.toml ./pyproject.toml # TODO: After pg2-dataset is public, below can be removed: # Currently, it is required to git clone pg2-dataset, which is a private repo. @@ -28,6 +34,6 @@ RUN --mount=type=secret,id=git_auth \ RUN uv sync --no-cache -COPY ./src ./src +COPY ./models/esm/src ./src ENTRYPOINT ["uv", "run", "pg2-model"] diff --git a/models/esm/esm.toml b/models/esm/manifest.toml similarity index 100% rename from models/esm/esm.toml rename to models/esm/manifest.toml index e2d4b5a6..68b62572 100644 --- a/models/esm/esm.toml +++ b/models/esm/manifest.toml @@ -1,6 +1,6 @@ name = "esm" -location = "esm2_t30_150M_UR50D" -scoring_strategy = "wt-marginals" [hyper_params] +location = "esm2_t30_150M_UR50D" +scoring_strategy = "wt-marginals" offset_idx = 24 \ No newline at end of file diff --git a/models/esm/pyproject.toml b/models/esm/pyproject.toml index 69b3d5ba..f4e73231 100644 --- a/models/esm/pyproject.toml +++ b/models/esm/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.12" dependencies = [ "typer~=0.16.0", "pg2-dataset[biopython]", + "pg2-benchmark", "fair-esm~=2.0.0", "torch~=2.7.1", "bio~=1.8.0", @@ -23,6 +24,7 @@ build-backend = "hatchling.build" [tool.uv.sources] pg2-dataset = { git = "https://github.com/ProteinGym2/pg2-dataset.git", rev = "58c327e13bade1effe1312eb2b8d5445016a5a8f" } +pg2-benchmark = { path = "./pg2-benchmark", editable = true } [tool.hatch.build.targets.wheel] packages = ["src/pg2_model_esm"] diff --git a/models/esm/src/pg2_model_esm/__main__.py b/models/esm/src/pg2_model_esm/__main__.py index 4b108187..54215284 100644 --- a/models/esm/src/pg2_model_esm/__main__.py +++ b/models/esm/src/pg2_model_esm/__main__.py @@ -5,7 +5,7 @@ from tqdm import tqdm from esm import pretrained from pg2_model_esm.utils import compute_pppl, label_row -from pg2_model_esm.manifest import Manifest as ModelManifest +from pg2_benchmark.manifest import Manifest as ModelManifest app = typer.Typer( @@ -42,15 +42,13 @@ def train( model_manifest = ModelManifest.from_path(model_toml_file) model_name = model_manifest.name - location = model_manifest.location - scoring_strategy = model_manifest.scoring_strategy hyper_params = model_manifest.hyper_params - model, alphabet = pretrained.load_model_and_alphabet(location) + model, alphabet = pretrained.load_model_and_alphabet(hyper_params["location"]) model.eval() console.print( - f"Loaded the model from {location} with scoring strategy {scoring_strategy}." + f"Loaded the model from {hyper_params['location']} with scoring strategy {hyper_params['scoring_strategy']}." ) if torch.cuda.is_available() and not nogpu: @@ -65,7 +63,7 @@ def train( batch_labels, batch_strs, batch_tokens = batch_converter(data) - match scoring_strategy: + match hyper_params["scoring_strategy"]: case "wt-marginals": with torch.no_grad(): token_probs = torch.log_softmax(model(batch_tokens)["logits"], dim=-1) @@ -123,7 +121,9 @@ def train( ) case _: - err_console.print(f"Error: Invalid scoring strategy: {scoring_strategy}") + err_console.print( + f"Error: Invalid scoring strategy: {hyper_params['scoring_strategy']}" + ) df.rename(columns={targets[0]: "test"}, inplace=True) df.to_csv(f"/output/{dataset_name}_{model_name}.csv", index=False) diff --git a/models/pls/Dockerfile b/models/pls/Dockerfile index ff64d733..95887a8f 100644 --- a/models/pls/Dockerfile +++ b/models/pls/Dockerfile @@ -12,8 +12,14 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ WORKDIR /opt/program -COPY ./README.md ./README.md -COPY ./pyproject.toml ./pyproject.toml +# Copy pg2-benchmark source +COPY ./README.md ./pg2-benchmark/README.md +COPY ./pyproject.toml ./pg2-benchmark/pyproject.toml +COPY ./src ./pg2-benchmark/src + +# Copy pg2-model-pls dependencies +COPY ./models/pls/README.md ./README.md +COPY ./models/pls/pyproject.toml ./pyproject.toml # TODO: After pg2-dataset is public, below can be removed: # Currently, it is required to git clone pg2-dataset, which is a private repo. @@ -28,6 +34,6 @@ RUN --mount=type=secret,id=git_auth \ RUN uv sync --no-cache -COPY ./src ./src +COPY ./models/pls/src ./src ENTRYPOINT ["uv", "run", "pg2-model"] diff --git a/models/pls/pls.toml b/models/pls/manifest.toml similarity index 100% rename from models/pls/pls.toml rename to models/pls/manifest.toml diff --git a/models/pls/pyproject.toml b/models/pls/pyproject.toml index 7123b343..19282a03 100644 --- a/models/pls/pyproject.toml +++ b/models/pls/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "numpy~=2.2.5", "scikit-learn~=1.7.0", "pg2-dataset[biopython]", + "pg2-benchmark", "polars[pyarrow]~=1.31.0", ] @@ -21,6 +22,7 @@ build-backend = "hatchling.build" [tool.uv.sources] pg2-dataset = { git = "https://github.com/ProteinGym2/pg2-dataset.git", rev = "58c327e13bade1effe1312eb2b8d5445016a5a8f" } +pg2-benchmark = { path = "./pg2-benchmark", editable = true } [tool.hatch.build.targets.wheel] packages = ["src/pg2_model_pls"] diff --git a/models/pls/src/pg2_model_pls/__main__.py b/models/pls/src/pg2_model_pls/__main__.py index 624ab12d..1f7bf8b2 100644 --- a/models/pls/src/pg2_model_pls/__main__.py +++ b/models/pls/src/pg2_model_pls/__main__.py @@ -2,7 +2,7 @@ from rich.console import Console from pg2_dataset.dataset import Manifest from pg2_dataset.splits.abstract_split_strategy import TrainTestValid -from pg2_model_pls.manifest import Manifest as ModelManifest +from pg2_benchmark.manifest import Manifest as ModelManifest from pg2_model_pls.utils import load_x_and_y, train_model, predict_model import typer diff --git a/models/pls/src/pg2_model_pls/manifest.py b/models/pls/src/pg2_model_pls/manifest.py deleted file mode 100644 index fd299c3b..00000000 --- a/models/pls/src/pg2_model_pls/manifest.py +++ /dev/null @@ -1,13 +0,0 @@ -from pydantic import BaseModel, Field -from pathlib import Path -from typing import Self, Any -import toml - - -class Manifest(BaseModel): - name: str = "" - hyper_params: dict[str, Any] = Field(default_factory=dict) - - @classmethod - def from_path(cls, toml_file: Path) -> Self: - return cls.model_validate(toml.load(toml_file)) diff --git a/models/pls/src/pg2_model_pls/utils.py b/models/pls/src/pg2_model_pls/utils.py index 93a8e8e3..d65e2d57 100644 --- a/models/pls/src/pg2_model_pls/utils.py +++ b/models/pls/src/pg2_model_pls/utils.py @@ -5,7 +5,7 @@ from pg2_dataset.dataset import Manifest from pg2_dataset.backends.assays import SPLIT_STRATEGY_MAPPING from pg2_dataset.splits.abstract_split_strategy import TrainTestValid -from pg2_model_pls.manifest import Manifest as ModelManifest +from pg2_benchmark.manifest import Manifest as ModelManifest import logging logger = logging.getLogger(__name__) diff --git a/models/pls/uv.lock b/models/pls/uv.lock index 303463f9..6d9d100e 100644 --- a/models/pls/uv.lock +++ b/models/pls/uv.lock @@ -158,6 +158,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128", size = 9566, upload_time = "2020-05-11T07:59:49.499Z" }, ] +[[package]] +name = "art" +version = "6.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d4/7d/7d80509bbd19fb747edef94ba487dbadd2747944774ccc0528ad0d005a36/art-6.5.tar.gz", hash = "sha256:a98d77b42c278697ec6cf4b5bdcdfd997f6b2425332da078d4e31e31377d1844", size = 672902, upload_time = "2025-04-12T17:02:20.279Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/29/57b06fdb3abdf52c621d3ca3caea735e2db4c8d48288ebd26af448e8e247/art-6.5-py3-none-any.whl", hash = "sha256:70706408144c45c666caab690627d5c74aea7b6c7ce8cc968408ddeef8d84afd", size = 610382, upload_time = "2025-04-12T17:02:21.97Z" }, +] + [[package]] name = "asyncssh" version = "2.21.0" @@ -1385,6 +1394,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload_time = "2023-12-10T22:30:43.14Z" }, ] +[[package]] +name = "pg2-benchmark" +version = "0.1.0" +source = { editable = "../../" } +dependencies = [ + { name = "dvc" }, + { name = "pg2-dataset", extra = ["biopython"] }, + { name = "pycm" }, + { name = "scipy" }, + { name = "typer" }, +] + +[package.metadata] +requires-dist = [ + { name = "dvc", specifier = ">=3.59.2" }, + { name = "pg2-dataset", extras = ["biopython"], git = "https://github.com/ProteinGym2/pg2-dataset.git?rev=58c327e13bade1effe1312eb2b8d5445016a5a8f" }, + { name = "pycm", specifier = ">=4.3" }, + { name = "scipy", specifier = ">=1.15.3" }, + { name = "typer", specifier = ">=0.15.2" }, +] + +[package.metadata.requires-dev] +dev = [{ name = "pre-commit", specifier = ">=4.2.0" }] + [[package]] name = "pg2-dataset" version = "0.1.0" @@ -1410,6 +1443,7 @@ version = "0.1.0" source = { editable = "." } dependencies = [ { name = "numpy" }, + { name = "pg2-benchmark" }, { name = "pg2-dataset", extra = ["biopython"] }, { name = "polars", extra = ["pyarrow"] }, { name = "scikit-learn" }, @@ -1423,15 +1457,16 @@ dev = [ [package.metadata] requires-dist = [ - { name = "numpy", specifier = "==2.2.5" }, + { name = "numpy", specifier = "~=2.2.5" }, + { name = "pg2-benchmark", editable = "../../" }, { name = "pg2-dataset", extras = ["biopython"], git = "https://github.com/ProteinGym2/pg2-dataset.git?rev=58c327e13bade1effe1312eb2b8d5445016a5a8f" }, - { name = "polars", extras = ["pyarrow"], specifier = "==1.31.0" }, - { name = "scikit-learn", specifier = "==1.7.0" }, - { name = "typer", specifier = "==0.16.0" }, + { name = "polars", extras = ["pyarrow"], specifier = "~=1.31.0" }, + { name = "scikit-learn", specifier = "~=1.7.0" }, + { name = "typer", specifier = "~=0.16.0" }, ] [package.metadata.requires-dev] -dev = [{ name = "pytest", specifier = "==8.4.1" }] +dev = [{ name = "pytest", specifier = "~=8.4.1" }] [[package]] name = "platformdirs" @@ -1636,6 +1671,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload_time = "2025-03-28T02:41:19.028Z" }, ] +[[package]] +name = "pycm" +version = "4.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "art" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/81/fa/98cf4caac67e7574874201706ef9b6388cfe7480606b2d65b6bcf7edbe9b/pycm-4.3.tar.gz", hash = "sha256:63b45a34f716fbd9b169cbcfc2df78982d3c0d10231425faa79dcaf1974f2e1c", size = 918510, upload_time = "2025-04-04T14:07:52.971Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/24/bd/8fab2781cbbd809c06605d5256105d41ca209dec2ffffeed7d47ada5db1d/pycm-4.3-py3-none-any.whl", hash = "sha256:95ebfe0771f871413de6f245c1129bd136eff0028d340526dd7ec1abc67024f4", size = 70912, upload_time = "2025-04-04T14:07:54.976Z" }, +] + [[package]] name = "pycparser" version = "2.22" diff --git a/models/esm/src/pg2_model_esm/manifest.py b/src/pg2_benchmark/manifest.py similarity index 76% rename from models/esm/src/pg2_model_esm/manifest.py rename to src/pg2_benchmark/manifest.py index 75d6e9e0..60a6becf 100644 --- a/models/esm/src/pg2_model_esm/manifest.py +++ b/src/pg2_benchmark/manifest.py @@ -1,16 +1,15 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from pathlib import Path from typing import Self, Any import toml class Manifest(BaseModel): + model_config = ConfigDict(extra="allow") + name: str = "" hyper_params: dict[str, Any] = Field(default_factory=dict) - location: str = "" - scoring_strategy: str = "" - @classmethod def from_path(cls, toml_file: Path) -> Self: return cls.model_validate(toml.load(toml_file)) diff --git a/supervised/dvc.lock b/supervised/dvc.lock index 9b478797..cfa5f306 100644 --- a/supervised/dvc.lock +++ b/supervised/dvc.lock @@ -178,70 +178,68 @@ stages: size: 1619 create_training_job@dataset0-model0: cmd: - - docker build --secret id=git_auth,src=git-auth.txt -t pls:latest ../models/pls - --no-cache + - docker build --no-cache --secret id=git_auth,src=git-auth.txt -f ../models/pls/Dockerfile + -t pls:latest .. - docker run --rm -v $(realpath ../datasets):/datasets -v $(realpath ../models):/models -v $(realpath output/local):/output pls:latest train --dataset-toml-file /datasets/dummy/charge_ladder.toml - --model-toml-file /models/pls/pls.toml + --model-toml-file /models/pls/manifest.toml - docker image prune -a -f deps: - path: ../datasets/dummy/charge_ladder.toml hash: md5 md5: bcd140fc3152f1fcc9906f1a0667846d size: 258 - - path: ../models/pls + - path: ../models/pls/Dockerfile hash: md5 - md5: 4e1daa5a34b8106a55af9bf0bff97004.dir - size: 545914025 - nfiles: 17728 - - path: ../models/pls/pls.toml + md5: 901a42e51cf4a14665099985ee9c30de + size: 1138 + - path: ../models/pls/manifest.toml hash: md5 md5: 751569f65bbbc4d796ddc02a70dce809 size: 185 outs: - path: output/local/charge_ladder_pls.csv hash: md5 - md5: 46aae60f9e2fb30d2246ac632fcbdb4a + md5: 2a2145f0345bf340c208c6dcc3bf2863 size: 13780 create_training_job@dataset1-model0: cmd: - - docker build --secret id=git_auth,src=git-auth.txt -t pls:latest ../models/pls - --no-cache + - docker build --no-cache --secret id=git_auth,src=git-auth.txt -f ../models/pls/Dockerfile + -t pls:latest .. - docker run --rm -v $(realpath ../datasets):/datasets -v $(realpath ../models):/models -v $(realpath output/local):/output pls:latest train --dataset-toml-file /datasets/neime/neime.toml - --model-toml-file /models/pls/pls.toml + --model-toml-file /models/pls/manifest.toml - docker image prune -a -f deps: - path: ../datasets/neime/neime.toml hash: md5 md5: 894acf89b5058d6e50c454ddc53dd326 size: 276 - - path: ../models/pls + - path: ../models/pls/Dockerfile hash: md5 - md5: 4e1daa5a34b8106a55af9bf0bff97004.dir - size: 545914025 - nfiles: 17728 - - path: ../models/pls/pls.toml + md5: 901a42e51cf4a14665099985ee9c30de + size: 1138 + - path: ../models/pls/manifest.toml hash: md5 md5: 751569f65bbbc4d796ddc02a70dce809 size: 185 outs: - path: output/local/neime_pls.csv hash: md5 - md5: 3cf0dc59920b1e5c9c3a8b1f4aadb281 - size: 35009 + md5: 4ac1900a997dcf92c86ff082c5525a29 + size: 35019 calculate_metric@dataset0-model0: cmd: uv run pg2-benchmark metric calc --output-path output/local/charge_ladder_pls.csv --metric-path metric/local/charge_ladder_pls.csv deps: - path: output/local/charge_ladder_pls.csv hash: md5 - md5: 46aae60f9e2fb30d2246ac632fcbdb4a + md5: 2a2145f0345bf340c208c6dcc3bf2863 size: 13780 outs: - path: metric/local/charge_ladder_pls.csv hash: md5 - md5: 636cda8ace453381282c4fb56deae0ca + md5: d2e9585da62ad366468746533536fa83 size: 1457 calculate_metric@dataset1-model0: cmd: uv run pg2-benchmark metric calc --output-path output/local/neime_pls.csv @@ -249,10 +247,10 @@ stages: deps: - path: output/local/neime_pls.csv hash: md5 - md5: 3cf0dc59920b1e5c9c3a8b1f4aadb281 - size: 35009 + md5: 4ac1900a997dcf92c86ff082c5525a29 + size: 35019 outs: - path: metric/local/neime_pls.csv hash: md5 - md5: 1009bb7e7af63fe48acee67c787f3f39 - size: 1616 + md5: 71acb4373ff1e2e9ec3cab577615189f + size: 1624 diff --git a/supervised/dvc.yaml b/supervised/dvc.yaml index 88777441..2123b569 100644 --- a/supervised/dvc.yaml +++ b/supervised/dvc.yaml @@ -16,9 +16,9 @@ vars: - models: - name: pls - container_path: /models/pls/pls.toml - local_path: ../models/pls/pls.toml - dockerfile: ../models/pls + container_path: /models/pls/manifest.toml + local_path: ../models/pls/manifest.toml + dockerfile: ../models/pls/Dockerfile stages: create_training_job: @@ -27,7 +27,7 @@ stages: model: ${models} cmd: - - docker build --secret id=git_auth,src=git-auth.txt -t ${item.model.name}:latest ${item.model.dockerfile} --no-cache + - docker build --no-cache --secret id=git_auth,src=git-auth.txt -f ${item.model.dockerfile} -t ${item.model.name}:latest .. - docker run --rm -v $(realpath ${local.datasets_dir}):/datasets -v $(realpath ${local.models_dir}):/models -v $(realpath ${local.output_dir}):/output ${item.model.name}:latest train --dataset-toml-file ${item.dataset.container_path} --model-toml-file ${item.model.container_path} - docker image prune -a -f diff --git a/zero_shot/dvc.lock b/zero_shot/dvc.lock index 1d16bb34..3e5fdc6c 100644 --- a/zero_shot/dvc.lock +++ b/zero_shot/dvc.lock @@ -90,25 +90,24 @@ stages: size: 1588 create_training_job@dataset0-model0: cmd: - - docker build --secret id=git_auth,src=git-auth.txt -t esm:latest ../models/esm - --no-cache + - docker build --no-cache --secret id=git_auth,src=git-auth.txt -f ../models/esm/Dockerfile + -t esm:latest .. - docker run --rm -v $(realpath ../datasets):/datasets -v $(realpath ../models):/models -v $(realpath output/local):/output esm:latest train --dataset-toml-file /datasets/ranganathan/ranganathan.toml - --model-toml-file /models/esm/esm.toml + --model-toml-file /models/esm/manifest.toml - docker image prune -a -f deps: - path: ../datasets/ranganathan/ranganathan.toml hash: md5 md5: 2c32981e09b24dfa0ce25b01079cb49a size: 1003 - - path: ../models/esm + - path: ../models/esm/Dockerfile hash: md5 - md5: 1d3484e2b24505d96083206110e4dc35.dir - size: 831070691 - nfiles: 30601 - - path: ../models/esm/esm.toml + md5: 7f13567daf55f157ab05f23602a1f53e + size: 1138 + - path: ../models/esm/manifest.toml hash: md5 - md5: 64b3f9c9ca8ca6553bc695a1fba4d974 + md5: 76c430737597ea876f37d6d4a76bda90 size: 111 outs: - path: output/local/ranganathan_esm.csv diff --git a/zero_shot/dvc.yaml b/zero_shot/dvc.yaml index b374a8fe..01aa39ea 100644 --- a/zero_shot/dvc.yaml +++ b/zero_shot/dvc.yaml @@ -13,9 +13,9 @@ vars: - models: - name: esm - container_path: /models/esm/esm.toml - local_path: ../models/esm/esm.toml - dockerfile: ../models/esm + container_path: /models/esm/manifest.toml + local_path: ../models/esm/manifest.toml + dockerfile: ../models/esm/Dockerfile stages: create_training_job: @@ -24,7 +24,7 @@ stages: model: ${models} cmd: - - docker build --secret id=git_auth,src=git-auth.txt -t ${item.model.name}:latest ${item.model.dockerfile} --no-cache + - docker build --no-cache --secret id=git_auth,src=git-auth.txt -f ${item.model.dockerfile} -t ${item.model.name}:latest .. - docker run --rm -v $(realpath ${local.datasets_dir}):/datasets -v $(realpath ${local.models_dir}):/models -v $(realpath ${local.output_dir}):/output ${item.model.name}:latest train --dataset-toml-file ${item.dataset.container_path} --model-toml-file ${item.model.container_path} - docker image prune -a -f From 250e3292ec3ca9bd9a6db1e2861380d1c9d6871c Mon Sep 17 00:00:00 2001 From: tintinrevient Date: Tue, 29 Jul 2025 17:38:48 +0200 Subject: [PATCH 2/3] Move nogpu into model TOML file --- models/esm/manifest.toml | 1 + models/esm/src/pg2_model_esm/__main__.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/models/esm/manifest.toml b/models/esm/manifest.toml index 68b62572..cf8efa17 100644 --- a/models/esm/manifest.toml +++ b/models/esm/manifest.toml @@ -3,4 +3,5 @@ name = "esm" [hyper_params] location = "esm2_t30_150M_UR50D" scoring_strategy = "wt-marginals" +nogpu = false offset_idx = 24 \ No newline at end of file diff --git a/models/esm/src/pg2_model_esm/__main__.py b/models/esm/src/pg2_model_esm/__main__.py index 54215284..29eb7aee 100644 --- a/models/esm/src/pg2_model_esm/__main__.py +++ b/models/esm/src/pg2_model_esm/__main__.py @@ -21,7 +21,6 @@ def train( dataset_toml_file: str = typer.Option(help="Path to the dataset TOML file"), model_toml_file: str = typer.Option(help="Path to the model TOML file"), - nogpu: bool = typer.Option(False, help="GPUs available"), ): console.print(f"Loading {dataset_toml_file} and {model_toml_file}...") @@ -51,7 +50,7 @@ def train( f"Loaded the model from {hyper_params['location']} with scoring strategy {hyper_params['scoring_strategy']}." ) - if torch.cuda.is_available() and not nogpu: + if torch.cuda.is_available() and not hyper_params["nogpu"]: model = model.cuda() print("Transferred model to GPU") From 2b8c7502b5404e2f586c521a5aa72705ac9182b3 Mon Sep 17 00:00:00 2001 From: tintinrevient Date: Thu, 31 Jul 2025 14:02:04 +0200 Subject: [PATCH 3/3] Add docstring for manifest.py --- src/pg2_benchmark/manifest.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/pg2_benchmark/manifest.py b/src/pg2_benchmark/manifest.py index 60a6becf..64d51a4b 100644 --- a/src/pg2_benchmark/manifest.py +++ b/src/pg2_benchmark/manifest.py @@ -5,6 +5,19 @@ class Manifest(BaseModel): + """A manifest representing configuration for a protein language model. + + This class loads and validates model configuration from TOML files, containing + model metadata and hyperparameters for benchmarking tasks. + + Attributes: + name: The name of the model + hyper_params: Dictionary containing model hyperparameters and configuration + + The model allows extra fields beyond the defined attributes to accommodate + varying model configurations. + """ + model_config = ConfigDict(extra="allow") name: str = ""