Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyaptamer/trafos/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""Transformations."""
"""Transformations module."""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add module?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank u , i will correct it!

4 changes: 4 additions & 0 deletions pyaptamer/trafos/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,7 @@ def _check_X(self, X): # noqa: N802
"""
X, _ = self._check_X_y(X, None)
return X

def get_torch_transform(self):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reasoning and use of the function that is being added here? Please explain, since it's just returning None?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dear @siddharth7113 The get_torch_transform() is designed as an extension point in the sklearn base class.

By default it returns None, meaning "this transform doesn't have a torch equivalent."

Subclasses that have torch-compatible versions can override this to return them. For example, the sklearn GreedyEncode encoder overrides this method to return the torch GreedyEncode transform with the same parameters.

This allows flexibility - not every sklearn transform needs a torch version, but those that do can provide it through this interface.

"""Return torch-compatible version of this transform, or None."""
return None
13 changes: 13 additions & 0 deletions pyaptamer/trafos/encode/_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,16 @@ def get_test_params(self):
"word_max_len": 2,
}
return [param0, param1]

def get_torch_transform(self):
"""Return torch-compatible version of this encoder."""
from pyaptamer.trafos.torch import GreedyEncode

if self.max_len is None:
raise ValueError("max_len must be set for torch transform")

return GreedyEncode(
vocab=self.words,
max_len=self.max_len,
token_max_len=self.word_max_len,
)
14 changes: 14 additions & 0 deletions pyaptamer/trafos/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""PyTorch-compatible transformations."""

from pyaptamer.trafos.torch._base import BaseTorchTransform
from pyaptamer.trafos.torch._encode import GreedyEncode
from pyaptamer.trafos.torch._mask import RandomMask
from pyaptamer.trafos.torch._string import DNAtoRNA, Reverse

__all__ = [
"BaseTorchTransform",
"DNAtoRNA",
"GreedyEncode",
"RandomMask",
"Reverse",
]
19 changes: 19 additions & 0 deletions pyaptamer/trafos/torch/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Base class for PyTorch-compatible transformations."""

from abc import ABC, abstractmethod

from torch import Tensor


class BaseTorchTransform(ABC):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @fkiraly , This is the proposed design for the torch base class, but the currently it only puts call and repr, do we have something similar in sktime to borrow for torch related transformers?

"""Base class for torch-compatible transformations.

Subclasses must implement ``__call__``.
"""

@abstractmethod
def __call__(self, x: str | Tensor) -> str | Tensor:
"""Apply transformation."""

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
45 changes: 45 additions & 0 deletions pyaptamer/trafos/torch/_encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""String-to-tensor transformations."""

import torch
from torch import Tensor

from pyaptamer.trafos.torch._base import BaseTorchTransform


class GreedyEncode(BaseTorchTransform):
"""Greedy tokenization of sequence to tensor.

Matches longest possible token at each position. Unknown chars map to 0.
"""

def __init__(self, vocab: dict[str, int], max_len: int, token_max_len: int = None):
self.vocab = vocab
self.max_len = max_len
self.token_max_len = token_max_len or max(len(k) for k in vocab)

def __call__(self, x: str) -> Tensor:
tokens = []
i = 0
while i < len(x):
matched = False
for j in range(self.token_max_len, 0, -1):
if i + j <= len(x):
substr = x[i : i + j]
if substr in self.vocab:
tokens.append(self.vocab[substr])
i += j
matched = True
break
if not matched:
tokens.append(0)
i += 1

if len(tokens) < self.max_len:
tokens.extend([0] * (self.max_len - len(tokens)))
else:
tokens = tokens[: self.max_len]

return torch.tensor(tokens, dtype=torch.long)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(max_len={self.max_len})"
30 changes: 30 additions & 0 deletions pyaptamer/trafos/torch/_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Tensor-to-tensor transformations."""

import random

from torch import Tensor

from pyaptamer.trafos.torch._base import BaseTorchTransform


class RandomMask(BaseTorchTransform):
"""Randomly mask positions in a sequence tensor."""

def __init__(self, mask_idx: int, mask_rate: float = 0.15, padding_idx: int = 0):
self.mask_idx = mask_idx
self.mask_rate = mask_rate
self.padding_idx = padding_idx

