-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
questionFurther information is requestedFurther information is requested
Description
While we have base classes that can be subclassed, it could also be helpful to define protocols to facilitate API that does not rely on the subclassing directly.
This might be more useful for API downstream of training where the APIs can be more lightweight and just needs specific methods.
Possible examples could be:
from typing import Protocol
from auto_cast.types import Batch, EncodedBatch, RolloutOutput, Tensor
class EncoderProtocol(Protocol):
"""Encoder Protocol."""
def encode(self, batch: Batch) -> Tensor: ...
def __call__(self, batch: Batch) -> Tensor: ...
class ProcessorProtocol(Protocol):
"""Processor Protocol."""
def map(self, x: Tensor) -> Tensor: ...
def rollout(self, batch: EncodedBatch) -> RolloutOutput: ...
def __call__(self, batch: EncodedBatch) -> Tensor: ...
class DecoderProtocol(Protocol):
"""Decoder Protocol."""
def decode(self, z: Tensor) -> Tensor: ...
def __call__(self, batch: EncodedBatch) -> Tensor: ...
class EncoderDecoderProtocol(Protocol):
"""Encoder-Decoder Protocol."""
def encode(self, batch: Batch) -> Tensor: ...
def decode(self, z: Tensor) -> Tensor: ...
def encode_decode(self, x: Batch) -> Tensor: ...
def __call__(self, batch: Batch) -> Tensor: ...
class EncoderProcessorDecoderProtocol(Protocol):
"""Encoder-Decoder Protocol."""
def encode(self, batch: Batch) -> Tensor: ...
def decode(self, z: Tensor) -> Tensor: ...
def map(self, batch: EncodedBatch) -> Tensor: ...
def encode_decode(self, batch: Batch) -> Tensor: ...
def rollout(self, batch: Batch) -> RolloutOutput: ...
def __call__(self, batch: Batch) -> Tensor: ...Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested
Type
Projects
Status
In progress