Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions pyaptamer/datasets/dataclasses/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__all__ = ["APIDataset"]

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

Expand Down Expand Up @@ -58,6 +59,58 @@ def __init__(

self.len = len(self.x_apta)

@classmethod
def from_dataframe(
cls,
df: pd.DataFrame,
apta_col: str,
prot_col: str,
label_col: str,
apta_max_len: int,
prot_max_len: int,
prot_words: dict[str, int],
split: str = "train",
) -> "APIDataset":
"""
Create an APIDataset from a pandas DataFrame.

Parameters
----------
df : pd.DataFrame
The dataframe containing the data.
apta_col : str
The name of the column containing aptamer sequences.
prot_col : str
The name of the column containing protein sequences.
label_col : str
The name of the column containing interaction labels.
apta_max_len : int
Maximum length for aptamer sequences.
prot_max_len : int
Maximum length for protein sequences.
prot_words : dict[str, int]
Protein k-mer word mapping.
split : str, optional, default="train"
If "train", the dataset will augment aptamer sequences.

Returns
-------
APIDataset
An instance of APIDataset.
"""
x_apta = df[apta_col].values
x_prot = df[prot_col].values
y = df[label_col].values
return cls(
x_apta=x_apta,
x_prot=x_prot,
y=y,
apta_max_len=apta_max_len,
prot_max_len=prot_max_len,
prot_words=prot_words,
split=split,
)

def _prepare_data(
self,
x_apta: np.ndarray,
Expand Down
80 changes: 80 additions & 0 deletions pyaptamer/datasets/tests/test_api_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Tests for APIDataset class."""

import numpy as np
import pandas as pd
import pytest
import torch

from pyaptamer.datasets.dataclasses import APIDataset


class TestAPIDataset:
"""Tests for APIDataset."""

@pytest.fixture
def dummy_data(self):
apta = np.array(["AUGC", "GCAU"])
prot = np.array(["ACD", "EFG"])
y = np.array(["positive", "negative"])
words = {"ACD": 1, "EFG": 2}
return apta, prot, y, words

def test_init_train(self, dummy_data):
"""Test initialization with train split (augmentation)."""
apta, prot, y, words = dummy_data
ds = APIDataset(
x_apta=apta,
x_prot=prot,
y=y,
apta_max_len=10,
prot_max_len=10,
prot_words=words,
split="train",
)
# Augmentation (reverse) doubles the size
assert len(ds) == 4
assert isinstance(ds[0][0], torch.Tensor)
assert isinstance(ds[0][1], torch.Tensor)
assert isinstance(ds[0][2], torch.Tensor)

def test_init_test(self, dummy_data):
"""Test initialization with test split (no augmentation)."""
apta, prot, y, words = dummy_data
ds = APIDataset(
x_apta=apta,
x_prot=prot,
y=y,
apta_max_len=10,
prot_max_len=10,
prot_words=words,
split="test",
)
assert len(ds) == 2

def test_from_dataframe(self, dummy_data):
"""Test from_dataframe factory method."""
apta, prot, y, words = dummy_data
df = pd.DataFrame({"apta": apta, "prot": prot, "label": y})

ds = APIDataset.from_dataframe(
df=df,
apta_col="apta",
prot_col="prot",
label_col="label",
apta_max_len=10,
prot_max_len=10,
prot_words=words,
split="test",
)

assert len(ds) == 2
# Check that data was correctly passed
# Label 'positive' becomes 1, 'negative' becomes 0
assert ds[0][2].item() == 1
assert ds[1][2].item() == 0

def test_invalid_split(self, dummy_data):
"""Test that invalid split raises ValueError."""
apta, prot, y, words = dummy_data
with pytest.raises(ValueError, match="Unknown split"):
APIDataset(apta, prot, y, 10, 10, words, split="val")