Skip to content

Commit a01f31b

Browse files
committed
[ENH] Add from_dataframe factory method to APIDataset
Implemented a class method from_dataframe() in APIDataset to facilitate dataset creation directly from pandas DataFrames. This is highly useful as many users store interaction data in tabular formats. Features: - Convenience method for pandas integration - Automatic extraction of columns to numpy arrays - Included comprehensive test suite for APIDataset
1 parent dc1dd8d commit a01f31b

2 files changed

Lines changed: 133 additions & 0 deletions

File tree

pyaptamer/datasets/dataclasses/_api.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
__all__ = ["APIDataset"]
33

44
import numpy as np
5+
import pandas as pd
56
import torch
67
from torch.utils.data import Dataset
78

@@ -58,6 +59,58 @@ def __init__(
5859

5960
self.len = len(self.x_apta)
6061

62+
@classmethod
63+
def from_dataframe(
64+
cls,
65+
df: pd.DataFrame,
66+
apta_col: str,
67+
prot_col: str,
68+
label_col: str,
69+
apta_max_len: int,
70+
prot_max_len: int,
71+
prot_words: dict[str, int],
72+
split: str = "train",
73+
) -> "APIDataset":
74+
"""
75+
Create an APIDataset from a pandas DataFrame.
76+
77+
Parameters
78+
----------
79+
df : pd.DataFrame
80+
The dataframe containing the data.
81+
apta_col : str
82+
The name of the column containing aptamer sequences.
83+
prot_col : str
84+
The name of the column containing protein sequences.
85+
label_col : str
86+
The name of the column containing interaction labels.
87+
apta_max_len : int
88+
Maximum length for aptamer sequences.
89+
prot_max_len : int
90+
Maximum length for protein sequences.
91+
prot_words : dict[str, int]
92+
Protein k-mer word mapping.
93+
split : str, optional, default="train"
94+
If "train", the dataset will augment aptamer sequences.
95+
96+
Returns
97+
-------
98+
APIDataset
99+
An instance of APIDataset.
100+
"""
101+
x_apta = df[apta_col].values
102+
x_prot = df[prot_col].values
103+
y = df[label_col].values
104+
return cls(
105+
x_apta=x_apta,
106+
x_prot=x_prot,
107+
y=y,
108+
apta_max_len=apta_max_len,
109+
prot_max_len=prot_max_len,
110+
prot_words=prot_words,
111+
split=split,
112+
)
113+
61114
def _prepare_data(
62115
self,
63116
x_apta: np.ndarray,
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Tests for APIDataset class."""
2+
3+
import numpy as np
4+
import pandas as pd
5+
import pytest
6+
import torch
7+
8+
from pyaptamer.datasets.dataclasses import APIDataset
9+
10+
11+
class TestAPIDataset:
12+
"""Tests for APIDataset."""
13+
14+
@pytest.fixture
15+
def dummy_data(self):
16+
apta = np.array(["AUGC", "GCAU"])
17+
prot = np.array(["ACD", "EFG"])
18+
y = np.array(["positive", "negative"])
19+
words = {"ACD": 1, "EFG": 2}
20+
return apta, prot, y, words
21+
22+
def test_init_train(self, dummy_data):
23+
"""Test initialization with train split (augmentation)."""
24+
apta, prot, y, words = dummy_data
25+
ds = APIDataset(
26+
x_apta=apta,
27+
x_prot=prot,
28+
y=y,
29+
apta_max_len=10,
30+
prot_max_len=10,
31+
prot_words=words,
32+
split="train",
33+
)
34+
# Augmentation (reverse) doubles the size
35+
assert len(ds) == 4
36+
assert isinstance(ds[0][0], torch.Tensor)
37+
assert isinstance(ds[0][1], torch.Tensor)
38+
assert isinstance(ds[0][2], torch.Tensor)
39+
40+
def test_init_test(self, dummy_data):
41+
"""Test initialization with test split (no augmentation)."""
42+
apta, prot, y, words = dummy_data
43+
ds = APIDataset(
44+
x_apta=apta,
45+
x_prot=prot,
46+
y=y,
47+
apta_max_len=10,
48+
prot_max_len=10,
49+
prot_words=words,
50+
split="test",
51+
)
52+
assert len(ds) == 2
53+
54+
def test_from_dataframe(self, dummy_data):
55+
"""Test from_dataframe factory method."""
56+
apta, prot, y, words = dummy_data
57+
df = pd.DataFrame({"apta": apta, "prot": prot, "label": y})
58+
59+
ds = APIDataset.from_dataframe(
60+
df=df,
61+
apta_col="apta",
62+
prot_col="prot",
63+
label_col="label",
64+
apta_max_len=10,
65+
prot_max_len=10,
66+
prot_words=words,
67+
split="test",
68+
)
69+
70+
assert len(ds) == 2
71+
# Check that data was correctly passed
72+
# Label 'positive' becomes 1, 'negative' becomes 0
73+
assert ds[0][2].item() == 1
74+
assert ds[1][2].item() == 0
75+
76+
def test_invalid_split(self, dummy_data):
77+
"""Test that invalid split raises ValueError."""
78+
apta, prot, y, words = dummy_data
79+
with pytest.raises(ValueError, match="Unknown split"):
80+
APIDataset(apta, prot, y, 10, 10, words, split="val")

0 commit comments

Comments
 (0)