-
Notifications
You must be signed in to change notification settings - Fork 134
Add torch-compatible transformations #246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
0e980af
960ce06
287b671
798fa2a
c5f73f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| """Transformations.""" | ||
| """Transformations module.""" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -171,3 +171,7 @@ def _check_X(self, X): # noqa: N802 | |
| """ | ||
| X, _ = self._check_X_y(X, None) | ||
| return X | ||
|
|
||
| def get_torch_transform(self): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the reasoning and use of the function that is being added here? Please explain, since it's just returning None?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dear @siddharth7113 The By default it returns None, meaning "this transform doesn't have a torch equivalent." Subclasses that have torch-compatible versions can override this to return them. For example, the sklearn GreedyEncode encoder overrides this method to return the torch GreedyEncode transform with the same parameters. This allows flexibility - not every sklearn transform needs a torch version, but those that do can provide it through this interface. |
||
| """Return torch-compatible version of this transform, or None.""" | ||
| return None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| """PyTorch-compatible transformations.""" | ||
|
|
||
| from pyaptamer.trafos.torch._base import BaseTorchTransform | ||
| from pyaptamer.trafos.torch._encode import GreedyEncode | ||
| from pyaptamer.trafos.torch._mask import RandomMask | ||
| from pyaptamer.trafos.torch._string import DNAtoRNA, Reverse | ||
|
|
||
| __all__ = [ | ||
| "BaseTorchTransform", | ||
| "DNAtoRNA", | ||
| "GreedyEncode", | ||
| "RandomMask", | ||
| "Reverse", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| """Base class for PyTorch-compatible transformations.""" | ||
|
|
||
| from abc import ABC, abstractmethod | ||
|
|
||
| from torch import Tensor | ||
|
|
||
|
|
||
| class BaseTorchTransform(ABC): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @fkiraly , This is the proposed design for the torch base class, but the currently it only puts call and repr, do we have something similar in |
||
| """Base class for torch-compatible transformations. | ||
|
|
||
| Subclasses must implement ``__call__``. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def __call__(self, x: str | Tensor) -> str | Tensor: | ||
| """Apply transformation.""" | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}()" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| """String-to-tensor transformations.""" | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
| from pyaptamer.trafos.torch._base import BaseTorchTransform | ||
|
|
||
|
|
||
| class GreedyEncode(BaseTorchTransform): | ||
| """Greedy tokenization of sequence to tensor. | ||
|
|
||
| Matches longest possible token at each position. Unknown chars map to 0. | ||
| """ | ||
|
|
||
| def __init__(self, vocab: dict[str, int], max_len: int, token_max_len: int = None): | ||
| self.vocab = vocab | ||
| self.max_len = max_len | ||
| self.token_max_len = token_max_len or max(len(k) for k in vocab) | ||
|
|
||
| def __call__(self, x: str) -> Tensor: | ||
| tokens = [] | ||
| i = 0 | ||
| while i < len(x): | ||
| matched = False | ||
| for j in range(self.token_max_len, 0, -1): | ||
| if i + j <= len(x): | ||
| substr = x[i : i + j] | ||
| if substr in self.vocab: | ||
| tokens.append(self.vocab[substr]) | ||
| i += j | ||
| matched = True | ||
| break | ||
| if not matched: | ||
| tokens.append(0) | ||
| i += 1 | ||
|
|
||
| if len(tokens) < self.max_len: | ||
| tokens.extend([0] * (self.max_len - len(tokens))) | ||
| else: | ||
| tokens = tokens[: self.max_len] | ||
|
|
||
| return torch.tensor(tokens, dtype=torch.long) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}(max_len={self.max_len})" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| """Tensor-to-tensor transformations.""" | ||
|
|
||
| import random | ||
|
|
||
| from torch import Tensor | ||
|
|
||
| from pyaptamer.trafos.torch._base import BaseTorchTransform | ||
|
|
||
|
|
||
| class RandomMask(BaseTorchTransform): | ||
| """Randomly mask positions in a sequence tensor.""" | ||
|
|
||
| def __init__(self, mask_idx: int, mask_rate: float = 0.15, padding_idx: int = 0): | ||
| self.mask_idx = mask_idx | ||
| self.mask_rate = mask_rate | ||
| self.padding_idx = padding_idx | ||
|
|
||
| def __call__(self, x: Tensor) -> Tensor: | ||
| x_masked = x.clone() | ||
| valid_pos = (x != self.padding_idx).nonzero(as_tuple=True)[0].tolist() | ||
| n_mask = int(len(valid_pos) * self.mask_rate) | ||
|
|
||
| if n_mask > 0: | ||
| mask_pos = random.sample(valid_pos, n_mask) | ||
| x_masked[mask_pos] = self.mask_idx | ||
|
|
||
| return x_masked | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}(mask_idx={self.mask_idx})" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| """String-to-string transformations.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this still doesn't make sense, why have a _string in torch ? when there are no tensor operations involved. |
||
|
|
||
|
|
||
| class Reverse: | ||
| """Reverse a sequence string.""" | ||
|
|
||
| def __call__(self, x: str) -> str: | ||
| return x[::-1] | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}()" | ||
|
|
||
|
|
||
| class DNAtoRNA: | ||
| """Convert DNA to RNA (T -> U).""" | ||
|
|
||
| def __call__(self, x: str) -> str: | ||
| return x.replace("T", "U") | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}()" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Tests for torch transformations.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| """Tests for torch transforms.""" | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from pyaptamer.trafos.torch import ( | ||
| DNAtoRNA, | ||
| GreedyEncode, | ||
| RandomMask, | ||
| Reverse, | ||
| ) | ||
|
|
||
|
|
||
| class TestReverse: | ||
| def test_reverse(self): | ||
| t = Reverse() | ||
| assert t("ACGT") == "TGCA" | ||
| assert t("AAA") == "AAA" | ||
| assert t("") == "" | ||
|
|
||
| def test_repr(self): | ||
| assert "Reverse" in repr(Reverse()) | ||
|
|
||
|
|
||
| class TestDNAtoRNA: | ||
| def test_convert(self): | ||
| t = DNAtoRNA() | ||
| assert t("ACGT") == "ACGU" | ||
| assert t("TTT") == "UUU" | ||
| assert t("ACG") == "ACG" | ||
|
|
||
| def test_repr(self): | ||
| assert "DNAtoRNA" in repr(DNAtoRNA()) | ||
|
|
||
|
|
||
| class TestGreedyEncode: | ||
| @pytest.fixture | ||
| def vocab(self): | ||
| return {"A": 1, "C": 2, "G": 3, "T": 4, "AC": 5, "GT": 6} | ||
|
|
||
| def test_encode_simple(self, vocab): | ||
| t = GreedyEncode(vocab, max_len=5) | ||
| result = t("ACGT") | ||
| assert result.shape == (5,) | ||
| assert result[0].item() == 5 | ||
| assert result[1].item() == 6 | ||
|
|
||
| def test_padding(self, vocab): | ||
| t = GreedyEncode(vocab, max_len=10) | ||
| result = t("A") | ||
| assert result.shape == (10,) | ||
| assert result[0].item() == 1 | ||
| assert result[1].item() == 0 | ||
|
|
||
| def test_truncation(self, vocab): | ||
| t = GreedyEncode(vocab, max_len=2) | ||
| result = t("ACGTACGT") | ||
| assert result.shape == (2,) | ||
|
|
||
| def test_unknown_char(self, vocab): | ||
| t = GreedyEncode(vocab, max_len=5) | ||
| result = t("XYZ") | ||
| assert result[0].item() == 0 | ||
|
|
||
| def test_repr(self, vocab): | ||
| t = GreedyEncode(vocab, max_len=10) | ||
| assert "GreedyEncode" in repr(t) | ||
|
|
||
|
|
||
| class TestRandomMask: | ||
| def test_shape(self): | ||
| t = RandomMask(mask_idx=99, mask_rate=0.5) | ||
| x = torch.tensor([1, 2, 3, 4, 0, 0]) | ||
| assert t(x).shape == x.shape | ||
|
|
||
| def test_preserves_padding(self): | ||
| t = RandomMask(mask_idx=99, mask_rate=1.0) | ||
| x = torch.tensor([1, 2, 0, 0]) | ||
| result = t(x) | ||
| assert result[2].item() == 0 | ||
| assert result[3].item() == 0 | ||
|
|
||
| def test_applies_mask(self): | ||
| torch.manual_seed(42) | ||
| t = RandomMask(mask_idx=99, mask_rate=0.5) | ||
| x = torch.tensor([1, 2, 3, 4]) | ||
| assert (t(x) == 99).any() | ||
|
|
||
| def test_repr(self): | ||
| t = RandomMask(mask_idx=99) | ||
| assert "RandomMask" in repr(t) | ||
|
|
||
|
|
||
| class TestChaining: | ||
| def test_str_transforms(self): | ||
| t1 = DNAtoRNA() | ||
| t2 = Reverse() | ||
| assert t2(t1("ACGT")) == "UGCA" | ||
|
|
||
| def test_tensor_transforms(self): | ||
| vocab = {"A": 1, "C": 2, "G": 3, "U": 4} | ||
| encode = GreedyEncode(vocab, max_len=5) | ||
| mask = RandomMask(mask_idx=99, mask_rate=0.5) | ||
| result = mask(encode("ACGU")) | ||
| assert result.shape == (5,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why add module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank u , i will correct it!