Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions narwhals/_plan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
min,
min_horizontal,
nth,
read_csv,
read_parquet,
sum,
sum_horizontal,
when,
Expand Down Expand Up @@ -58,6 +60,8 @@
"min",
"min_horizontal",
"nth",
"read_csv",
"read_parquet",
"selectors",
"sum",
"sum_horizontal",
Expand Down
21 changes: 13 additions & 8 deletions narwhals/_plan/arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from narwhals._plan.arrow.series import ArrowSeries as Series
from narwhals._plan.common import temp
from narwhals._plan.compliant.dataframe import EagerDataFrame
from narwhals._plan.compliant.typing import namespace
from narwhals._plan.compliant.typing import LazyFrameAny, namespace
from narwhals._plan.exceptions import shape_error
from narwhals._plan.expressions import NamedIR, named_ir
from narwhals._utils import Version, generate_repr
Expand All @@ -34,6 +34,7 @@
from narwhals._plan.expressions import ExprIR, NamedIR
from narwhals._plan.options import ExplodeOptions, SortMultipleOptions
from narwhals._plan.typing import NonCrossJoinStrategy
from narwhals._typing import _LazyAllowedImpl
from narwhals.dtypes import DType
from narwhals.typing import IntoSchema, UniqueKeepStrategy

Expand All @@ -57,6 +58,10 @@ def _group_by(self) -> type[GroupBy]:
def shape(self) -> tuple[int, int]:
return self.native.shape

def lazy(self, backend: _LazyAllowedImpl | None, **kwds: Any) -> LazyFrameAny:
msg = "ArrowDataFrame.lazy"
raise NotImplementedError(msg)

def group_by_resolver(self, resolver: GroupByResolver, /) -> GroupBy:
return self._group_by.from_resolver(self, resolver)

