Skip to content

Add protocols to define type structure alongside base classes #39

@sgreenbury

Description

@sgreenbury

While we have base classes that can be subclassed, it could also be helpful to define protocols to facilitate API that does not rely on the subclassing directly.

This might be more useful for API downstream of training where the APIs can be more lightweight and just needs specific methods.

Possible examples could be:

from typing import Protocol

from auto_cast.types import Batch, EncodedBatch, RolloutOutput, Tensor


class EncoderProtocol(Protocol):
    """Encoder Protocol."""

    def encode(self, batch: Batch) -> Tensor: ...
    def __call__(self, batch: Batch) -> Tensor: ...


class ProcessorProtocol(Protocol):
    """Processor Protocol."""

    def map(self, x: Tensor) -> Tensor: ...
    def rollout(self, batch: EncodedBatch) -> RolloutOutput: ...
    def __call__(self, batch: EncodedBatch) -> Tensor: ...


class DecoderProtocol(Protocol):
    """Decoder Protocol."""

    def decode(self, z: Tensor) -> Tensor: ...
    def __call__(self, batch: EncodedBatch) -> Tensor: ...


class EncoderDecoderProtocol(Protocol):
    """Encoder-Decoder Protocol."""

    def encode(self, batch: Batch) -> Tensor: ...
    def decode(self, z: Tensor) -> Tensor: ...
    def encode_decode(self, x: Batch) -> Tensor: ...
    def __call__(self, batch: Batch) -> Tensor: ...


class EncoderProcessorDecoderProtocol(Protocol):
    """Encoder-Decoder Protocol."""

    def encode(self, batch: Batch) -> Tensor: ...
    def decode(self, z: Tensor) -> Tensor: ...
    def map(self, batch: EncodedBatch) -> Tensor: ...
    def encode_decode(self, batch: Batch) -> Tensor: ...
    def rollout(self, batch: Batch) -> RolloutOutput: ...
    def __call__(self, batch: Batch) -> Tensor: ...

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    Status

    In progress

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions