Skip to content

Commit 0e980af

Browse files
committed
Add torch-compatible transformations
1 parent 50d4328 commit 0e980af

10 files changed

Lines changed: 249 additions & 1 deletion

File tree

pyaptamer/trafos/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
"""Transformations."""
1+
"""Transformations module."""

pyaptamer/trafos/base/_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,7 @@ def _check_X(self, X): # noqa: N802
171171
"""
172172
X, _ = self._check_X_y(X, None)
173173
return X
174+
175+
def get_torch_transform(self):
176+
"""Return torch-compatible version of this transform, or None."""
177+
return None

pyaptamer/trafos/encode/_greedy.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,16 @@ def get_test_params(self):
148148
"word_max_len": 2,
149149
}
150150
return [param0, param1]
151+
152+
def get_torch_transform(self):
153+
"""Return torch-compatible version of this encoder."""
154+
from pyaptamer.trafos.torch import GreedyEncode
155+
156+
if self.max_len is None:
157+
raise ValueError("max_len must be set for torch transform")
158+
159+
return GreedyEncode(
160+
vocab=self.words,
161+
max_len=self.max_len,
162+
token_max_len=self.word_max_len,
163+
)

pyaptamer/trafos/torch/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""PyTorch-compatible transformations."""
2+
3+
from pyaptamer.trafos.torch._base import BaseTorchTransform
4+
from pyaptamer.trafos.torch._encode import GreedyEncode
5+
from pyaptamer.trafos.torch._mask import RandomMask
6+
from pyaptamer.trafos.torch._string import DNAtoRNA, Reverse
7+
8+
__all__ = [
9+
"BaseTorchTransform",
10+
"DNAtoRNA",
11+
"GreedyEncode",
12+
"RandomMask",
13+
"Reverse",
14+
]

pyaptamer/trafos/torch/_base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Base class for PyTorch-compatible transformations."""
2+
3+
from abc import ABC, abstractmethod
4+
5+
from torch import Tensor
6+
7+
8+
class BaseTorchTransform(ABC):
9+
"""Base class for torch-compatible transformations.
10+
11+
Subclasses must implement ``__call__``.
12+
"""
13+
14+
@abstractmethod
15+
def __call__(self, x: str | Tensor) -> str | Tensor:
16+
"""Apply transformation."""
17+
18+
def __repr__(self) -> str:
19+
return f"{self.__class__.__name__}()"

pyaptamer/trafos/torch/_encode.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""String-to-tensor transformations."""
2+
3+
import torch
4+
from torch import Tensor
5+
6+
from pyaptamer.trafos.torch._base import BaseTorchTransform
7+
8+
9+
class GreedyEncode(BaseTorchTransform):
10+
"""Greedy tokenization of sequence to tensor.
11+
12+
Matches longest possible token at each position. Unknown chars map to 0.
13+
"""
14+
15+
def __init__(self, vocab: dict[str, int], max_len: int, token_max_len: int = None):
16+
self.vocab = vocab
17+
self.max_len = max_len
18+
self.token_max_len = token_max_len or max(len(k) for k in vocab)
19+
20+
def __call__(self, x: str) -> Tensor:
21+
tokens = []
22+
i = 0
23+
while i < len(x):
24+
matched = False
25+
for j in range(self.token_max_len, 0, -1):
26+
if i + j <= len(x):
27+
substr = x[i : i + j]
28+
if substr in self.vocab:
29+
tokens.append(self.vocab[substr])
30+
i += j
31+
matched = True
32+
break
33+
if not matched:
34+
tokens.append(0)
35+
i += 1
36+
37+
if len(tokens) < self.max_len:
38+
tokens.extend([0] * (self.max_len - len(tokens)))
39+
else:
40+
tokens = tokens[: self.max_len]
41+
42+
return torch.tensor(tokens, dtype=torch.long)
43+
44+
def __repr__(self) -> str:
45+
return f"{self.__class__.__name__}(max_len={self.max_len})"

