Skip to content

Commit 9748309

Browse files
authored
feat: start (silently) adding support for SQLFrame (#1883)
1 parent a11147b commit 9748309

15 files changed

+415
-191
lines changed

.github/workflows/mkdocs.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@ jobs:
2929
- name: griffe
3030
# hopefully temporary until https://github.com/mkdocstrings/mkdocstrings/issues/716
3131
run: pip install git+https://github.com/MarcoGorelli/griffe.git@no-overloads
32-
- run: pip install -e .[docs,pyspark,dask,duckdb]
32+
- run: pip install -e .[docs,dask,duckdb]
3333

3434
- run: mkdocs gh-deploy --force

narwhals/_spark_like/dataframe.py

+59-29
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
from typing import Literal
77
from typing import Sequence
88

9-
from pyspark.sql import Window
10-
from pyspark.sql import functions as F # noqa: N812
11-
129
from narwhals._spark_like.utils import ExprKind
1310
from narwhals._spark_like.utils import native_to_narwhals_dtype
1411
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
@@ -40,38 +37,73 @@ def __init__(
4037
*,
4138
backend_version: tuple[int, ...],
4239
version: Version,
40+
implementation: Implementation,
4341
) -> None:
4442
self._native_frame = native_dataframe
4543
self._backend_version = backend_version
46-
self._implementation = Implementation.PYSPARK
44+
self._implementation = implementation
4745
self._version = version
4846
validate_backend_version(self._implementation, self._backend_version)
4947

50-
def __native_namespace__(self: Self) -> ModuleType: # pragma: no cover
51-
if self._implementation is Implementation.PYSPARK:
52-
return self._implementation.to_native_namespace()
48+
@property
49+
def _F(self) -> Any: # noqa: N802
50+
if self._implementation is Implementation.SQLFRAME:
51+
from sqlframe.duckdb import functions
52+
53+
return functions
54+
from pyspark.sql import functions
55+
56+
return functions
57+
58+
@property
59+
def _native_dtypes(self) -> Any:
60+
if self._implementation is Implementation.SQLFRAME:
61+
from sqlframe.duckdb import types
62+
63+
return types
64+
from pyspark.sql import types
65+
66+
return types
67+
68+
@property
69+
def _Window(self) -> Any: # noqa: N802
70+
if self._implementation is Implementation.SQLFRAME:
71+
from sqlframe.duckdb import Window
5372

54-
msg = f"Expected pyspark, got: {type(self._implementation)}" # pragma: no cover
55-
raise AssertionError(msg)
73+
return Window
74+
from pyspark.sql import Window
75+
76+
return Window
77+
78+
def __native_namespace__(self: Self) -> ModuleType: # pragma: no cover
79+
return self._implementation.to_native_namespace()
5680

5781
def __narwhals_namespace__(self: Self) -> SparkLikeNamespace:
5882
from narwhals._spark_like.namespace import SparkLikeNamespace
5983

6084
return SparkLikeNamespace(
61-
backend_version=self._backend_version, version=self._version
85+
backend_version=self._backend_version,
86+
version=self._version,
87+
implementation=self._implementation,
6288
)
6389

6490
def __narwhals_lazyframe__(self: Self) -> Self:
6591
return self
6692

6793
def _change_version(self: Self, version: Version) -> Self:
6894
return self.__class__(
69-
self._native_frame, backend_version=self._backend_version, version=version
95+
self._native_frame,
96+
backend_version=self._backend_version,
97+
version=version,
98+
implementation=self._implementation,
7099
)
71100

72101
def _from_native_frame(self: Self, df: DataFrame) -> Self:
73102
return self.__class__(
74-
df, backend_version=self._backend_version, version=self._version
103+
df,
104+
backend_version=self._backend_version,
105+
version=self._version,
106+
implementation=self._implementation,
75107
)
76108

77109
@property
@@ -102,10 +134,10 @@ def select(
102134

103135
if not new_columns:
104136
# return empty dataframe, like Polars does
105-
from pyspark.sql.types import StructType
106-
107137
spark_session = self._native_frame.sparkSession
108-
spark_df = spark_session.createDataFrame([], StructType([]))
138+
spark_df = spark_session.createDataFrame(
139+
[], self._native_dtypes.StructType([])
140+
)
109141

110142
return self._from_native_frame(spark_df)
111143

@@ -116,7 +148,7 @@ def select(
116148
return self._from_native_frame(self._native_frame.agg(*new_columns_list))
117149
else:
118150
new_columns_list = [
119-
col.over(Window.partitionBy(F.lit(1))).alias(col_name)
151+
col.over(self._Window().partitionBy(self._F.lit(1))).alias(col_name)
120152
if expr_kind is ExprKind.AGGREGATION
121153
else col.alias(col_name)
122154
for (col_name, col), expr_kind in zip(new_columns.items(), expr_kinds)
@@ -131,7 +163,7 @@ def with_columns(
131163
new_columns, expr_kinds = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)
132164

133165
new_columns_map = {
134-
col_name: col.over(Window.partitionBy(F.lit(1)))
166+
col_name: col.over(self._Window().partitionBy(self._F.lit(1)))
135167
if expr_kind is ExprKind.AGGREGATION
136168
else col
137169
for (col_name, col), expr_kind in zip(new_columns.items(), expr_kinds)
@@ -152,7 +184,9 @@ def filter(self: Self, *predicates: SparkLikeExpr, **constraints: Any) -> Self:
152184
def schema(self: Self) -> dict[str, DType]:
153185
return {
154186
field.name: native_to_narwhals_dtype(
155-
dtype=field.dataType, version=self._version
187+
dtype=field.dataType,
188+
version=self._version,
189+
spark_types=self._native_dtypes,
156190
)
157191
for field in self._native_frame.schema
158192
}
@@ -186,18 +220,18 @@ def sort(
186220
descending: bool | Sequence[bool],
187221
nulls_last: bool,
188222
) -> Self:
189-
import pyspark.sql.functions as F # noqa: N812
190-
191223
if isinstance(descending, bool):
192224
descending = [descending] * len(by)
193225

194226
if nulls_last:
195227
sort_funcs = (
196-
F.desc_nulls_last if d else F.asc_nulls_last for d in descending
228+
self._F.desc_nulls_last if d else self._F.asc_nulls_last
229+
for d in descending
197230
)
198231
else:
199232
sort_funcs = (
200-
F.desc_nulls_first if d else F.asc_nulls_first for d in descending
233+
self._F.desc_nulls_first if d else self._F.asc_nulls_first
234+
for d in descending
201235
)
202236

203237
sort_cols = [sort_f(col) for col, sort_f in zip(by, sort_funcs)]
@@ -207,14 +241,12 @@ def drop_nulls(self: Self, subset: list[str] | None) -> Self:
207241
return self._from_native_frame(self._native_frame.dropna(subset=subset))
208242

209243
def rename(self: Self, mapping: dict[str, str]) -> Self:
210-
import pyspark.sql.functions as F # noqa: N812
211-
212244
rename_mapping = {
213245
colname: mapping.get(colname, colname) for colname in self.columns
214246
}
215247
return self._from_native_frame(
216248
self._native_frame.select(
217-
[F.col(old).alias(new) for old, new in rename_mapping.items()]
249+
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
218250
)
219251
)
220252

@@ -238,8 +270,6 @@ def join(
238270
right_on: str | list[str] | None,
239271
suffix: str,
240272
) -> Self:
241-
import pyspark.sql.functions as F # noqa: N812
242-
243273
self_native = self._native_frame
244274
other_native = other._native_frame
245275

@@ -262,7 +292,7 @@ def join(
262292
},
263293
}
264294
other = other_native.select(
265-
[F.col(old).alias(new) for old, new in rename_mapping.items()]
295+
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
266296
)
267297

268298
# If how in {"semi", "anti"}, then resulting columns are same as left columns
@@ -280,5 +310,5 @@ def join(
280310
)
281311

282312
return self._from_native_frame(
283-
self_native.join(other=other, on=left_on, how=how).select(col_order)
313+
self_native.join(other, on=left_on, how=how).select(col_order)
284314
)

0 commit comments

Comments
 (0)