Add torch-compatible transformations#246
Conversation
55b9f7b to
0e980af
Compare
f1c587a to
0e980af
Compare
fkiraly
left a comment
There was a problem hiding this comment.
Can you explain how this design would be used in practice? E.g., by providing a basic vignette on how you expect users to interact with the new classes?
669edde to
0e980af
Compare
|
@fkiraly plz sir review it . |
|
Hi @Tarun-goswamii , I think @fkiraly meant to give an example in conversation itself and not put it in code right now, I would suggest to revert the commit and post here in the conversation instead first and confirm with him. |
This reverts commit 960ce06.
Practical Design & UsageDear @fkiraly. Real-world pipelineDirect chaining example: from pyaptamer.trafos.torch import GreedyEncode, RandomMask, DNAtoRNA, Reverse
# String transforms can be chained
dna_sequence = "ACGTACGT"
rna_transform = DNAtoRNA()
reverse_transform = Reverse()
rna_seq = rna_transform(dna_sequence) # "ACGUACGU"
reversed_seq = reverse_transform(rna_seq) # "UGCAUGCA"
# Encoding to tensor
vocab = {"A": 1, "C": 2, "G": 3, "U": 4}
encoder = GreedyEncode(vocab=vocab, max_len=16)
encoded = encoder(rna_seq)
# Result: tensor([1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0])
# Apply augmentation
masker = RandomMask(mask_idx=99, mask_rate=0.15)
masked = masker(encoded)
# Some non-padded positions are randomly replaced with 99
@fkiraly @siddharth7113 plz review it |
| from pyaptamer.trafos.torch._base import BaseTorchTransform | ||
|
|
||
|
|
||
| class Reverse(BaseTorchTransform): |
|
Also please ensure code quality checks are passing in CI, and pre-commit is installed |
String transforms (Reverse, DNAtoRNA) only operate on Python strings and don't use torch, so they shouldn't inherit from BaseTorchTransform. This separates the architecture properly - only transforms that actually use torch (GreedyEncode, RandomMask) inherit from the base. Also add __repr__ methods for consistency with other transforms. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| @@ -0,0 +1,21 @@ | |||
| """String-to-string transformations.""" | |||
There was a problem hiding this comment.
this still doesn't make sense, why have a _string in torch ? when there are no tensor operations involved.
| @@ -1 +1 @@ | |||
| """Transformations.""" | |||
| """Transformations module.""" | |||
There was a problem hiding this comment.
thank u , i will correct it!
String transforms (Reverse, DNAtoRNA) don't use torch operations, so they shouldn't be in the torch module. This was architecturally confusing and didn't make sense. - Removed pyaptamer/trafos/torch/_string.py - Removed DNAtoRNA and Reverse from torch module exports - Removed related tests from test_torch_transforms.py - Reverted unnecessary docstring change in trafos/__init__.py Only kept torch-specific transforms: GreedyEncode and RandomMask. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
siddharth7113
left a comment
There was a problem hiding this comment.
Hi @Tarun-goswamii ,
There are some design decision that would be needed for this PR, we can wait for @fkiraly response for some time then raise in discord if needed
| X, _ = self._check_X_y(X, None) | ||
| return X | ||
|
|
||
| def get_torch_transform(self): |
There was a problem hiding this comment.
What is the reasoning and use of the function that is being added here? Please explain, since it's just returning None?
There was a problem hiding this comment.
Dear @siddharth7113 The get_torch_transform() is designed as an extension point in the sklearn base class.
By default it returns None, meaning "this transform doesn't have a torch equivalent."
Subclasses that have torch-compatible versions can override this to return them. For example, the sklearn GreedyEncode encoder overrides this method to return the torch GreedyEncode transform with the same parameters.
This allows flexibility - not every sklearn transform needs a torch version, but those that do can provide it through this interface.
| from torch import Tensor | ||
|
|
||
|
|
||
| class BaseTorchTransform(ABC): |
There was a problem hiding this comment.
cc @fkiraly , This is the proposed design for the torch base class, but the currently it only puts call and repr, do we have something similar in sktime to borrow for torch related transformers?
Closes #171
Added trafos.torch module with transforms for DataLoaders.
Includes str->str (Reverse, DNAtoRNA), str->tensor (GreedyEncode), and tensor->tensor (RandomMask) transforms. Also added get_torch_transform() to sklearn base class per discussion in #170.