@@ -47,7 +47,9 @@ def test_koalas_spark_graph_adapter(spark_session):
47
47
initial_columns ,
48
48
example_module ,
49
49
adapter = h_spark .SparkKoalasGraphAdapter (
50
- spark_session , result_builder = base .PandasDataFrameResult (), spine_column = "spend"
50
+ spark_session ,
51
+ result_builder = base .PandasDataFrameResult (),
52
+ spine_column = "spend" ,
51
53
),
52
54
)
53
55
output_columns = [
@@ -79,7 +81,9 @@ def test_smoke_screen_module(spark_session):
79
81
config ,
80
82
smoke_screen_module ,
81
83
adapter = h_spark .SparkKoalasGraphAdapter (
82
- spark_session , result_builder = base .PandasDataFrameResult (), spine_column = "weeks"
84
+ spark_session ,
85
+ result_builder = base .PandasDataFrameResult (),
86
+ spine_column = "weeks" ,
83
87
),
84
88
)
85
89
output_columns = [
@@ -110,7 +114,12 @@ def test_smoke_screen_module(spark_session):
110
114
(lambda df : ({"a" : df }, (df , {}))),
111
115
(lambda df : ({"a" : df , "b" : 1 }, (df , {"b" : 1 }))),
112
116
],
113
- ids = ["no_kwargs" , "one_plain_kwarg" , "one_df_kwarg" , "one_df_kwarg_and_one_plain_kwarg" ],
117
+ ids = [
118
+ "no_kwargs" ,
119
+ "one_plain_kwarg" ,
120
+ "one_df_kwarg" ,
121
+ "one_df_kwarg_and_one_plain_kwarg" ,
122
+ ],
114
123
)
115
124
def test__inspect_kwargs (input_and_expected_fn , spark_session ):
116
125
"""A unit test for inspect_kwargs."""
@@ -230,7 +239,11 @@ def base_func(a: int, b: int) -> int:
230
239
base_spark_df = spark_session .createDataFrame (pd .DataFrame ({"a" : [1 , 2 , 3 ], "b" : [4 , 5 , 6 ]}))
231
240
node_ = node .Node .from_fn (base_func )
232
241
new_df = h_spark ._lambda_udf (base_spark_df , node_ , {})
233
- assert new_df .collect () == [Row (a = 1 , b = 4 , test = 5 ), Row (a = 2 , b = 5 , test = 7 ), Row (a = 3 , b = 6 , test = 9 )]
242
+ assert new_df .collect () == [
243
+ Row (a = 1 , b = 4 , test = 5 ),
244
+ Row (a = 2 , b = 5 , test = 7 ),
245
+ Row (a = 3 , b = 6 , test = 9 ),
246
+ ]
234
247
235
248
236
249
def test__lambda_udf_pandas_func (spark_session ):
@@ -243,7 +256,11 @@ def base_func(a: pd.Series, b: pd.Series) -> htypes.column[pd.Series, int]:
243
256
node_ = node .Node .from_fn (base_func )
244
257
245
258
new_df = h_spark ._lambda_udf (base_spark_df , node_ , {})
246
- assert new_df .collect () == [Row (a = 1 , b = 4 , test = 5 ), Row (a = 2 , b = 5 , test = 7 ), Row (a = 3 , b = 6 , test = 9 )]
259
+ assert new_df .collect () == [
260
+ Row (a = 1 , b = 4 , test = 5 ),
261
+ Row (a = 2 , b = 5 , test = 7 ),
262
+ Row (a = 3 , b = 6 , test = 9 ),
263
+ ]
247
264
248
265
249
266
def test__lambda_udf_pandas_func_error (spark_session ):
@@ -348,11 +365,13 @@ def test_get_spark_type_numpy_types(return_type, expected_spark_type):
348
365
349
366
# 4. Unsupported types
350
367
@pytest .mark .parametrize (
351
- "unsupported_return_type" , [dict , set , tuple ] # Add other unsupported types as needed
368
+ "unsupported_return_type" ,
369
+ [dict , set , tuple ], # Add other unsupported types as needed
352
370
)
353
371
def test_get_spark_type_unsupported (unsupported_return_type ):
354
372
with pytest .raises (
355
- ValueError , match = f"Currently unsupported return type { unsupported_return_type } ."
373
+ ValueError ,
374
+ match = f"Currently unsupported return type { unsupported_return_type } ." ,
356
375
):
357
376
h_spark .get_spark_type (unsupported_return_type )
358
377
@@ -470,19 +489,19 @@ def test_base_spark_executor_end_to_end_multiple_with_columns(spark_session):
470
489
471
490
472
491
def _only_pyspark_dataframe_parameter (foo : DataFrame ) -> DataFrame :
473
- ...
492
+ pass
474
493
475
494
476
495
def _no_pyspark_dataframe_parameter (foo : int ) -> int :
477
- ...
496
+ pass
478
497
479
498
480
499
def _one_pyspark_dataframe_parameter (foo : DataFrame , bar : int ) -> DataFrame :
481
- ...
500
+ pass
482
501
483
502
484
503
def _two_pyspark_dataframe_parameters (foo : DataFrame , bar : int , baz : DataFrame ) -> DataFrame :
485
- ...
504
+ pass
486
505
487
506
488
507
@pytest .mark .parametrize (
@@ -603,7 +622,11 @@ def df_as_pandas(df: DataFrame) -> pd.DataFrame:
603
622
604
623
nodes = dec .generate_nodes (df_as_pandas , {})
605
624
nodes_by_names = {n .name : n for n in nodes }
606
- assert set (nodes_by_names .keys ()) == {"df_as_pandas.c" , "df_as_pandas" , "df_as_pandas._select" }
625
+ assert set (nodes_by_names .keys ()) == {
626
+ "df_as_pandas.c" ,
627
+ "df_as_pandas" ,
628
+ "df_as_pandas._select" ,
629
+ }
607
630
608
631
609
632
def test_with_columns_generate_nodes_specify_namespace ():
@@ -640,7 +663,10 @@ def test__format_standard_udf():
640
663
641
664
def test_sparkify_node ():
642
665
def foo (
643
- a_from_upstream : pd .Series , b_from_upstream : pd .Series , c_from_df : pd .Series , d_fixed : int
666
+ a_from_upstream : pd .Series ,
667
+ b_from_upstream : pd .Series ,
668
+ c_from_df : pd .Series ,
669
+ d_fixed : int ,
644
670
) -> htypes .column [pd .Series , int ]:
645
671
return a_from_upstream + b_from_upstream + c_from_df + d_fixed
646
672
@@ -679,7 +705,10 @@ def test_pyspark_mixed_pandas_udfs_end_to_end():
679
705
# inputs={"spark_session": spark_session},
680
706
# )
681
707
results = dr .execute (
682
- ["processed_df_as_pandas_dataframe_with_injected_dataframe" , "processed_df_as_pandas" ],
708
+ [
709
+ "processed_df_as_pandas_dataframe_with_injected_dataframe" ,
710
+ "processed_df_as_pandas" ,
711
+ ],
683
712
inputs = {"spark_session" : spark_session },
684
713
)
685
714
processed_df_as_pandas = results ["processed_df_as_pandas" ]
@@ -774,7 +803,11 @@ def test_create_selector_node(spark_session):
774
803
selector_node = h_spark .with_columns .create_selector_node ("foo" , ["a" , "b" ], "select" )
775
804
assert selector_node .name == "select"
776
805
pandas_df = pd .DataFrame (
777
- {"a" : [10 , 10 , 20 , 40 , 40 , 50 ], "b" : [1 , 10 , 50 , 100 , 200 , 400 ], "c" : [1 , 2 , 3 , 4 , 5 , 6 ]}
806
+ {
807
+ "a" : [10 , 10 , 20 , 40 , 40 , 50 ],
808
+ "b" : [1 , 10 , 50 , 100 , 200 , 400 ],
809
+ "c" : [1 , 2 , 3 , 4 , 5 , 6 ],
810
+ }
778
811
)
779
812
df = spark_session .createDataFrame (pandas_df )
780
813
transformed = selector_node (foo = df ).toPandas ()
0 commit comments