Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions models/esm/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/

WORKDIR /opt/program

COPY ./README.md ./README.md
COPY ./pyproject.toml ./pyproject.toml
# Copy pg2-benchmark source
COPY ./README.md ./pg2-benchmark/README.md
COPY ./pyproject.toml ./pg2-benchmark/pyproject.toml
COPY ./src ./pg2-benchmark/src

# Copy pg2-model-esm dependencies
COPY ./models/esm/README.md ./README.md
COPY ./models/esm/pyproject.toml ./pyproject.toml

# TODO: After pg2-dataset is public, below can be removed:
# Currently, it is required to git clone pg2-dataset, which is a private repo.
Expand All @@ -28,6 +34,6 @@ RUN --mount=type=secret,id=git_auth \

RUN uv sync --no-cache

COPY ./src ./src
COPY ./models/esm/src ./src

ENTRYPOINT ["uv", "run", "pg2-model"]
5 changes: 3 additions & 2 deletions models/esm/esm.toml → models/esm/manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name = "esm"
location = "esm2_t30_150M_UR50D"
scoring_strategy = "wt-marginals"

[hyper_params]
location = "esm2_t30_150M_UR50D"
scoring_strategy = "wt-marginals"
nogpu = false
offset_idx = 24
2 changes: 2 additions & 0 deletions models/esm/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ requires-python = ">=3.12"
dependencies = [
"typer~=0.16.0",
"pg2-dataset[biopython]",
"pg2-benchmark",
"fair-esm~=2.0.0",
"torch~=2.7.1",
"bio~=1.8.0",
Expand All @@ -23,6 +24,7 @@ build-backend = "hatchling.build"

[tool.uv.sources]
pg2-dataset = { git = "https://github.com/ProteinGym2/pg2-dataset.git", rev = "58c327e13bade1effe1312eb2b8d5445016a5a8f" }
pg2-benchmark = { path = "./pg2-benchmark", editable = true }

[tool.hatch.build.targets.wheel]
packages = ["src/pg2_model_esm"]
Expand Down
17 changes: 8 additions & 9 deletions models/esm/src/pg2_model_esm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tqdm import tqdm
from esm import pretrained
from pg2_model_esm.utils import compute_pppl, label_row
from pg2_model_esm.manifest import Manifest as ModelManifest
from pg2_benchmark.manifest import Manifest as ModelManifest


app = typer.Typer(
Expand All @@ -21,7 +21,6 @@
def train(
dataset_toml_file: str = typer.Option(help="Path to the dataset TOML file"),
model_toml_file: str = typer.Option(help="Path to the model TOML file"),
nogpu: bool = typer.Option(False, help="GPUs available"),
):
console.print(f"Loading {dataset_toml_file} and {model_toml_file}...")

Expand All @@ -42,18 +41,16 @@ def train(
model_manifest = ModelManifest.from_path(model_toml_file)

model_name = model_manifest.name
location = model_manifest.location
scoring_strategy = model_manifest.scoring_strategy
hyper_params = model_manifest.hyper_params

model, alphabet = pretrained.load_model_and_alphabet(location)
model, alphabet = pretrained.load_model_and_alphabet(hyper_params["location"])
model.eval()

console.print(
f"Loaded the model from {location} with scoring strategy {scoring_strategy}."
f"Loaded the model from {hyper_params['location']} with scoring strategy {hyper_params['scoring_strategy']}."
)

if torch.cuda.is_available() and not nogpu:
if torch.cuda.is_available() and not hyper_params["nogpu"]:
model = model.cuda()
print("Transferred model to GPU")

Expand All @@ -65,7 +62,7 @@ def train(

batch_labels, batch_strs, batch_tokens = batch_converter(data)

match scoring_strategy:
match hyper_params["scoring_strategy"]:
case "wt-marginals":
with torch.no_grad():
token_probs = torch.log_softmax(model(batch_tokens)["logits"], dim=-1)
Expand Down Expand Up @@ -123,7 +120,9 @@ def train(
)

case _:
err_console.print(f"Error: Invalid scoring strategy: {scoring_strategy}")
err_console.print(
f"Error: Invalid scoring strategy: {hyper_params['scoring_strategy']}"
)

df.rename(columns={targets[0]: "test"}, inplace=True)
df.to_csv(f"/output/{dataset_name}_{model_name}.csv", index=False)
Expand Down
16 changes: 0 additions & 16 deletions models/esm/src/pg2_model_esm/manifest.py

This file was deleted.

12 changes: 9 additions & 3 deletions models/pls/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/

WORKDIR /opt/program

COPY ./README.md ./README.md
COPY ./pyproject.toml ./pyproject.toml
# Copy pg2-benchmark source
COPY ./README.md ./pg2-benchmark/README.md
COPY ./pyproject.toml ./pg2-benchmark/pyproject.toml
COPY ./src ./pg2-benchmark/src

# Copy pg2-model-pls dependencies
COPY ./models/pls/README.md ./README.md
COPY ./models/pls/pyproject.toml ./pyproject.toml

# TODO: After pg2-dataset is public, below can be removed:
# Currently, it is required to git clone pg2-dataset, which is a private repo.
Expand All @@ -28,6 +34,6 @@ RUN --mount=type=secret,id=git_auth \

RUN uv sync --no-cache

COPY ./src ./src
COPY ./models/pls/src ./src

ENTRYPOINT ["uv", "run", "pg2-model"]
File renamed without changes.
2 changes: 2 additions & 0 deletions models/pls/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies = [
"numpy~=2.2.5",
"scikit-learn~=1.7.0",
"pg2-dataset[biopython]",
"pg2-benchmark",
"polars[pyarrow]~=1.31.0",
]

Expand All @@ -21,6 +22,7 @@ build-backend = "hatchling.build"

[tool.uv.sources]
pg2-dataset = { git = "https://github.com/ProteinGym2/pg2-dataset.git", rev = "58c327e13bade1effe1312eb2b8d5445016a5a8f" }
pg2-benchmark = { path = "./pg2-benchmark", editable = true }

[tool.hatch.build.targets.wheel]
packages = ["src/pg2_model_pls"]
Expand Down
2 changes: 1 addition & 1 deletion models/pls/src/pg2_model_pls/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rich.console import Console
from pg2_dataset.dataset import Manifest
from pg2_dataset.splits.abstract_split_strategy import TrainTestValid
from pg2_model_pls.manifest import Manifest as ModelManifest
from pg2_benchmark.manifest import Manifest as ModelManifest
from pg2_model_pls.utils import load_x_and_y, train_model, predict_model

import typer
Expand Down
13 changes: 0 additions & 13 deletions models/pls/src/pg2_model_pls/manifest.py

This file was deleted.

2 changes: 1 addition & 1 deletion models/pls/src/pg2_model_pls/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pg2_dataset.dataset import Manifest
from pg2_dataset.backends.assays import SPLIT_STRATEGY_MAPPING
from pg2_dataset.splits.abstract_split_strategy import TrainTestValid
from pg2_model_pls.manifest import Manifest as ModelManifest
from pg2_benchmark.manifest import Manifest as ModelManifest
import logging

logger = logging.getLogger(__name__)
Expand Down
58 changes: 53 additions & 5 deletions models/pls/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions src/pg2_benchmark/manifest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from pydantic import BaseModel, Field, ConfigDict
from pathlib import Path
from typing import Self, Any
import toml


class Manifest(BaseModel):
"""A manifest representing configuration for a protein language model.

This class loads and validates model configuration from TOML files, containing
model metadata and hyperparameters for benchmarking tasks.

Attributes:
name: The name of the model
hyper_params: Dictionary containing model hyperparameters and configuration

The model allows extra fields beyond the defined attributes to accommodate
varying model configurations.
"""

model_config = ConfigDict(extra="allow")
Comment thread
tintinrevient marked this conversation as resolved.

name: str = ""
hyper_params: dict[str, Any] = Field(default_factory=dict)

@classmethod
def from_path(cls, toml_file: Path) -> Self:
return cls.model_validate(toml.load(toml_file))
Loading