6
6
from typing import Literal
7
7
from typing import Sequence
8
8
9
- from pyspark .sql import Window
10
- from pyspark .sql import functions as F # noqa: N812
11
-
12
9
from narwhals ._spark_like .utils import ExprKind
13
10
from narwhals ._spark_like .utils import native_to_narwhals_dtype
14
11
from narwhals ._spark_like .utils import parse_exprs_and_named_exprs
@@ -40,38 +37,73 @@ def __init__(
40
37
* ,
41
38
backend_version : tuple [int , ...],
42
39
version : Version ,
40
+ implementation : Implementation ,
43
41
) -> None :
44
42
self ._native_frame = native_dataframe
45
43
self ._backend_version = backend_version
46
- self ._implementation = Implementation . PYSPARK
44
+ self ._implementation = implementation
47
45
self ._version = version
48
46
validate_backend_version (self ._implementation , self ._backend_version )
49
47
50
- def __native_namespace__ (self : Self ) -> ModuleType : # pragma: no cover
51
- if self ._implementation is Implementation .PYSPARK :
52
- return self ._implementation .to_native_namespace ()
48
+ @property
49
+ def _F (self ) -> Any : # noqa: N802
50
+ if self ._implementation is Implementation .SQLFRAME :
51
+ from sqlframe .duckdb import functions
52
+
53
+ return functions
54
+ from pyspark .sql import functions
55
+
56
+ return functions
57
+
58
+ @property
59
+ def _native_dtypes (self ) -> Any :
60
+ if self ._implementation is Implementation .SQLFRAME :
61
+ from sqlframe .duckdb import types
62
+
63
+ return types
64
+ from pyspark .sql import types
65
+
66
+ return types
67
+
68
+ @property
69
+ def _Window (self ) -> Any : # noqa: N802
70
+ if self ._implementation is Implementation .SQLFRAME :
71
+ from sqlframe .duckdb import Window
53
72
54
- msg = f"Expected pyspark, got: { type (self ._implementation )} " # pragma: no cover
55
- raise AssertionError (msg )
73
+ return Window
74
+ from pyspark .sql import Window
75
+
76
+ return Window
77
+
78
+ def __native_namespace__ (self : Self ) -> ModuleType : # pragma: no cover
79
+ return self ._implementation .to_native_namespace ()
56
80
57
81
def __narwhals_namespace__ (self : Self ) -> SparkLikeNamespace :
58
82
from narwhals ._spark_like .namespace import SparkLikeNamespace
59
83
60
84
return SparkLikeNamespace (
61
- backend_version = self ._backend_version , version = self ._version
85
+ backend_version = self ._backend_version ,
86
+ version = self ._version ,
87
+ implementation = self ._implementation ,
62
88
)
63
89
64
90
def __narwhals_lazyframe__ (self : Self ) -> Self :
65
91
return self
66
92
67
93
def _change_version (self : Self , version : Version ) -> Self :
68
94
return self .__class__ (
69
- self ._native_frame , backend_version = self ._backend_version , version = version
95
+ self ._native_frame ,
96
+ backend_version = self ._backend_version ,
97
+ version = version ,
98
+ implementation = self ._implementation ,
70
99
)
71
100
72
101
def _from_native_frame (self : Self , df : DataFrame ) -> Self :
73
102
return self .__class__ (
74
- df , backend_version = self ._backend_version , version = self ._version
103
+ df ,
104
+ backend_version = self ._backend_version ,
105
+ version = self ._version ,
106
+ implementation = self ._implementation ,
75
107
)
76
108
77
109
@property
@@ -102,10 +134,10 @@ def select(
102
134
103
135
if not new_columns :
104
136
# return empty dataframe, like Polars does
105
- from pyspark .sql .types import StructType
106
-
107
137
spark_session = self ._native_frame .sparkSession
108
- spark_df = spark_session .createDataFrame ([], StructType ([]))
138
+ spark_df = spark_session .createDataFrame (
139
+ [], self ._native_dtypes .StructType ([])
140
+ )
109
141
110
142
return self ._from_native_frame (spark_df )
111
143
@@ -116,7 +148,7 @@ def select(
116
148
return self ._from_native_frame (self ._native_frame .agg (* new_columns_list ))
117
149
else :
118
150
new_columns_list = [
119
- col .over (Window . partitionBy (F .lit (1 ))).alias (col_name )
151
+ col .over (self . _Window (). partitionBy (self . _F .lit (1 ))).alias (col_name )
120
152
if expr_kind is ExprKind .AGGREGATION
121
153
else col .alias (col_name )
122
154
for (col_name , col ), expr_kind in zip (new_columns .items (), expr_kinds )
@@ -131,7 +163,7 @@ def with_columns(
131
163
new_columns , expr_kinds = parse_exprs_and_named_exprs (self , * exprs , ** named_exprs )
132
164
133
165
new_columns_map = {
134
- col_name : col .over (Window . partitionBy (F .lit (1 )))
166
+ col_name : col .over (self . _Window (). partitionBy (self . _F .lit (1 )))
135
167
if expr_kind is ExprKind .AGGREGATION
136
168
else col
137
169
for (col_name , col ), expr_kind in zip (new_columns .items (), expr_kinds )
@@ -152,7 +184,9 @@ def filter(self: Self, *predicates: SparkLikeExpr, **constraints: Any) -> Self:
152
184
def schema (self : Self ) -> dict [str , DType ]:
153
185
return {
154
186
field .name : native_to_narwhals_dtype (
155
- dtype = field .dataType , version = self ._version
187
+ dtype = field .dataType ,
188
+ version = self ._version ,
189
+ spark_types = self ._native_dtypes ,
156
190
)
157
191
for field in self ._native_frame .schema
158
192
}
@@ -186,18 +220,18 @@ def sort(
186
220
descending : bool | Sequence [bool ],
187
221
nulls_last : bool ,
188
222
) -> Self :
189
- import pyspark .sql .functions as F # noqa: N812
190
-
191
223
if isinstance (descending , bool ):
192
224
descending = [descending ] * len (by )
193
225
194
226
if nulls_last :
195
227
sort_funcs = (
196
- F .desc_nulls_last if d else F .asc_nulls_last for d in descending
228
+ self ._F .desc_nulls_last if d else self ._F .asc_nulls_last
229
+ for d in descending
197
230
)
198
231
else :
199
232
sort_funcs = (
200
- F .desc_nulls_first if d else F .asc_nulls_first for d in descending
233
+ self ._F .desc_nulls_first if d else self ._F .asc_nulls_first
234
+ for d in descending
201
235
)
202
236
203
237
sort_cols = [sort_f (col ) for col , sort_f in zip (by , sort_funcs )]
@@ -207,14 +241,12 @@ def drop_nulls(self: Self, subset: list[str] | None) -> Self:
207
241
return self ._from_native_frame (self ._native_frame .dropna (subset = subset ))
208
242
209
243
def rename (self : Self , mapping : dict [str , str ]) -> Self :
210
- import pyspark .sql .functions as F # noqa: N812
211
-
212
244
rename_mapping = {
213
245
colname : mapping .get (colname , colname ) for colname in self .columns
214
246
}
215
247
return self ._from_native_frame (
216
248
self ._native_frame .select (
217
- [F .col (old ).alias (new ) for old , new in rename_mapping .items ()]
249
+ [self . _F .col (old ).alias (new ) for old , new in rename_mapping .items ()]
218
250
)
219
251
)
220
252
@@ -238,8 +270,6 @@ def join(
238
270
right_on : str | list [str ] | None ,
239
271
suffix : str ,
240
272
) -> Self :
241
- import pyspark .sql .functions as F # noqa: N812
242
-
243
273
self_native = self ._native_frame
244
274
other_native = other ._native_frame
245
275
@@ -262,7 +292,7 @@ def join(
262
292
},
263
293
}
264
294
other = other_native .select (
265
- [F .col (old ).alias (new ) for old , new in rename_mapping .items ()]
295
+ [self . _F .col (old ).alias (new ) for old , new in rename_mapping .items ()]
266
296
)
267
297
268
298
# If how in {"semi", "anti"}, then resulting columns are same as left columns
@@ -280,5 +310,5 @@ def join(
280
310
)
281
311
282
312
return self ._from_native_frame (
283
- self_native .join (other = other , on = left_on , how = how ).select (col_order )
313
+ self_native .join (other , on = left_on , how = how ).select (col_order )
284
314
)
0 commit comments