From a01f31b253253c3b893af811b9cd013e5df83381 Mon Sep 17 00:00:00 2001 From: Vaishnav88sk Date: Sat, 2 May 2026 14:48:50 +0530 Subject: [PATCH] [ENH] Add from_dataframe factory method to APIDataset Implemented a class method from_dataframe() in APIDataset to facilitate dataset creation directly from pandas DataFrames. This is highly useful as many users store interaction data in tabular formats. Features: - Convenience method for pandas integration - Automatic extraction of columns to numpy arrays - Included comprehensive test suite for APIDataset --- pyaptamer/datasets/dataclasses/_api.py | 53 +++++++++++++ pyaptamer/datasets/tests/test_api_dataset.py | 80 ++++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 pyaptamer/datasets/tests/test_api_dataset.py diff --git a/pyaptamer/datasets/dataclasses/_api.py b/pyaptamer/datasets/dataclasses/_api.py index 5ccd7970..f0bd39b3 100644 --- a/pyaptamer/datasets/dataclasses/_api.py +++ b/pyaptamer/datasets/dataclasses/_api.py @@ -2,6 +2,7 @@ __all__ = ["APIDataset"] import numpy as np +import pandas as pd import torch from torch.utils.data import Dataset @@ -58,6 +59,58 @@ def __init__( self.len = len(self.x_apta) + @classmethod + def from_dataframe( + cls, + df: pd.DataFrame, + apta_col: str, + prot_col: str, + label_col: str, + apta_max_len: int, + prot_max_len: int, + prot_words: dict[str, int], + split: str = "train", + ) -> "APIDataset": + """ + Create an APIDataset from a pandas DataFrame. + + Parameters + ---------- + df : pd.DataFrame + The dataframe containing the data. + apta_col : str + The name of the column containing aptamer sequences. + prot_col : str + The name of the column containing protein sequences. + label_col : str + The name of the column containing interaction labels. + apta_max_len : int + Maximum length for aptamer sequences. + prot_max_len : int + Maximum length for protein sequences. + prot_words : dict[str, int] + Protein k-mer word mapping. + split : str, optional, default="train" + If "train", the dataset will augment aptamer sequences. + + Returns + ------- + APIDataset + An instance of APIDataset. + """ + x_apta = df[apta_col].values + x_prot = df[prot_col].values + y = df[label_col].values + return cls( + x_apta=x_apta, + x_prot=x_prot, + y=y, + apta_max_len=apta_max_len, + prot_max_len=prot_max_len, + prot_words=prot_words, + split=split, + ) + def _prepare_data( self, x_apta: np.ndarray, diff --git a/pyaptamer/datasets/tests/test_api_dataset.py b/pyaptamer/datasets/tests/test_api_dataset.py new file mode 100644 index 00000000..e135474b --- /dev/null +++ b/pyaptamer/datasets/tests/test_api_dataset.py @@ -0,0 +1,80 @@ +"""Tests for APIDataset class.""" + +import numpy as np +import pandas as pd +import pytest +import torch + +from pyaptamer.datasets.dataclasses import APIDataset + + +class TestAPIDataset: + """Tests for APIDataset.""" + + @pytest.fixture + def dummy_data(self): + apta = np.array(["AUGC", "GCAU"]) + prot = np.array(["ACD", "EFG"]) + y = np.array(["positive", "negative"]) + words = {"ACD": 1, "EFG": 2} + return apta, prot, y, words + + def test_init_train(self, dummy_data): + """Test initialization with train split (augmentation).""" + apta, prot, y, words = dummy_data + ds = APIDataset( + x_apta=apta, + x_prot=prot, + y=y, + apta_max_len=10, + prot_max_len=10, + prot_words=words, + split="train", + ) + # Augmentation (reverse) doubles the size + assert len(ds) == 4 + assert isinstance(ds[0][0], torch.Tensor) + assert isinstance(ds[0][1], torch.Tensor) + assert isinstance(ds[0][2], torch.Tensor) + + def test_init_test(self, dummy_data): + """Test initialization with test split (no augmentation).""" + apta, prot, y, words = dummy_data + ds = APIDataset( + x_apta=apta, + x_prot=prot, + y=y, + apta_max_len=10, + prot_max_len=10, + prot_words=words, + split="test", + ) + assert len(ds) == 2 + + def test_from_dataframe(self, dummy_data): + """Test from_dataframe factory method.""" + apta, prot, y, words = dummy_data + df = pd.DataFrame({"apta": apta, "prot": prot, "label": y}) + + ds = APIDataset.from_dataframe( + df=df, + apta_col="apta", + prot_col="prot", + label_col="label", + apta_max_len=10, + prot_max_len=10, + prot_words=words, + split="test", + ) + + assert len(ds) == 2 + # Check that data was correctly passed + # Label 'positive' becomes 1, 'negative' becomes 0 + assert ds[0][2].item() == 1 + assert ds[1][2].item() == 0 + + def test_invalid_split(self, dummy_data): + """Test that invalid split raises ValueError.""" + apta, prot, y, words = dummy_data + with pytest.raises(ValueError, match="Unknown split"): + APIDataset(apta, prot, y, 10, 10, words, split="val")