Skip to content

Commit dd21d0a

Browse files
committed
Test Spark 4.0
1 parent b657322 commit dd21d0a

File tree

9 files changed

+61
-19
lines changed

9 files changed

+61
-19
lines changed

.github/workflows/data/core/matrix.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ latest: &latest
2222
matrix:
2323
small: [*max]
2424
full: [*min, *max]
25-
nightly: [*min, *max, *latest]
25+
nightly: [*min, *latest]

onetl/_util/scala.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ def get_default_scala_version(spark_version: Version) -> Version:
99
"""
1010
Get default Scala version for specific Spark version
1111
"""
12-
if spark_version.major < 3:
12+
if spark_version.major == 2:
1313
return Version("2.11")
14-
return Version("2.12")
14+
if spark_version.major == 3:
15+
return Version("2.12")
16+
return Version("2.13")

onetl/connection/db_connection/jdbc_mixin/connection.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -431,10 +431,11 @@ def _execute_on_driver(
431431
statement_args = self._get_statement_args()
432432
jdbc_statement = self._build_statement(statement, statement_type, jdbc_connection, statement_args)
433433

434-
return self._execute_statement(jdbc_statement, statement, options, callback, read_only)
434+
return self._execute_statement(jdbc_connection, jdbc_statement, statement, options, callback, read_only)
435435

436436
def _execute_statement(
437437
self,
438+
jdbc_connection,
438439
jdbc_statement,
439440
statement: str,
440441
options: JDBCFetchOptions | JDBCExecuteOptions,
@@ -472,7 +473,7 @@ def _execute_statement(
472473
else:
473474
jdbc_statement.executeUpdate(statement)
474475

475-
return callback(jdbc_statement)
476+
return callback(jdbc_connection, jdbc_statement)
476477

477478
@staticmethod
478479
def _build_statement(
@@ -501,11 +502,11 @@ def _build_statement(
501502

502503
return jdbc_connection.createStatement(*statement_args)
503504

504-
def _statement_to_dataframe(self, jdbc_statement) -> DataFrame:
505+
def _statement_to_dataframe(self, jdbc_connection, jdbc_statement) -> DataFrame:
505506
result_set = jdbc_statement.getResultSet()
506-
return self._resultset_to_dataframe(result_set)
507+
return self._resultset_to_dataframe(jdbc_connection, result_set)
507508

508-
def _statement_to_optional_dataframe(self, jdbc_statement) -> DataFrame | None:
509+
def _statement_to_optional_dataframe(self, jdbc_connection, jdbc_statement) -> DataFrame | None:
509510
"""
510511
Returns ``org.apache.spark.sql.DataFrame`` or ``None``, if ResultSet is does not contain any columns.
511512
@@ -522,9 +523,9 @@ def _statement_to_optional_dataframe(self, jdbc_statement) -> DataFrame | None:
522523
if not result_column_count:
523524
return None
524525

525-
return self._resultset_to_dataframe(result_set)
526+
return self._resultset_to_dataframe(jdbc_connection, result_set)
526527

527-
def _resultset_to_dataframe(self, result_set) -> DataFrame:
528+
def _resultset_to_dataframe(self, jdbc_connection, result_set) -> DataFrame:
528529
"""
529530
Converts ``java.sql.ResultSet`` to ``org.apache.spark.sql.DataFrame`` using Spark's internal methods.
530531
@@ -545,13 +546,25 @@ def _resultset_to_dataframe(self, result_set) -> DataFrame:
545546

546547
java_converters = self.spark._jvm.scala.collection.JavaConverters # type: ignore
547548

548-
if get_spark_version(self.spark) >= Version("3.4"):
549+
if get_spark_version(self.spark) >= Version("4.0"):
550+
result_schema = jdbc_utils.getSchema(
551+
jdbc_connection,
552+
result_set,
553+
jdbc_dialect,
554+
False, # noqa: WPS425
555+
False, # noqa: WPS425
556+
)
557+
elif get_spark_version(self.spark) >= Version("3.4"):
549558
# https://github.com/apache/spark/commit/2349175e1b81b0a61e1ed90c2d051c01cf78de9b
550559
result_schema = jdbc_utils.getSchema(result_set, jdbc_dialect, False, False) # noqa: WPS425
551560
else:
552561
result_schema = jdbc_utils.getSchema(result_set, jdbc_dialect, False) # noqa: WPS425
553562

554-
result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema)
563+
if get_spark_version(self.spark) >= Version("4.0"):
564+
result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema, jdbc_dialect)
565+
else:
566+
result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema)
567+
555568
result_list = java_converters.seqAsJavaListConverter(result_iterator.toSeq()).asJava()
556569
jdf = self.spark._jsparkSession.createDataFrame(result_list, result_schema) # type: ignore
557570

onetl/connection/db_connection/kafka/connection.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,13 @@ def get_packages(
432432
raise ValueError(f"Spark version must be at least 2.4, got {spark_ver}")
433433

434434
scala_ver = Version(scala_version).min_digits(2) if scala_version else get_default_scala_version(spark_ver)
435+
436+
if spark_ver.major < 4:
437+
version = spark_ver.format("{0}.{1}.{2}")
438+
else:
439+
version = "4.0.0-preview1"
435440
return [
436-
f"org.apache.spark:spark-sql-kafka-0-10_{scala_ver.format('{0}.{1}')}:{spark_ver.format('{0}.{1}.{2}')}",
441+
f"org.apache.spark:spark-sql-kafka-0-10_{scala_ver.format('{0}.{1}')}:{version}",
437442
]
438443

439444
def __enter__(self):

onetl/connection/file_df_connection/spark_s3/connection.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,14 @@ def get_packages(
246246
# https://issues.apache.org/jira/browse/SPARK-23977
247247
raise ValueError(f"Spark version must be at least 3.x, got {spark_ver}")
248248

249+
if spark_ver.major < 4:
250+
version = spark_ver.format("{0}.{1}.{2}")
251+
else:
252+
version = "4.0.0-preview1"
253+
249254
scala_ver = Version(scala_version).min_digits(2) if scala_version else get_default_scala_version(spark_ver)
250255
# https://mvnrepository.com/artifact/org.apache.spark/spark-hadoop-cloud
251-
return [f"org.apache.spark:spark-hadoop-cloud_{scala_ver.format('{0}.{1}')}:{spark_ver.format('{0}.{1}.{2}')}"]
256+
return [f"org.apache.spark:spark-hadoop-cloud_{scala_ver.format('{0}.{1}')}:{version}"]
252257

253258
@slot
254259
def path_from_string(self, path: os.PathLike | str) -> RemotePath:

onetl/file/format/avro.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,12 @@ def get_packages(
163163
if scala_ver < Version("2.11"):
164164
raise ValueError(f"Scala version should be at least 2.11, got {scala_ver.format('{0}.{1}')}")
165165

166-
return [f"org.apache.spark:spark-avro_{scala_ver.format('{0}.{1}')}:{spark_ver.format('{0}.{1}.{2}')}"]
166+
if spark_ver.major < 4:
167+
version = spark_ver.format("{0}.{1}.{2}")
168+
else:
169+
version = "4.0.0-preview1"
170+
171+
return [f"org.apache.spark:spark-avro_{scala_ver.format('{0}.{1}')}:{version}"]
167172

168173
@slot
169174
def check_if_supported(self, spark: SparkSession) -> None:

onetl/file/format/xml.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ def get_packages( # noqa: WPS231
193193
)
194194
195195
"""
196+
spark_ver = Version(spark_version)
197+
if spark_ver.major >= 4:
198+
return []
196199

197200
if package_version:
198201
version = Version(package_version).min_digits(3)
@@ -202,7 +205,6 @@ def get_packages( # noqa: WPS231
202205
else:
203206
version = Version("0.18.0").min_digits(3)
204207

205-
spark_ver = Version(spark_version)
206208
scala_ver = Version(scala_version).min_digits(2) if scala_version else get_default_scala_version(spark_ver)
207209

208210
# Ensure compatibility with Spark and Scala versions
@@ -216,8 +218,11 @@ def get_packages( # noqa: WPS231
216218

217219
@slot
218220
def check_if_supported(self, spark: SparkSession) -> None:
219-
java_class = "com.databricks.spark.xml.XmlReader"
221+
version = get_spark_version(spark)
222+
if version.major >= 4:
223+
return
220224

225+
java_class = "com.databricks.spark.xml.XmlReader"
221226
try:
222227
try_import_java_class(spark, java_class)
223228
except Exception as e:
@@ -332,6 +337,13 @@ def parse_column(self, column: str | Column, schema: StructType) -> Column:
332337
| |-- name: string (nullable = true)
333338
| |-- age: integer (nullable = true)
334339
"""
340+
from pyspark import __version__ as spark_version
341+
342+
if spark_version > "4":
343+
from pyspark.sql.function import from_xml # noqa: WPS450
344+
345+
return from_xml(column, schema, self.dict())
346+
335347
from pyspark.sql import Column, SparkSession # noqa: WPS442
336348

337349
spark = SparkSession._instantiatedSession # noqa: WPS437

requirements/tests/spark-latest.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
numpy>=1.16
22
pandas>=1.0
33
pyarrow>=1.0
4-
pyspark
4+
pyspark==4.0.0.dev1
55
sqlalchemy

tests/fixtures/spark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def maven_packages(request):
103103
# There is no MongoDB connector for Spark less than 3.2
104104
packages.extend(MongoDB.get_packages(spark_version=str(pyspark_version)))
105105

106-
if "excel" in markers:
106+
if "excel" in markers and pyspark_version < Version("4.0"):
107107
# There is no Excel files support for Spark less than 3.2
108108
packages.extend(Excel.get_packages(spark_version=str(pyspark_version)))
109109

0 commit comments

Comments
 (0)