pyaptamer/trafos/torch/_mask.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Tensor-to-tensor transformations."""
2+
3+
import random
4+
5+
from torch import Tensor
6+
7+
from pyaptamer.trafos.torch._base import BaseTorchTransform
8+
9+
10+
class RandomMask(BaseTorchTransform):
11+
"""Randomly mask positions in a sequence tensor."""
12+
13+
def __init__(self, mask_idx: int, mask_rate: float = 0.15, padding_idx: int = 0):
14+
self.mask_idx = mask_idx
15+
self.mask_rate = mask_rate
16+
self.padding_idx = padding_idx
17+
18+
def __call__(self, x: Tensor) -> Tensor:
19+
x_masked = x.clone()
20+
valid_pos = (x != self.padding_idx).nonzero(as_tuple=True)[0].tolist()
21+
n_mask = int(len(valid_pos) * self.mask_rate)
22+
23+
if n_mask > 0:
24+
mask_pos = random.sample(valid_pos, n_mask)
25+
x_masked[mask_pos] = self.mask_idx
26+
27+
return x_masked
28+
29+
def __repr__(self) -> str:
30+
return f"{self.__class__.__name__}(mask_idx={self.mask_idx})"

pyaptamer/trafos/torch/_string.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""String-to-string transformations."""
2+
3+
from pyaptamer.trafos.torch._base import BaseTorchTransform
4+
5+
6+
class Reverse(BaseTorchTransform):
7+
"""Reverse a sequence string."""
8+
9+
def __call__(self, x: str) -> str:
10+
return x[::-1]
11+
12+
13+
class DNAtoRNA(BaseTorchTransform):
14+
"""Convert DNA to RNA (T -> U)."""
15+
16+
def __call__(self, x: str) -> str:
17+
return x.replace("T", "U")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for torch transformations."""
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Tests for torch transforms."""
2+
3+
import pytest
4+
import torch
5+
6+
from pyaptamer.trafos.torch import (
7+
DNAtoRNA,
8+
GreedyEncode,
9+
RandomMask,
10+
Reverse,
11+
)
12+
13+
14+
class TestReverse:
15+
def test_reverse(self):
16+
t = Reverse()
17+
assert t("ACGT") == "TGCA"
18+
assert t("AAA") == "AAA"
19+
assert t("") == ""
20+
21+
def test_repr(self):
22+
assert "Reverse" in repr(Reverse())
23+
24+
25+
class TestDNAtoRNA:
26+
def test_convert(self):
27+
t = DNAtoRNA()
28+
assert t("ACGT") == "ACGU"
29+
assert t("TTT") == "UUU"
30+
assert t("ACG") == "ACG"
31+
32+
def test_repr(self):
33+
assert "DNAtoRNA" in repr(DNAtoRNA())
34+
35+
36+
class TestGreedyEncode:
37+
@pytest.fixture
38+
def vocab(self):
39+
return {"A": 1, "C": 2, "G": 3, "T": 4, "AC": 5, "GT": 6}
40+
41+
def test_encode_simple(self, vocab):
42+
t = GreedyEncode(vocab, max_len=5)
43+
result = t("ACGT")
44+
assert result.shape == (5,)
45+
assert result[0].item() == 5
46+
assert result[1].item() == 6
47+
48+
def test_padding(self, vocab):
49+
t = GreedyEncode(vocab, max_len=10)
50+
result = t("A")
51+
assert result.shape == (10,)
52+
assert result[0].item() == 1
53+
assert result[1].item() == 0
54+
55+
def test_truncation(self, vocab):
56+
t = GreedyEncode(vocab, max_len=2)
57+
result = t("ACGTACGT")
58+
assert result.shape == (2,)
59+
60+
def test_unknown_char(self, vocab):
61+
t = GreedyEncode(vocab, max_len=5)
62+
result = t("XYZ")
63+
assert result[0].item() == 0
64+
65+
def test_repr(self, vocab):
66+
t = GreedyEncode(vocab, max_len=10)
67+
assert "GreedyEncode" in repr(t)
68+
69+
70+
class TestRandomMask:
71+
def test_shape(self):
72+
t = RandomMask(mask_idx=99, mask_rate=0.5)
73+
x = torch.tensor([1, 2, 3, 4, 0, 0])
74+
assert t(x).shape == x.shape
75+
76+
def test_preserves_padding(self):
77+
t = RandomMask(mask_idx=99, mask_rate=1.0)
78+
x = torch.tensor([1, 2, 0, 0])
79+
result = t(x)
80+
assert result[2].item() == 0
81+
assert result[3].item() == 0
82+
83+
def test_applies_mask(self):
84+
torch.manual_seed(42)
85+
t = RandomMask(mask_idx=99, mask_rate=0.5)
86+
x = torch.tensor([1, 2, 3, 4])
87+
assert (t(x) == 99).any()
88+
89+
def test_repr(self):
90+
t = RandomMask(mask_idx=99)
91+
assert "RandomMask" in repr(t)
92+
93+
94+
class TestChaining:
95+
def test_str_transforms(self):
96+
t1 = DNAtoRNA()
97+
t2 = Reverse()
98+
assert t2(t1("ACGT")) == "UGCA"
99+
100+
def test_tensor_transforms(self):
101+
vocab = {"A": 1, "C": 2, "G": 3, "U": 4}
102+
encode = GreedyEncode(vocab, max_len=5)
103+
mask = RandomMask(mask_idx=99, mask_rate=0.5)
104+
result = mask(encode("ACGU"))
105+
assert result.shape == (5,)

0 commit comments

Comments
 (0)