Skip to content

Commit b91db4e

Browse files
committed
Test Spark 4.0
1 parent b657322 commit b91db4e

File tree

5 files changed

+33
-13
lines changed

5 files changed

+33
-13
lines changed

.github/workflows/data/core/matrix.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ latest: &latest
2222
matrix:
2323
small: [*max]
2424
full: [*min, *max]
25-
nightly: [*min, *max, *latest]
25+
nightly: [*min, *latest]

onetl/_util/scala.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@ def get_default_scala_version(spark_version: Version) -> Version:
1111
"""
1212
if spark_version.major < 3:
1313
return Version("2.11")
14-
return Version("2.12")
14+
if spark_version.major < 4:
15+
return Version("2.12")
16+
return Version("2.13")

onetl/connection/db_connection/jdbc_mixin/connection.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -431,10 +431,11 @@ def _execute_on_driver(
431431
statement_args = self._get_statement_args()
432432
jdbc_statement = self._build_statement(statement, statement_type, jdbc_connection, statement_args)
433433

434-
return self._execute_statement(jdbc_statement, statement, options, callback, read_only)
434+
return self._execute_statement(jdbc_connection, jdbc_statement, statement, options, callback, read_only)
435435

436436
def _execute_statement(
437437
self,
438+
jdbc_connection,
438439
jdbc_statement,
439440
statement: str,
440441
options: JDBCFetchOptions | JDBCExecuteOptions,
@@ -472,7 +473,7 @@ def _execute_statement(
472473
else:
473474
jdbc_statement.executeUpdate(statement)
474475

475-
return callback(jdbc_statement)
476+
return callback(jdbc_connection, jdbc_statement)
476477

477478
@staticmethod
478479
def _build_statement(
@@ -501,11 +502,11 @@ def _build_statement(
501502

502503
return jdbc_connection.createStatement(*statement_args)
503504

504-
def _statement_to_dataframe(self, jdbc_statement) -> DataFrame:
505+
def _statement_to_dataframe(self, jdbc_connection, jdbc_statement) -> DataFrame:
505506
result_set = jdbc_statement.getResultSet()
506-
return self._resultset_to_dataframe(result_set)
507+
return self._resultset_to_dataframe(jdbc_connection, result_set)
507508

508-
def _statement_to_optional_dataframe(self, jdbc_statement) -> DataFrame | None:
509+
def _statement_to_optional_dataframe(self, jdbc_connection, jdbc_statement) -> DataFrame | None:
509510
"""
510511
Returns ``org.apache.spark.sql.DataFrame`` or ``None``, if ResultSet is does not contain any columns.
511512
@@ -522,9 +523,9 @@ def _statement_to_optional_dataframe(self, jdbc_statement) -> DataFrame | None:
522523
if not result_column_count:
523524
return None
524525

525-
return self._resultset_to_dataframe(result_set)
526+
return self._resultset_to_dataframe(jdbc_connection, result_set)
526527

527-
def _resultset_to_dataframe(self, result_set) -> DataFrame:
528+
def _resultset_to_dataframe(self, jdbc_connection, result_set) -> DataFrame:
528529
"""
529530
Converts ``java.sql.ResultSet`` to ``org.apache.spark.sql.DataFrame`` using Spark's internal methods.
530531
@@ -545,13 +546,25 @@ def _resultset_to_dataframe(self, result_set) -> DataFrame:
545546

546547
java_converters = self.spark._jvm.scala.collection.JavaConverters # type: ignore
547548

548-
if get_spark_version(self.spark) >= Version("3.4"):
549+
if get_spark_version(self.spark) >= Version("4.0"):
550+
result_schema = jdbc_utils.getSchema(
551+
jdbc_connection,
552+
result_set,
553+
jdbc_dialect,
554+
False, # noqa: WPS425
555+
False, # noqa: WPS425
556+
)
557+
elif get_spark_version(self.spark) >= Version("3.4"):
549558
# https://github.com/apache/spark/commit/2349175e1b81b0a61e1ed90c2d051c01cf78de9b
550559
result_schema = jdbc_utils.getSchema(result_set, jdbc_dialect, False, False) # noqa: WPS425
551560
else:
552561
result_schema = jdbc_utils.getSchema(result_set, jdbc_dialect, False) # noqa: WPS425
553562

554-
result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema)
563+
if get_spark_version(self.spark) >= Version("4.0"):
564+
result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema, jdbc_dialect)
565+
else:
566+
result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema)
567+
555568
result_list = java_converters.seqAsJavaListConverter(result_iterator.toSeq()).asJava()
556569
jdf = self.spark._jsparkSession.createDataFrame(result_list, result_schema) # type: ignore
557570

onetl/connection/file_df_connection/spark_s3/connection.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,14 @@ def get_packages(
246246
# https://issues.apache.org/jira/browse/SPARK-23977
247247
raise ValueError(f"Spark version must be at least 3.x, got {spark_ver}")
248248

249+
if spark_ver.major < 4:
250+
version = spark_ver.format("{0}.{1}.{2}")
251+
else:
252+
version = "4.0.0-preview1"
253+
249254
scala_ver = Version(scala_version).min_digits(2) if scala_version else get_default_scala_version(spark_ver)
250255
# https://mvnrepository.com/artifact/org.apache.spark/spark-hadoop-cloud
251-
return [f"org.apache.spark:spark-hadoop-cloud_{scala_ver.format('{0}.{1}')}:{spark_ver.format('{0}.{1}.{2}')}"]
256+
return [f"org.apache.spark:spark-hadoop-cloud_{scala_ver.format('{0}.{1}')}:{version}"]
252257

253258
@slot
254259
def path_from_string(self, path: os.PathLike | str) -> RemotePath:

requirements/tests/spark-latest.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
numpy>=1.16
22
pandas>=1.0
33
pyarrow>=1.0
4-
pyspark
4+
pyspark==4.0.0.dev1
55
sqlalchemy

0 commit comments

Comments
 (0)