Type hinting a class that wraps function with __call__
method that conditionally modifies returned type
#1933
Answered
by
hauntsaninja
soneill-pure
asked this question in
Q&A
-
I'm trying to type hint a class that essentially wraps a function and adds some additional method(s). The original function can return any one of the from abc import ABC, abstractmethod
from collections.abc import Callable, Generator, Iterable
from typing import Generic, TypeVar
import pandas
import polars
InitialOutputFrame = TypeVar(
"InitialOutputFrame",
polars.DataFrame,
Iterable[polars.DataFrame], # maps to polars.LazyFrame
polars.LazyFrame,
pandas.DataFrame,
Iterable[pandas.DataFrame], # maps to pandas.DataFrame
)
OutputFrame = TypeVar("OutputFrame", polars.DataFrame, polars.LazyFrame, pandas.DataFrame)
type_mapping: dict[type, type] = {
polars.DataFrame: polars.DataFrame,
Iterable[polars.DataFrame]: polars.LazyFrame,
polars.LazyFrame: polars.LazyFrame,
pandas.DataFrame: pandas.DataFrame,
Iterable[pandas.DataFrame]: pandas.DataFrame,
}
class AbstractCachedFunction(ABC, Generic[InitialOutputFrame, OutputFrame]):
"""Abstract base class for cached functions."""
@abstractmethod
def __init__(self, func: Callable[..., InitialOutputFrame]): ...
# TODO: how to map correct OutputFrame type to each InitialOutputFrame type?
@abstractmethod
def __call__(self, *args, show_progress_bar: bool = True, **kwargs) -> OutputFrame: ...
# call `func` here and apply additional caching logic that leads to conditionally-modified return type
@abstractmethod
def get_path(self, *args, show_progress_bar: bool = True, **kwargs) -> Path:
"""Get the cached file path for the given arguments.""" |
Beta Was this translation helpful? Give feedback.
Answered by
hauntsaninja
Feb 25, 2025
Replies: 1 comment 1 reply
-
You could try using overloads on |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
soneill-pure
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You could try using overloads on
__call__
where the overloads match on generic self