Expand Down Expand Up @@ -193,23 +198,23 @@ def with_row_index_by(
return self._with_native(self.native.add_column(0, name, column))

@overload
def write_csv(self, file: None) -> str: ...
def write_csv(self, target: None, /) -> str: ...
@overload
def write_csv(self, file: str | BytesIO) -> None: ...
def write_csv(self, file: str | BytesIO | None) -> str | None:
def write_csv(self, target: str | BytesIO, /) -> None: ...
def write_csv(self, target: str | BytesIO | None, /) -> str | None:
import pyarrow.csv as pa_csv

if file is None:
if target is None:
csv_buffer = pa.BufferOutputStream()
pa_csv.write_csv(self.native, csv_buffer)
return csv_buffer.getvalue().to_pybytes().decode()
pa_csv.write_csv(self.native, file)
pa_csv.write_csv(self.native, target)
return None

def write_parquet(self, file: str | BytesIO) -> None:
def write_parquet(self, target: str | BytesIO, /) -> None:
import pyarrow.parquet as pp

pp.write_table(self.native, file)
pp.write_table(self.native, target)

def to_struct(self, name: str = "") -> Series:
native = self.native
Expand Down
16 changes: 16 additions & 0 deletions narwhals/_plan/arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from narwhals._arrow.utils import narwhals_to_native_dtype
from narwhals._plan._guards import is_tuple_of
from narwhals._plan.arrow import functions as fn
from narwhals._plan.common import todo
from narwhals._plan.compliant.namespace import EagerNamespace
from narwhals._plan.expressions.expr import RangeExpr
from narwhals._plan.expressions.literal import is_literal_scalar
Expand Down Expand Up @@ -328,3 +329,18 @@ def _concat_vertical(self, items: Iterable[Frame | Series]) -> Frame | Series:
raise TypeError(msg)
return df._with_native(fn.concat_tables(df.native for df in dfs))
raise TypeError(items)

def read_csv(self, source: str, /, **kwds: Any) -> Frame:
import pyarrow.csv as pcsv

native = pcsv.read_csv(source, **kwds)
return self._dataframe.from_native(native, version=self.version)

def read_parquet(self, source: str, /, **kwds: Any) -> Frame:
import pyarrow.parquet as pq

native = pq.read_table(source, **kwds)
return self._dataframe.from_native(native, version=self.version)

scan_csv = todo()
scan_parquet = todo()
65 changes: 54 additions & 11 deletions narwhals/_plan/compliant/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, overload

from narwhals._plan.compliant import io
from narwhals._plan.compliant.group_by import Grouped
from narwhals._plan.compliant.typing import ColumnT_co, HasVersion, SeriesT
from narwhals._plan.compliant.typing import (
ColumnT_co,
DataFrameAny,
HasVersion,
LazyFrameAny,
SeriesT,
)
from narwhals._plan.typing import (
IncompleteCyclic,
IntoExpr,
NativeDataFrameT,
NativeFrameT_co,
NativeLazyFrameT,
NativeSeriesT,
NonCrossJoinStrategy,
OneOrIterable,
Expand All @@ -34,7 +42,7 @@
from narwhals._plan.expressions import NamedIR
from narwhals._plan.options import ExplodeOptions, SortMultipleOptions
from narwhals._plan.typing import Seq
from narwhals._typing import _EagerAllowedImpl
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
from narwhals._utils import Implementation, Version
from narwhals.dtypes import DType
from narwhals.typing import IntoSchema, UniqueKeepStrategy
Expand Down Expand Up @@ -94,13 +102,54 @@ def with_row_index_by(
) -> Self: ...


class CompliantLazyFrame(
io.LazyOutput,
CompliantFrame[ColumnT_co, NativeLazyFrameT],
Protocol[ColumnT_co, NativeLazyFrameT],
):
"""Very incomplete!

Using mostly as a placeholder for typing lazy I/O.
"""

_native: NativeLazyFrameT

def __narwhals_lazyframe__(self) -> Self:
return self

def _with_native(self, native: NativeLazyFrameT) -> Self:
return self.from_native(native, self.version)

@classmethod
def from_native(cls, native: NativeLazyFrameT, /, version: Version) -> Self:
obj = cls.__new__(cls)
obj._native = native
obj._version = version
return obj

def to_narwhals(self) -> Incomplete:
msg = f"{type(self).__name__}.to_narwhals"
raise NotImplementedError(msg)

@property
def native(self) -> NativeLazyFrameT:
return self._native

def collect(self, backend: _EagerAllowedImpl | None, **kwds: Any) -> DataFrameAny: ...


class CompliantDataFrame(
io.EagerOutput,
CompliantFrame[SeriesT, NativeDataFrameT],
Protocol[SeriesT, NativeDataFrameT, NativeSeriesT],
):
implementation: ClassVar[_EagerAllowedImpl]
_native: NativeDataFrameT

def __narwhals_dataframe__(self) -> Self:
return self

def lazy(self, backend: _LazyAllowedImpl | None, **kwds: Any) -> LazyFrameAny: ...
@property
def shape(self) -> tuple[int, int]: ...
def __len__(self) -> int: ...
Expand Down Expand Up @@ -209,12 +258,6 @@ def unique_by(
maintain_order: bool = False,
) -> Self: ...
def with_row_index(self, name: str) -> Self: ...
@overload
def write_csv(self, file: None) -> str: ...
@overload
def write_csv(self, file: str | BytesIO) -> None: ...
def write_csv(self, file: str | BytesIO | None) -> str | None: ...
def write_parquet(self, file: str | BytesIO) -> None: ...
def slice(self, offset: int, length: int | None = None) -> Self: ...
def sample_frac(
self, fraction: float, *, with_replacement: bool = False, seed: int | None = None
Expand All @@ -228,6 +271,7 @@ def sample_n(


class EagerDataFrame(
io.LazyOutput,
CompliantDataFrame[SeriesT, NativeDataFrameT, NativeSeriesT],
Protocol[SeriesT, NativeDataFrameT, NativeSeriesT],
):
Expand All @@ -254,6 +298,5 @@ def with_columns(self, irs: Seq[NamedIR]) -> Self:
def to_series(self, index: int = 0) -> SeriesT:
return self.get_column(self.columns[index])

# TODO @dangotbanned: Move to `CompliantLazyFrame` once that's added
def sink_parquet(self, file: str | BytesIO) -> None:
self.write_parquet(file)
def sink_parquet(self, target: str | BytesIO, /) -> None:
self.write_parquet(target)
7 changes: 5 additions & 2 deletions narwhals/_plan/compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from narwhals._plan.compliant.typing import (
FrameT_contra,
HasVersion,
LazyFrameT_contra,
LengthT,
SeriesT,
SeriesT_co,
Expand Down Expand Up @@ -44,6 +45,7 @@
from narwhals._plan.typing import IncompleteCyclic


# NOTE: At some point `Series` needs to be swapped out for `Column`
class CompliantExpr(HasVersion, Protocol[FrameT_contra, SeriesT_co]):
"""Everything common to `Expr`/`Series` and `Scalar` literal values."""

Expand Down Expand Up @@ -317,8 +319,9 @@ def __bool__(self) -> Literal[True]:
return True


# NOTE: At some point `Series` needs to be swapped out for `Column`
class LazyExpr(
SupportsBroadcast[SeriesT, LengthT],
CompliantExpr[FrameT_contra, SeriesT],
Protocol[FrameT_contra, SeriesT, LengthT],
CompliantExpr[LazyFrameT_contra, SeriesT],
Protocol[LazyFrameT_contra, SeriesT, LengthT],
): ...
123 changes: 123 additions & 0 deletions narwhals/_plan/compliant/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Minimal protocols for importing/exporting `*Frame`s from/to files and guards to test for them."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Protocol, overload

from narwhals._plan.compliant.typing import (
DataFrameT,
DataFrameT_co,
LazyFrameT,
LazyFrameT_co,
)
from narwhals._utils import _hasattr_static

if TYPE_CHECKING:
from io import BytesIO

from typing_extensions import TypeIs

__all__ = [
"EagerInput",
"EagerOutput",
"LazyInput",
"LazyOutput",
"ReadCsv",
"ReadParquet",
"ScanCsv",
"ScanParquet",
"SinkParquet",
"WriteCsv",
"WriteParquet",
"can_read_csv",
"can_read_parquet",
"can_scan_csv",
"can_scan_parquet",
"can_sink_parquet",
"can_write_csv",
"can_write_parquet",
]


class ScanCsv(Protocol[LazyFrameT_co]):
def scan_csv(self, source: str, /, **kwds: Any) -> LazyFrameT_co: ...


class ScanParquet(Protocol[LazyFrameT_co]):
def scan_parquet(self, source: str, /, **kwds: Any) -> LazyFrameT_co: ...


class ReadCsv(Protocol[DataFrameT_co]):
def read_csv(self, source: str, /, **kwds: Any) -> DataFrameT_co: ...


class ReadParquet(Protocol[DataFrameT_co]):
def read_parquet(self, source: str, /, **kwds: Any) -> DataFrameT_co: ...


class SinkParquet(Protocol):
def sink_parquet(self, target: str | BytesIO, /) -> None: ...


class WriteCsv(Protocol):
@overload
def write_csv(self, target: None, /) -> str: ...
@overload
def write_csv(self, target: str | BytesIO, /) -> None: ...
def write_csv(self, target: str | BytesIO | None, /) -> str | None: ...


class WriteParquet(Protocol):
def write_parquet(self, target: str | BytesIO, /) -> None: ...


class LazyInput(
ScanCsv[LazyFrameT_co], ScanParquet[LazyFrameT_co], Protocol[LazyFrameT_co]
):
"""Supports all `scan_*` methods, for lazily reading from files."""


class EagerInput(
ReadCsv[DataFrameT_co], ReadParquet[DataFrameT_co], Protocol[DataFrameT_co]
):
"""Supports all `read_*` methods, for eagerly reading from files."""


class LazyOutput(SinkParquet, Protocol):
"""Supports all `sink_*` methods, for lazily writing to files."""


class EagerOutput(WriteCsv, WriteParquet, Protocol):
"""Supports all `write_*` methods, for eagerly writing to files."""


def can_read_csv(obj: ReadCsv[DataFrameT] | Any) -> TypeIs[ReadCsv[DataFrameT]]:
return _hasattr_static(obj, "read_csv")


def can_read_parquet(
obj: ReadParquet[DataFrameT] | Any,
) -> TypeIs[ReadParquet[DataFrameT]]:
return _hasattr_static(obj, "read_parquet")


def can_scan_csv(obj: ScanCsv[LazyFrameT] | Any) -> TypeIs[ScanCsv[LazyFrameT]]:
return _hasattr_static(obj, "scan_csv")


def can_scan_parquet(
obj: ScanParquet[LazyFrameT] | Any,
) -> TypeIs[ScanParquet[LazyFrameT]]:
return _hasattr_static(obj, "scan_parquet")


def can_write_csv(obj: Any) -> TypeIs[WriteCsv]:
return _hasattr_static(obj, "write_csv")


def can_write_parquet(obj: Any) -> TypeIs[WriteParquet]:
return _hasattr_static(obj, "write_parquet")


def can_sink_parquet(obj: Any) -> TypeIs[SinkParquet]:
return _hasattr_static(obj, "sink_parquet")
Loading
Loading