@@ -431,10 +431,11 @@ def _execute_on_driver(
431
431
statement_args = self ._get_statement_args ()
432
432
jdbc_statement = self ._build_statement (statement , statement_type , jdbc_connection , statement_args )
433
433
434
- return self ._execute_statement (jdbc_statement , statement , options , callback , read_only )
434
+ return self ._execute_statement (jdbc_connection , jdbc_statement , statement , options , callback , read_only )
435
435
436
436
def _execute_statement (
437
437
self ,
438
+ jdbc_connection ,
438
439
jdbc_statement ,
439
440
statement : str ,
440
441
options : JDBCFetchOptions | JDBCExecuteOptions ,
@@ -472,7 +473,7 @@ def _execute_statement(
472
473
else :
473
474
jdbc_statement .executeUpdate (statement )
474
475
475
- return callback (jdbc_statement )
476
+ return callback (jdbc_connection , jdbc_statement )
476
477
477
478
@staticmethod
478
479
def _build_statement (
@@ -501,11 +502,11 @@ def _build_statement(
501
502
502
503
return jdbc_connection .createStatement (* statement_args )
503
504
504
- def _statement_to_dataframe (self , jdbc_statement ) -> DataFrame :
505
+ def _statement_to_dataframe (self , jdbc_connection , jdbc_statement ) -> DataFrame :
505
506
result_set = jdbc_statement .getResultSet ()
506
- return self ._resultset_to_dataframe (result_set )
507
+ return self ._resultset_to_dataframe (jdbc_connection , result_set )
507
508
508
- def _statement_to_optional_dataframe (self , jdbc_statement ) -> DataFrame | None :
509
+ def _statement_to_optional_dataframe (self , jdbc_connection , jdbc_statement ) -> DataFrame | None :
509
510
"""
510
511
Returns ``org.apache.spark.sql.DataFrame`` or ``None``, if ResultSet is does not contain any columns.
511
512
@@ -522,9 +523,9 @@ def _statement_to_optional_dataframe(self, jdbc_statement) -> DataFrame | None:
522
523
if not result_column_count :
523
524
return None
524
525
525
- return self ._resultset_to_dataframe (result_set )
526
+ return self ._resultset_to_dataframe (jdbc_connection , result_set )
526
527
527
- def _resultset_to_dataframe (self , result_set ) -> DataFrame :
528
+ def _resultset_to_dataframe (self , jdbc_connection , result_set ) -> DataFrame :
528
529
"""
529
530
Converts ``java.sql.ResultSet`` to ``org.apache.spark.sql.DataFrame`` using Spark's internal methods.
530
531
@@ -545,13 +546,25 @@ def _resultset_to_dataframe(self, result_set) -> DataFrame:
545
546
546
547
java_converters = self .spark ._jvm .scala .collection .JavaConverters # type: ignore
547
548
548
- if get_spark_version (self .spark ) >= Version ("3.4" ):
549
+ if get_spark_version (self .spark ) >= Version ("4.0" ):
550
+ result_schema = jdbc_utils .getSchema (
551
+ jdbc_connection ,
552
+ result_set ,
553
+ jdbc_dialect ,
554
+ False , # noqa: WPS425
555
+ False , # noqa: WPS425
556
+ )
557
+ elif get_spark_version (self .spark ) >= Version ("3.4" ):
549
558
# https://github.com/apache/spark/commit/2349175e1b81b0a61e1ed90c2d051c01cf78de9b
550
559
result_schema = jdbc_utils .getSchema (result_set , jdbc_dialect , False , False ) # noqa: WPS425
551
560
else :
552
561
result_schema = jdbc_utils .getSchema (result_set , jdbc_dialect , False ) # noqa: WPS425
553
562
554
- result_iterator = jdbc_utils .resultSetToRows (result_set , result_schema )
563
+ if get_spark_version (self .spark ) >= Version ("4.0" ):
564
+ result_iterator = jdbc_utils .resultSetToRows (result_set , result_schema , jdbc_dialect )
565
+ else :
566
+ result_iterator = jdbc_utils .resultSetToRows (result_set , result_schema )
567
+
555
568
result_list = java_converters .seqAsJavaListConverter (result_iterator .toSeq ()).asJava ()
556
569
jdf = self .spark ._jsparkSession .createDataFrame (result_list , result_schema ) # type: ignore
557
570
0 commit comments