def __call__(self, x: Tensor) -> Tensor:
x_masked = x.clone()
valid_pos = (x != self.padding_idx).nonzero(as_tuple=True)[0].tolist()
n_mask = int(len(valid_pos) * self.mask_rate)

if n_mask > 0:
mask_pos = random.sample(valid_pos, n_mask)
x_masked[mask_pos] = self.mask_idx

return x_masked

def __repr__(self) -> str:
return f"{self.__class__.__name__}(mask_idx={self.mask_idx})"
21 changes: 21 additions & 0 deletions pyaptamer/trafos/torch/_string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""String-to-string transformations."""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this still doesn't make sense, why have a _string in torch ? when there are no tensor operations involved.



class Reverse:
"""Reverse a sequence string."""

def __call__(self, x: str) -> str:
return x[::-1]

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"


class DNAtoRNA:
"""Convert DNA to RNA (T -> U)."""

def __call__(self, x: str) -> str:
return x.replace("T", "U")

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
1 change: 1 addition & 0 deletions pyaptamer/trafos/torch/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for torch transformations."""
105 changes: 105 additions & 0 deletions pyaptamer/trafos/torch/tests/test_torch_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Tests for torch transforms."""

import pytest
import torch

from pyaptamer.trafos.torch import (
DNAtoRNA,
GreedyEncode,
RandomMask,
Reverse,
)


class TestReverse:
def test_reverse(self):
t = Reverse()
assert t("ACGT") == "TGCA"
assert t("AAA") == "AAA"
assert t("") == ""

def test_repr(self):
assert "Reverse" in repr(Reverse())


class TestDNAtoRNA:
def test_convert(self):
t = DNAtoRNA()
assert t("ACGT") == "ACGU"
assert t("TTT") == "UUU"
assert t("ACG") == "ACG"

def test_repr(self):
assert "DNAtoRNA" in repr(DNAtoRNA())


class TestGreedyEncode:
@pytest.fixture
def vocab(self):
return {"A": 1, "C": 2, "G": 3, "T": 4, "AC": 5, "GT": 6}

def test_encode_simple(self, vocab):
t = GreedyEncode(vocab, max_len=5)
result = t("ACGT")
assert result.shape == (5,)
assert result[0].item() == 5
assert result[1].item() == 6

def test_padding(self, vocab):
t = GreedyEncode(vocab, max_len=10)
result = t("A")
assert result.shape == (10,)
assert result[0].item() == 1
assert result[1].item() == 0

def test_truncation(self, vocab):
t = GreedyEncode(vocab, max_len=2)
result = t("ACGTACGT")
assert result.shape == (2,)

def test_unknown_char(self, vocab):
t = GreedyEncode(vocab, max_len=5)
result = t("XYZ")
assert result[0].item() == 0

def test_repr(self, vocab):
t = GreedyEncode(vocab, max_len=10)
assert "GreedyEncode" in repr(t)


class TestRandomMask:
def test_shape(self):
t = RandomMask(mask_idx=99, mask_rate=0.5)
x = torch.tensor([1, 2, 3, 4, 0, 0])
assert t(x).shape == x.shape

def test_preserves_padding(self):
t = RandomMask(mask_idx=99, mask_rate=1.0)
x = torch.tensor([1, 2, 0, 0])
result = t(x)
assert result[2].item() == 0
assert result[3].item() == 0

def test_applies_mask(self):
torch.manual_seed(42)
t = RandomMask(mask_idx=99, mask_rate=0.5)
x = torch.tensor([1, 2, 3, 4])
assert (t(x) == 99).any()

def test_repr(self):
t = RandomMask(mask_idx=99)
assert "RandomMask" in repr(t)


class TestChaining:
def test_str_transforms(self):
t1 = DNAtoRNA()
t2 = Reverse()
assert t2(t1("ACGT")) == "UGCA"

def test_tensor_transforms(self):
vocab = {"A": 1, "C": 2, "G": 3, "U": 4}
encode = GreedyEncode(vocab, max_len=5)
mask = RandomMask(mask_idx=99, mask_rate=0.5)
result = mask(encode("ACGU"))
assert result.shape == (5,)