From 76947c01c06be94d9fcd1a8c27d7cca29cda20e1 Mon Sep 17 00:00:00 2001 From: Cameron Riddell Date: Wed, 12 Feb 2025 15:38:47 -0800 Subject: [PATCH 1/2] enh pluggable backends draft --- narwhals/backends.py | 360 +++++++++++++++++++++++++++++++++++++ narwhals/dataframe.py | 3 +- narwhals/translate.py | 401 ++++-------------------------------------- 3 files changed, 395 insertions(+), 369 deletions(-) create mode 100644 narwhals/backends.py diff --git a/narwhals/backends.py b/narwhals/backends.py new file mode 100644 index 0000000000..1b44bca9e3 --- /dev/null +++ b/narwhals/backends.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +import sys +from dataclasses import dataclass +from dataclasses import field +from dataclasses import replace +from importlib import import_module +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Generator +from typing import TypeVar +from typing import cast + +from narwhals.dataframe import DataFrame +from narwhals.dataframe import LazyFrame +from narwhals.series import Series +from narwhals.utils import Implementation +from narwhals.utils import Version +from narwhals.utils import parse_version + +if TYPE_CHECKING: + from types import ModuleType + + T = TypeVar("T") + + +BACKENDS = [] + + +@dataclass +class Adaptation: + narwhals: type + native: str | type + adapter: str | type + level: str + kwargs: dict[str, Any] = field(default_factory=dict) + version: Version | None = None + + @property + def imported_adapter(self) -> type: + if isinstance(self.adapter, str): + return dynamic_import(self.adapter) + return self.adapter + + @property + def imported_native(self) -> type: + if isinstance(self.native, str): + return dynamic_import(self.native) + return self.native + + +@dataclass +class Backend: + requires: list[tuple[str, str | Callable, tuple[int, ...]]] + adaptations: list[Adaptation] + implementation: Implementation | None = None + + def __post_init__(self) -> None: + adaptations = [] + for adapt in self.adaptations: + if adapt.version in Version: + adaptations.append(adapt) + elif adapt.version is None: + adaptations.extend(replace(adapt, version=v) for v in Version) + else: + msg = "Adaptation.version must be {Version!r} or None, got {adapt.version!r}" + raise TypeError(msg) + self.adaptations = adaptations + + def get_adapter( + self, cls: type, version: Version = Version.MAIN + ) -> Adaptation | None: + module_name, *_ = cls.__module__.split(".", maxsplit=1) + for adapt in self.adaptations: + if adapt.version != version: + continue + + if isinstance(adapt.native, type) and cls is adapt.native: + return adapt + + elif isinstance(adapt.native, str): + adapt.native = cast(str, adapt.native) + adapt_module_name, *_, adapt_cls_name = adapt.native.split(".") + if ( + (adapt_module_name in sys.modules) # base-module is imported + and (module_name == adapt_module_name) # roots match + and (cls.__name__ == adapt_cls_name) # tips match + and (cls is dynamic_import(adapt.native)) # types are identical + ): + return adapt + return None + + def validate_backend_version(self) -> None: + for module_name, version_getter, min_version in self.requires: + # TODO(camriddell): this logic may be better suited for a Version namedtuple or dataclass + if callable(version_getter): + version_str = version_getter() + elif isinstance(version_getter, str): + version_str = dynamic_import(version_getter) + else: + msg = "version_getter {version_getter!r} must be a string or callable, got {type(version_getter)}" + raise TypeError(msg) + + installed_version = parse_version(version_str) + if installed_version < min_version: + msg = f"{module_name} must be updated to at least {min_version}, got {installed_version}" + raise ValueError(msg) + + def version(self) -> tuple[int, ...]: + version_getter = self.requires[0][1] + # TODO(camriddell): this logic may be better suited for a Version namedtuple or dataclass + if callable(version_getter): + version_str = version_getter() + elif isinstance(version_getter, str): + version_str = dynamic_import(version_getter) + else: + msg = "version_getter {version_getter!r} must be a string or callable, got {type(version_getter)}" + raise TypeError(msg) + return parse_version(version_str) + + def native_namespace(self) -> ModuleType: + return import_module(self.requires[0][0]) + + def get_native_namespace(self) -> ModuleType | None: + return sys.modules.get(self.requires[0][0], None) + + +def register_backends(*backends: Backend) -> None: + for b in backends: + BACKENDS.append(b) # noqa: PERF402 + + +def traverse_rsplits(text: str, sep: str = " ") -> Generator[tuple[str, list[str]]]: + sep_count = text.count(sep) + if sep_count == 0: + yield (text, []) + + for i in range(1, sep_count + 1): + base, *remaining = text.rsplit(sep, maxsplit=i) + yield base, remaining + + +def dynamic_import(dotted_path: str) -> Any: + for base, attributes in traverse_rsplits(dotted_path, sep="."): + if not attributes: + continue + try: + module = import_module(base) + except ImportError: + pass + else: + obj = module + for attr in attributes: + obj = getattr(obj, attr) + return obj + msg = "Could not import {dotted_path!r}" + raise ImportError(msg) + + +register_backends( + Backend( + requires=[ + ("pandas", "pandas.__version__", (0, 25, 3)), + ], + adaptations=[ + Adaptation( + DataFrame, + "pandas.DataFrame", + "narwhals._pandas_like.dataframe.PandasLikeDataFrame", + level="full", + kwargs={"validate_column_names": True}, + ), + Adaptation( + Series, + "pandas.Series", + "narwhals._pandas_like.dataframe.PandasLikeSeries", + level="full", + ), + ], + implementation=Implementation.PANDAS, + ), + Backend( + requires=[ + ("polars", "polars.__version__", (0, 20, 3)), + ], + adaptations=[ + Adaptation( + LazyFrame, + "polars.LazyFrame", + "narwhals._polars.dataframe.PolarsLazyFrame", + level="full", + ), + Adaptation( + DataFrame, + "polars.DataFrame", + "narwhals._polars.dataframe.PolarsDataFrame", + level="full", + ), + Adaptation( + Series, + "polars.Series", + "narwhals._polars.series.PolarsSeries", + level="full", + ), + ], + ), + Backend( + requires=[("modin.pandas", "modin.__version__", (0, 25, 3))], + adaptations=[ + Adaptation( + DataFrame, + "modin.pandas.DataFrame", + "narwhals._pandas_like.dataframe.PandasLikeDataFrame", + level="full", + kwargs={"validate_column_names": True}, + ), + Adaptation( + Series, + "modin.pandas.Series", + "narwhals._pandas_like.dataframe.PandasLikeSeries", + level="full", + ), + ], + implementation=Implementation.MODIN, + ), + Backend( + requires=[ + ("cudf", "cudf.__version__", (24, 10)), + ], + adaptations=[ + Adaptation( + DataFrame, + "cudf.DataFrame", + "narwhals._pandas_like.dataframe.PandasLikeDataFrame", + level="full", + kwargs={"validate_column_names": True}, + ), + Adaptation( + Series, + "cudf.Series", + "narwhals._pandas_like.dataframe.PandasLikeSeries", + level="full", + ), + ], + implementation=Implementation.CUDF, + ), + Backend( + requires=[ + ("pyarrow", "pyarrow.__version__", (11,)), + ], + adaptations=[ + Adaptation( + DataFrame, + "pyarrow.Table", + "narwhals._arrow.dataframe.ArrowDataFrame", + level="full", + kwargs={"validate_column_names": True}, + ), + Adaptation( + Series, + "pyarrow.ChunkedArray", + "narwhals._arrow.series.ArrowSeries", + level="full", + kwargs={"name": ""}, + ), + ], + ), + Backend( + requires=[("pyspark.sql", "pyspark.__version__", (3, 5))], + adaptations=[ + Adaptation( + LazyFrame, + "pyspark.sql.DataFrame", + "narwhals._spark.dataframe.SparkLikeLazyFrame", + level="full", + kwargs={"validate_column_names": True}, + ), + Adaptation( + Series, + "pyspark.sql.Series", + "narwhals._arrow.dataframe.ArrowSeries", + level="full", + ), + ], + implementation=Implementation.PYSPARK, + ), + Backend( + requires=[ + ("dask.dataframe", "dask.__version__", (2024, 8)), + ("dask_expr", "dask_expr.__version__", (0,)), + ], + adaptations=[ + Adaptation( + LazyFrame, + "dask.dataframe.DataFrame", + "narwhals._dask.dataframe.DaskLazyFrame", + level="full", + kwargs={"validate_column_names": True}, + ), + Adaptation( + LazyFrame, + "dask_expr.DataFrame", + "narwhals._dask.dataframe.DaskLazyFrame", + level="full", + kwargs={"validate_column_names": True}, + ), + ], + ), + Backend( + requires=[("duckdb", "duckdb.__version__", (1,))], + adaptations=[ + Adaptation( + LazyFrame, + "duckdb.DuckDBPyRelation", + "narwhals._duckdb.dataframe.DuckDBLazyFrame", + level="full", + kwargs={"validate_column_names": True}, + version=Version.MAIN, + ), + Adaptation( + DataFrame, + "duckdb.DuckDBPyRelation", + "narwhals._duckdb.dataframe.DuckDBLazyFrame", + level="interchange", + version=Version.V1, + kwargs={"validate_column_names": True}, + ), + ], + ), + Backend( + requires=[ + ("ibis", "ibis.__version__", (6,)), + ], + adaptations=[ + Adaptation( + LazyFrame, + "ibis.expr.types.Table", + "narwhals._ibis.dataframe.IbisLazyFrame", + level="full", + kwargs={"validate_column_names": True}, + ), + ], + ), + Backend( + requires=[ + ("sqlframe", "sqlframe._version.__version__", (3, 14, 2)), + ], + adaptations=[ + Adaptation( + LazyFrame, + "sqlframe.base.dataframe.BaseDataFrame", + "narwhals._spark.dataframe.SparkLikeLazyFrame", + level="full", + kwargs={"validate_column_names": True}, + ), + ], + implementation=Implementation.SQLFRAME, + ), +) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index b7c53ff32d..3557bfb4b5 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -26,7 +26,6 @@ from narwhals.exceptions import LengthChangingExprError from narwhals.exceptions import OrderDependentExprError from narwhals.schema import Schema -from narwhals.translate import to_native from narwhals.utils import Implementation from narwhals.utils import find_stacklevel from narwhals.utils import flatten @@ -2364,6 +2363,8 @@ def to_native(self: Self) -> FrameT: └───────┴───────┘ """ + from narwhals.translate import to_native + return to_native(narwhals_object=self, pass_through=False) # inherited diff --git a/narwhals/translate.py b/narwhals/translate.py index e4c08507ac..b2939798fb 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -11,32 +11,11 @@ from typing import TypeVar from typing import overload -from narwhals.dependencies import get_cudf +import narwhals.backends from narwhals.dependencies import get_cupy -from narwhals.dependencies import get_dask -from narwhals.dependencies import get_dask_expr -from narwhals.dependencies import get_modin from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas -from narwhals.dependencies import get_polars from narwhals.dependencies import get_pyarrow -from narwhals.dependencies import get_pyspark -from narwhals.dependencies import is_cudf_dataframe -from narwhals.dependencies import is_cudf_series -from narwhals.dependencies import is_dask_dataframe -from narwhals.dependencies import is_duckdb_relation -from narwhals.dependencies import is_ibis_table -from narwhals.dependencies import is_modin_dataframe -from narwhals.dependencies import is_modin_series -from narwhals.dependencies import is_pandas_dataframe -from narwhals.dependencies import is_pandas_series -from narwhals.dependencies import is_polars_dataframe -from narwhals.dependencies import is_polars_lazyframe -from narwhals.dependencies import is_polars_series -from narwhals.dependencies import is_pyarrow_chunked_array -from narwhals.dependencies import is_pyarrow_table -from narwhals.dependencies import is_pyspark_dataframe -from narwhals.dependencies import is_sqlframe_dataframe from narwhals.utils import Version if TYPE_CHECKING: @@ -364,54 +343,28 @@ def _from_native_impl( # noqa: PLR0915 from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame from narwhals.series import Series - from narwhals.utils import Implementation from narwhals.utils import _supports_dataframe_interchange from narwhals.utils import is_compliant_dataframe from narwhals.utils import is_compliant_lazyframe from narwhals.utils import is_compliant_series - from narwhals.utils import parse_version # Early returns if isinstance(native_object, (DataFrame, LazyFrame)) and not series_only: return native_object if isinstance(native_object, Series) and (series_only or allow_series): return native_object - if series_only: if allow_series is False: msg = "Invalid parameter combination: `series_only=True` and `allow_series=False`" raise ValueError(msg) allow_series = True + if eager_only and eager_or_interchange_only: msg = "Invalid parameter combination: `eager_only=True` and `eager_or_interchange_only=True`" raise ValueError(msg) - # SQLFrame - # This one needs checking before extensions as `hasattr` always returns `True`. - if is_sqlframe_dataframe(native_object): # pragma: no cover - from narwhals._spark_like.dataframe import SparkLikeLazyFrame - - if series_only: - msg = "Cannot only use `series_only` with SQLFrame DataFrame" - raise TypeError(msg) - if eager_only or eager_or_interchange_only: - msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with SQLFrame DataFrame" - raise TypeError(msg) - import sqlframe._version - - backend_version = parse_version(sqlframe._version) - return LazyFrame( - SparkLikeLazyFrame( - native_object, - backend_version=backend_version, - version=version, - implementation=Implementation.SQLFRAME, - ), - level="lazy", - ) - # Extensions - elif is_compliant_dataframe(native_object): + if is_compliant_dataframe(native_object): if series_only: if not pass_through: msg = "Cannot only use `series_only` with dataframe" @@ -447,326 +400,47 @@ def _from_native_impl( # noqa: PLR0915 level="full", ) - # Polars - elif is_polars_dataframe(native_object): - from narwhals._polars.dataframe import PolarsDataFrame - - if series_only: - if not pass_through: - msg = "Cannot only use `series_only` with polars.DataFrame" - raise TypeError(msg) - return native_object - pl = get_polars() - return DataFrame( - PolarsDataFrame( - native_object, backend_version=parse_version(pl), version=version - ), - level="full", - ) - elif is_polars_lazyframe(native_object): - from narwhals._polars.dataframe import PolarsLazyFrame + for backend in reversed(narwhals.backends.BACKENDS): + adapter = backend.get_adapter(type(native_object), version=version) + if adapter is None: + continue - if series_only: - if not pass_through: - msg = "Cannot only use `series_only` with polars.LazyFrame" - raise TypeError(msg) - return native_object - if eager_only or eager_or_interchange_only: - if not pass_through: - msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with polars.LazyFrame" - raise TypeError(msg) - return native_object - pl = get_polars() - return LazyFrame( - PolarsLazyFrame( - native_object, backend_version=parse_version(pl), version=version - ), - level="lazy", - ) - elif is_polars_series(native_object): - from narwhals._polars.series import PolarsSeries - - pl = get_polars() - if not allow_series: - if not pass_through: - msg = "Please set `allow_series=True` or `series_only=True`" - raise TypeError(msg) - return native_object - return Series( - PolarsSeries( - native_object, backend_version=parse_version(pl), version=version - ), - level="full", - ) + kwargs = adapter.kwargs + if backend.implementation is not None: + kwargs = kwargs.copy() + kwargs.setdefault("implementation", backend.implementation) - # pandas - elif is_pandas_dataframe(native_object): - from narwhals._pandas_like.dataframe import PandasLikeDataFrame - - if series_only: - if not pass_through: - msg = "Cannot only use `series_only` with dataframe" - raise TypeError(msg) - return native_object - pd = get_pandas() - return DataFrame( - PandasLikeDataFrame( - native_object, - backend_version=parse_version(pd), - implementation=Implementation.PANDAS, - version=version, - validate_column_names=True, - ), - level="full", - ) - elif is_pandas_series(native_object): - from narwhals._pandas_like.series import PandasLikeSeries - - if not allow_series: - if not pass_through: - msg = "Please set `allow_series=True` or `series_only=True`" - raise TypeError(msg) - return native_object - pd = get_pandas() - return Series( - PandasLikeSeries( - native_object, - implementation=Implementation.PANDAS, - backend_version=parse_version(pd), - version=version, - ), - level="full", - ) - - # Modin - elif is_modin_dataframe(native_object): # pragma: no cover - from narwhals._pandas_like.dataframe import PandasLikeDataFrame - - mpd = get_modin() - if series_only: - if not pass_through: - msg = "Cannot only use `series_only` with modin.DataFrame" - raise TypeError(msg) - return native_object - return DataFrame( - PandasLikeDataFrame( - native_object, - implementation=Implementation.MODIN, - backend_version=parse_version(mpd), - version=version, - validate_column_names=True, - ), - level="full", - ) - elif is_modin_series(native_object): # pragma: no cover - from narwhals._pandas_like.series import PandasLikeSeries - - mpd = get_modin() - if not allow_series: - if not pass_through: - msg = "Please set `allow_series=True` or `series_only=True`" - raise TypeError(msg) - return native_object - return Series( - PandasLikeSeries( - native_object, - implementation=Implementation.MODIN, - backend_version=parse_version(mpd), - version=version, - ), - level="full", - ) - - # cuDF - elif is_cudf_dataframe(native_object): # pragma: no cover - from narwhals._pandas_like.dataframe import PandasLikeDataFrame - - cudf = get_cudf() - if series_only: - if not pass_through: - msg = "Cannot only use `series_only` with cudf.DataFrame" - raise TypeError(msg) - return native_object - return DataFrame( - PandasLikeDataFrame( - native_object, - implementation=Implementation.CUDF, - backend_version=parse_version(cudf), - version=version, - validate_column_names=True, - ), - level="full", - ) - elif is_cudf_series(native_object): # pragma: no cover - from narwhals._pandas_like.series import PandasLikeSeries - - cudf = get_cudf() - if not allow_series: - if not pass_through: - msg = "Please set `allow_series=True` or `series_only=True`" - raise TypeError(msg) - return native_object - return Series( - PandasLikeSeries( - native_object, - implementation=Implementation.CUDF, - backend_version=parse_version(cudf), - version=version, - ), - level="full", - ) - - # PyArrow - elif is_pyarrow_table(native_object): - from narwhals._arrow.dataframe import ArrowDataFrame - - pa = get_pyarrow() - if series_only: - if not pass_through: - msg = "Cannot only use `series_only` with arrow table" - raise TypeError(msg) - return native_object - return DataFrame( - ArrowDataFrame( - native_object, - backend_version=parse_version(pa), - version=version, - validate_column_names=True, - ), - level="full", - ) - elif is_pyarrow_chunked_array(native_object): - from narwhals._arrow.series import ArrowSeries - - pa = get_pyarrow() - if not allow_series: - if not pass_through: - msg = "Please set `allow_series=True` or `series_only=True`" - raise TypeError(msg) - return native_object - return Series( - ArrowSeries( - native_object, backend_version=parse_version(pa), name="", version=version - ), - level="full", - ) - - # Dask - elif is_dask_dataframe(native_object): - from narwhals._dask.dataframe import DaskLazyFrame - - if series_only: - if not pass_through: - msg = "Cannot only use `series_only` with dask DataFrame" - raise TypeError(msg) - return native_object - if eager_only or eager_or_interchange_only: - if not pass_through: - msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with dask DataFrame" - raise TypeError(msg) - return native_object - if ( - parse_version(get_dask()) <= (2024, 12, 1) and get_dask_expr() is None - ): # pragma: no cover - msg = "Please install dask-expr" - raise ImportError(msg) - return LazyFrame( - DaskLazyFrame( - native_object, - backend_version=parse_version(get_dask()), - version=version, - validate_column_names=True, - ), - level="lazy", - ) - - # DuckDB - elif is_duckdb_relation(native_object): - from narwhals._duckdb.dataframe import DuckDBLazyFrame - - if eager_only or series_only: # pragma: no cover - if not pass_through: - msg = ( - "Cannot only use `series_only=True` or `eager_only=False` " - "with DuckDBPyRelation" - ) - else: + if adapter.narwhals is Series and not (allow_series or series_only): + if pass_through: return native_object + msg = "Please set `allow_series=True` or `series_only=True`" raise TypeError(msg) - import duckdb # ignore-banned-import - - backend_version = parse_version(duckdb) - if version is Version.V1: - return DataFrame( - DuckDBLazyFrame( - native_object, - backend_version=backend_version, - version=version, - validate_column_names=True, - ), - level="interchange", - ) - return LazyFrame( - DuckDBLazyFrame( - native_object, - backend_version=backend_version, - version=version, - validate_column_names=True, - ), - level="lazy", - ) - # Ibis - elif is_ibis_table(native_object): # pragma: no cover - from narwhals._ibis.dataframe import IbisLazyFrame - - if eager_only or series_only: + if adapter.narwhals is not Series and series_only: if not pass_through: - msg = ( - "Cannot only use `series_only=True` or `eager_only=False` " - "with Ibis table" - ) + msg = f"Cannot only use `series_only` with {type(native_object)!r}" raise TypeError(msg) return native_object - import ibis # ignore-banned-import - - backend_version = parse_version(ibis) - if version is Version.V1: - return DataFrame( - IbisLazyFrame( - native_object, backend_version=backend_version, version=version - ), - level="interchange", - ) - return LazyFrame( - IbisLazyFrame( - native_object, backend_version=backend_version, version=version - ), - level="lazy", - ) - # PySpark - elif is_pyspark_dataframe(native_object): # pragma: no cover - from narwhals._spark_like.dataframe import SparkLikeLazyFrame + elif (adapter.narwhals is LazyFrame and eager_only) or eager_or_interchange_only: + if pass_through: + return native_object + elif not eager_or_interchange_only: + msg = f"Cannot only use `eager_only` or `eager_or_interchange_only` with {type(native_object)!r}" + raise TypeError(msg) - if series_only: - msg = "Cannot only use `series_only` with pyspark DataFrame" - raise TypeError(msg) - if eager_only or eager_or_interchange_only: - msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with pyspark DataFrame" - raise TypeError(msg) - return LazyFrame( - SparkLikeLazyFrame( + return adapter.narwhals( + adapter.imported_adapter( native_object, - backend_version=parse_version(get_pyspark()), - version=version, - implementation=Implementation.PYSPARK, + version=adapter.version, + backend_version=backend.version(), + **kwargs, ), - level="lazy", + level=adapter.level, ) # Interchange protocol - elif _supports_dataframe_interchange(native_object): + if _supports_dataframe_interchange(native_object): from narwhals._interchange.dataframe import InterchangeFrame if eager_only or series_only: @@ -823,19 +497,10 @@ def get_native_namespace( if has_native_namespace(obj): return obj.__native_namespace__() - if is_pandas_dataframe(obj) or is_pandas_series(obj): - return get_pandas() - if is_modin_dataframe(obj) or is_modin_series(obj): # pragma: no cover - return get_modin() - if is_pyarrow_table(obj) or is_pyarrow_chunked_array(obj): - return get_pyarrow() - if is_cudf_dataframe(obj) or is_cudf_series(obj): # pragma: no cover - return get_cudf() - if is_dask_dataframe(obj): # pragma: no cover - return get_dask() - if is_polars_dataframe(obj) or is_polars_lazyframe(obj) or is_polars_series(obj): - return get_polars() - msg = f"Could not get native namespace from object of type: {type(obj)}" + for backend in reversed(narwhals.backends.BACKENDS): + if backend.get_adapter(type(obj)) is not None: + return backend.native_namespace() + msg = "Could not get native namespace" raise TypeError(msg) From 24bc4051997aaa6eeaf02a7fa77869166ae34357 Mon Sep 17 00:00:00 2001 From: Cameron Riddell Date: Mon, 10 Mar 2025 10:50:48 -0700 Subject: [PATCH 2/2] Refactor narwhals backends - split Backend requirements into the core package requirement and extras (would be nice to refactor this out once the Implementation enum is refactored) - add Requirement class to store backend package/module requirements - add MROAdaptation class for matching against inherited types - BACKENDS represented by a deque for priority left/right appensions - refactor adaptation matching logic to exist in the Adapatation class - refactor version validation to Backend class --- narwhals/backends.py | 346 ++++++++++++++++++++++++++++++------------ narwhals/translate.py | 18 +-- narwhals/utils.py | 3 +- 3 files changed, 258 insertions(+), 109 deletions(-) diff --git a/narwhals/backends.py b/narwhals/backends.py index 0b0b44df98..f42826c474 100644 --- a/narwhals/backends.py +++ b/narwhals/backends.py @@ -1,14 +1,17 @@ from __future__ import annotations import sys +from collections import deque from dataclasses import dataclass from dataclasses import field -from dataclasses import replace from importlib import import_module +from textwrap import indent +from types import ModuleType from typing import TYPE_CHECKING from typing import Any from typing import Callable from typing import Generator +from typing import Literal from typing import TypeVar from narwhals.dataframe import DataFrame @@ -19,25 +22,38 @@ from narwhals.utils import parse_version if TYPE_CHECKING: - from types import ModuleType - T = TypeVar("T") -BACKENDS = [] - - @dataclass class Adaptation: - narwhals: type + """Links a Narwhals `interface` to a `native` type through the `adapter` class. + + interface: the narwhals type + native: the native type (e.g. Polars.LazyFrame, pandas.DataFrame, modin.pandas.DataFrame) + adapter: the class that implements the API of the interface class using the mechanics of the native class. + level: The degree of support that Narwhals has for this native class. + kwargs: Additional kwargs that should be passed to the adapter class when creating an instance. + version: The version(s) of the Narwhals API this adaptation supports. + """ + + interface: type[LazyFrame[Any] | DataFrame[Any] | Series[Any]] native: str | type adapter: str | type - level: str + level: Literal["full", "lazy", "interchange"] kwargs: dict[str, Any] = field(default_factory=dict) - version: Version | None = None + version: Version = Version.MAIN @property def imported_adapter(self) -> type: + """The object returned by importing `self.adapter`. + + Returns: + The type specified by `self.adapter`. + + Raises: + ImportError: If the type specified by the string is unimportable. + """ obj = dynamic_import(self.adapter) if not isinstance(obj, type): msg = f"Attempted to import {self.adapter!r}, expected an instance of type but got {obj}" @@ -46,94 +62,191 @@ def imported_adapter(self) -> type: @property def imported_native(self) -> type: + """The object returned by importing `self.native`. + + Returns: + The type specified by `self.native`. + + Raises: + ImportError: If the type specified by the string is unimportable. + """ obj = dynamic_import(self.adapter) if not isinstance(obj, type): msg = f"Attempted to import {self.adapter!r}, expected an instance of type but got {obj}" raise TypeError(msg) return obj + def matches(self, cls: type, version: Version) -> bool: + """Determines whether this Adapter matches the passed `cls` and `version`. + + Returns: + True if the native object and version in this Adaptation matches the passed type and version. + False otherwise. + """ + module_name, *_ = cls.__module__.split(".", maxsplit=1) + + if version not in self.version: + return False + + if isinstance(self.native, type) and cls is self.native: + return True + + elif isinstance(self.native, str): + adapt_module_name, *_, adapt_cls_name = self.native.split(".") + if ( + (adapt_module_name in sys.modules) # base-module is imported + and (module_name == adapt_module_name) # roots match + and (cls.__name__ == adapt_cls_name) # tips match + and (cls is dynamic_import(self.native)) # types are identical + ): + return True + return False + @dataclass -class Backend: - requires: list[tuple[str, str | Callable[[], str], tuple[int, ...]]] - adaptations: list[Adaptation] - implementation: Implementation | None = None +class MROAdaptation(Adaptation): + """An Adaptation that matches the native object to any type in the mro of a passed type. - def __post_init__(self) -> None: - adaptations = [] - for adapt in self.adaptations: - if adapt.version in Version: - adaptations.append(adapt) - elif adapt.version is None: - adaptations.extend(replace(adapt, version=v) for v in Version) - else: - msg = "Adaptation.version must be {Version!r} or None, got {adapt.version!r}" - raise TypeError(msg) - self.adaptations = adaptations + Useful if a downstream package has a base-class with multiple subclasses + that can all be represented with the same Adaptation options. + """ - def get_adapter( - self, cls: type, version: Version = Version.MAIN - ) -> Adaptation | None: - module_name, *_ = cls.__module__.split(".", maxsplit=1) - for adapt in self.adaptations: - if adapt.version != version: - continue + def matches(self, cls: type, version: Version) -> bool: + match_func = super().matches + return any(match_func(cls=base_cls, version=version) for base_cls in cls.mro()) - if isinstance(adapt.native, type) and cls is adapt.native: - return adapt - elif isinstance(adapt.native, str): - adapt_module_name, *_, adapt_cls_name = adapt.native.split(".") - if ( - (adapt_module_name in sys.modules) # base-module is imported - and (module_name == adapt_module_name) # roots match - and (cls.__name__ == adapt_cls_name) # tips match - and (cls is dynamic_import(adapt.native)) # types are identical - ): - return adapt - return None +@dataclass +class Requirement: + """Represents a package/module requirement with a specified minimum version.""" - def validate_backend_version(self) -> None: - for module_name, version_getter, min_version in self.requires: - # TODO(camriddell): this logic may be better suited for a Version namedtuple or dataclass - if callable(version_getter): - version_str = version_getter() - elif isinstance(version_getter, str): - version_str = dynamic_import(version_getter) - else: - msg = "version_getter {version_getter!r} must be a string or callable, got {type(version_getter)}" - raise TypeError(msg) - - installed_version = parse_version(version_str) - if installed_version < min_version: - msg = f"{module_name} must be updated to at least {min_version}, got {installed_version}" - raise ValueError(msg) + module: str + version_getter: str | Callable[[], str] + min_version: tuple[int, ...] def version(self) -> tuple[int, ...]: - version_getter = self.requires[0][1] - # TODO(camriddell): this logic may be better suited for a Version namedtuple or dataclass - if callable(version_getter): - version_str = version_getter() - elif isinstance(version_getter, str): - version_str = dynamic_import(version_getter) + """Retrieve the version of the imported module. + + Returns: + The current version of the package/module. + """ + if callable(self.version_getter): + version_str = self.version_getter() + elif isinstance(self.version_getter, str): + version_str = dynamic_import(self.version_getter) else: msg = "version_getter {version_getter!r} must be a string or callable, got {type(version_getter)}" raise TypeError(msg) + return parse_version(version_str) - def native_namespace(self) -> ModuleType: - return import_module(self.requires[0][0]) + def is_valid(self) -> bool: + """Determines whether the imported package meets the requirements specified in this class. - def get_native_namespace(self) -> ModuleType | None: - return sys.modules.get(self.requires[0][0], None) + Returns: + True if the requirements specified for the module are met. False otherwise + """ + return self.version() >= self.min_version -def register_backends(*backends: Backend) -> None: - for b in backends: - BACKENDS.append(b) # noqa: PERF402 +@dataclass +class Backend: + """A collection of metadata that associates a package and its import types to narwhals interface(s). + + requirement: A requirement for the core package that this Backend represents. + adaptation: a list of Adaptations that link types from the package `requirement` to narwhals interfaces. + extra_requirements: any additional requirements that should be checked for the the use inside of Narwhals. + implementation: The narwhals Implementation to be passed to the adapter class (if it requires one). + """ + + requirement: Requirement + adaptations: list[Adaptation] + extra_requirements: list[Requirement] = field(default_factory=list) + implementation: Implementation | None = None + + @property + def requirements(self) -> Generator[Requirement]: + """Traverse all requirements in this Backend. + + Yields: + Each requirement specified across `self.requirement` and `self.extra_requirements` + """ + yield self.requirement + yield from self.extra_requirements + + def imported_package(self) -> ModuleType: + """The imported version of the package specified in `self.requirement`. + + Returns: + The imported package specified in `self.requirement` + """ + module = self.requirement.module + if module in sys.modules: + return sys.modules[module] + obj = dynamic_import(module) + if not isinstance(obj, ModuleType): + msg = f"Attempted to import {self.requirement.module!r}, expected an instance of ModuleType but got {obj}" + raise TypeError(msg) + return obj + + def validate_backend_version(self) -> None: + """Checks if all of the specified package requirements are met for this Backend. + + Returns: None + Raises: ValueError: if any of the package requirements are not met. + """ + messages = [f"{self!r} did not meet the following requirements"] + validity = [] + + for req in self.requirements: + validity.append(valid := req.is_valid()) + indicator = "\N{HEAVY CHECK MARK}" if valid else "\N{CROSS MARK}" + messages.append( + indent( + f"{indicator}: {req.module} installed {req.version()} >= {req.min_version}", + prefix=" " * 4, + ) + ) + + if not all(validity): + raise ValueError("\n".join(messages)) + + def get_adapter( + self, cls: type, version: Version = Version.MAIN + ) -> Adaptation | None: + """Retrieve the adapter that matches the passed information. + + Arguments: + cls: type + version: Version + + Returns: + Adapter if a match was found. None otherwise. + """ + for adapt in self.adaptations: + if adapt.matches(cls=cls, version=version): + return adapt + return None def traverse_rsplits(text: str, sep: str = " ") -> Generator[tuple[str, list[str]]]: + """Generates all possible rsplits of a string. + + Arguments: + text: str + sep: str + The separator that exists within the text argument + + Yields: + A partitioning of each of the possible rsplits of the inputted text. + + Examples: + >>> from narwhals.backends import traverse_rsplits + >>> list(traverse_rsplits("package.subpackage.module.type", sep=".")) + >>> ("package.subpackage.module", ["type"]) + >>> ("package.subpackage", ["module", "type"]) + >>> ("package", ["subpackage", "module", "type"]) + """ sep_count = text.count(sep) if sep_count == 0: yield (text, []) @@ -144,6 +257,22 @@ def traverse_rsplits(text: str, sep: str = " ") -> Generator[tuple[str, list[str def dynamic_import(dotted_path: str | type, /) -> Any: + """Attempts to retrieve a specific object specified by a dotted import path. + + Arguments: + dotted_path: str + A string that represents a valid Python import. + + Returns: + The object specified by the import string. + + Examples: + >>> from narwhals.backends import dynamic_import + >>> dynamic_import("math.log") + + >>> dynamic_import("os.path.abspath") + + """ if isinstance(dotted_path, type): return dotted_path for base, attributes in traverse_rsplits(dotted_path, sep="."): @@ -162,103 +291,113 @@ def dynamic_import(dotted_path: str | type, /) -> Any: raise ImportError(msg) +def register_backends(*backends: Backend) -> None: + """Adds Backend(s) to the global BACKENDS variable.""" + BACKENDS.extendleft(backends) + + +BACKENDS: deque[Backend] = deque() + + register_backends( Backend( - requires=[ - ("pandas", "pandas.__version__", (0, 25, 3)), - ], + Requirement("pandas", "pandas.__version__", (0, 25, 3)), adaptations=[ Adaptation( DataFrame, "pandas.DataFrame", "narwhals._pandas_like.dataframe.PandasLikeDataFrame", level="full", + version=Version.MAIN | Version.V1, kwargs={"validate_column_names": True}, ), Adaptation( Series, "pandas.Series", "narwhals._pandas_like.dataframe.PandasLikeSeries", + version=Version.MAIN | Version.V1, level="full", ), ], implementation=Implementation.PANDAS, ), Backend( - requires=[ - ("polars", "polars.__version__", (0, 20, 3)), - ], + Requirement("polars", "polars.__version__", (0, 20, 3)), adaptations=[ Adaptation( LazyFrame, "polars.LazyFrame", "narwhals._polars.dataframe.PolarsLazyFrame", + version=Version.MAIN | Version.V1, level="full", ), Adaptation( DataFrame, "polars.DataFrame", "narwhals._polars.dataframe.PolarsDataFrame", + version=Version.MAIN | Version.V1, level="full", ), Adaptation( Series, "polars.Series", "narwhals._polars.series.PolarsSeries", + version=Version.MAIN | Version.V1, level="full", ), ], ), Backend( - requires=[("modin.pandas", "modin.__version__", (0, 25, 3))], + Requirement("modin.pandas", "modin.__version__", (0, 25, 3)), adaptations=[ Adaptation( DataFrame, "modin.pandas.DataFrame", "narwhals._pandas_like.dataframe.PandasLikeDataFrame", level="full", + version=Version.MAIN | Version.V1, kwargs={"validate_column_names": True}, ), Adaptation( Series, "modin.pandas.Series", "narwhals._pandas_like.dataframe.PandasLikeSeries", + version=Version.MAIN | Version.V1, level="full", ), ], implementation=Implementation.MODIN, ), Backend( - requires=[ - ("cudf", "cudf.__version__", (24, 10)), - ], + Requirement("cudf", "cudf.__version__", (24, 10)), adaptations=[ Adaptation( DataFrame, "cudf.DataFrame", "narwhals._pandas_like.dataframe.PandasLikeDataFrame", level="full", + version=Version.MAIN | Version.V1, kwargs={"validate_column_names": True}, ), Adaptation( Series, "cudf.Series", "narwhals._pandas_like.dataframe.PandasLikeSeries", + version=Version.MAIN | Version.V1, level="full", ), ], implementation=Implementation.CUDF, ), Backend( - requires=[ - ("pyarrow", "pyarrow.__version__", (11,)), - ], + Requirement("pyarrow", "pyarrow.__version__", (11,)), adaptations=[ Adaptation( DataFrame, "pyarrow.Table", "narwhals._arrow.dataframe.ArrowDataFrame", level="full", + version=Version.MAIN | Version.V1, kwargs={"validate_column_names": True}, ), Adaptation( @@ -266,18 +405,20 @@ def dynamic_import(dotted_path: str | type, /) -> Any: "pyarrow.ChunkedArray", "narwhals._arrow.series.ArrowSeries", level="full", + version=Version.MAIN | Version.V1, kwargs={"name": ""}, ), ], ), Backend( - requires=[("pyspark.sql", "pyspark.__version__", (3, 5))], + Requirement("pyspark.sql", "pyspark.__version__", (3, 5)), adaptations=[ Adaptation( LazyFrame, "pyspark.sql.DataFrame", - "narwhals._spark.dataframe.SparkLikeLazyFrame", + "narwhals._spark_like.dataframe.SparkLikeLazyFrame", level="full", + version=Version.MAIN | Version.V1, kwargs={"validate_column_names": True}, ), Adaptation( @@ -285,14 +426,15 @@ def dynamic_import(dotted_path: str | type, /) -> Any: "pyspark.sql.Series", "narwhals._arrow.dataframe.ArrowSeries", level="full", + version=Version.MAIN | Version.V1, ), ], implementation=Implementation.PYSPARK, ), Backend( - requires=[ - ("dask.dataframe", "dask.__version__", (2024, 8)), - ("dask_expr", "dask_expr.__version__", (0,)), + Requirement("dask.dataframe", "dask.__version__", (2024, 8)), + extra_requirements=[ + Requirement("dask_expr", "dask_expr.__version__", (0,)), ], adaptations=[ Adaptation( @@ -300,6 +442,7 @@ def dynamic_import(dotted_path: str | type, /) -> Any: "dask.dataframe.DataFrame", "narwhals._dask.dataframe.DaskLazyFrame", level="full", + version=Version.MAIN | Version.V1, kwargs={"validate_column_names": True}, ), Adaptation( @@ -307,12 +450,13 @@ def dynamic_import(dotted_path: str | type, /) -> Any: "dask_expr.DataFrame", "narwhals._dask.dataframe.DaskLazyFrame", level="full", + version=Version.MAIN | Version.V1, kwargs={"validate_column_names": True}, ), ], ), Backend( - requires=[("duckdb", "duckdb.__version__", (1,))], + Requirement("duckdb", "duckdb.__version__", (1,)), adaptations=[ Adaptation( LazyFrame, @@ -333,29 +477,33 @@ def dynamic_import(dotted_path: str | type, /) -> Any: ], ), Backend( - requires=[ - ("ibis", "ibis.__version__", (6,)), - ], + Requirement("ibis", "ibis.__version__", (6,)), adaptations=[ + Adaptation( + DataFrame, + "ibis.expr.types.Table", + "narwhals._ibis.dataframe.IbisLazyFrame", + level="full", + version=Version.V1, + ), Adaptation( LazyFrame, "ibis.expr.types.Table", "narwhals._ibis.dataframe.IbisLazyFrame", level="full", - kwargs={"validate_column_names": True}, + version=Version.MAIN, ), ], ), Backend( - requires=[ - ("sqlframe", "sqlframe._version.__version__", (3, 14, 2)), - ], + Requirement("sqlframe", "sqlframe._version.__version__", (3, 22, 0)), adaptations=[ - Adaptation( + MROAdaptation( LazyFrame, "sqlframe.base.dataframe.BaseDataFrame", - "narwhals._spark.dataframe.SparkLikeLazyFrame", + "narwhals._spark_like.dataframe.SparkLikeLazyFrame", level="full", + version=Version.MAIN | Version.V1, kwargs={"validate_column_names": True}, ), ], diff --git a/narwhals/translate.py b/narwhals/translate.py index 0ef7203a7a..5500fe31a2 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -401,7 +401,7 @@ def _from_native_impl( # noqa: PLR0915 level="full", ) - for backend in reversed(narwhals.backends.BACKENDS): + for backend in narwhals.backends.BACKENDS: adapter = backend.get_adapter(type(native_object), version=version) if adapter is None: continue @@ -411,30 +411,30 @@ def _from_native_impl( # noqa: PLR0915 kwargs = kwargs.copy() kwargs.setdefault("implementation", backend.implementation) - if adapter.narwhals is Series and not (allow_series or series_only): + if adapter.interface is Series and not (allow_series or series_only): if pass_through: return native_object msg = "Please set `allow_series=True` or `series_only=True`" raise TypeError(msg) - if adapter.narwhals is not Series and series_only: + if adapter.interface is not Series and series_only: if not pass_through: msg = f"Cannot only use `series_only` with {type(native_object)!r}" raise TypeError(msg) return native_object - elif (adapter.narwhals is LazyFrame and eager_only) or eager_or_interchange_only: + elif adapter.interface is LazyFrame and (eager_only or eager_or_interchange_only): if pass_through: return native_object elif not eager_or_interchange_only: msg = f"Cannot only use `eager_only` or `eager_or_interchange_only` with {type(native_object)!r}" raise TypeError(msg) - return adapter.narwhals( + return adapter.interface( adapter.imported_adapter( native_object, - version=adapter.version, - backend_version=backend.version(), + version=version, + backend_version=backend.requirement.version(), **kwargs, ), level=adapter.level, @@ -498,9 +498,9 @@ def get_native_namespace( if has_native_namespace(obj): return obj.__native_namespace__() - for backend in reversed(narwhals.backends.BACKENDS): + for backend in narwhals.backends.BACKENDS: if backend.get_adapter(type(obj)) is not None: - return backend.native_namespace() + return backend.imported_package() msg = "Could not get native namespace" raise TypeError(msg) diff --git a/narwhals/utils.py b/narwhals/utils.py index a75ead5f3d..f34202cab7 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -4,6 +4,7 @@ import re from datetime import timezone from enum import Enum +from enum import Flag from enum import auto from inspect import getattr_static from secrets import token_hex @@ -114,7 +115,7 @@ class _StoresColumns(Protocol): def columns(self) -> Sequence[str]: ... -class Version(Enum): +class Version(Flag): V1 = auto() MAIN = auto()