diff --git a/requirements-dev.txt b/requirements-dev.txt index 65d0b372..fd1ed74b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ # pyspark -pyspark==3.5.3 +pyspark[connect]==3.5.3 # linters flake8==7.1.1 pylint==3.3.2 diff --git a/tests/_core/test_column.py b/tests/_core/test_column.py index 730ff3ae..a761a9b6 100644 --- a/tests/_core/test_column.py +++ b/tests/_core/test_column.py @@ -1,3 +1,7 @@ +# from typedspark import configure # isort:skip + +# configure(spark_connect=True) + from dataclasses import dataclass from typing import Annotated diff --git a/tests/_core/test_dataset.py b/tests/_core/test_dataset.py index fc04c20e..2e56995b 100644 --- a/tests/_core/test_dataset.py +++ b/tests/_core/test_dataset.py @@ -3,11 +3,12 @@ import pandas as pd import pytest from pyspark import StorageLevel -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import SparkSession from pyspark.sql.types import LongType, StringType from typedspark import Column, DataSet, Schema from typedspark._core.dataset import DataSetImplements +from typedspark._core.spark_imports import DataFrame from typedspark._utils.create_dataset import create_empty_dataset diff --git a/tests/_utils/test_register_schema_to_dataset.py b/tests/_utils/test_register_schema_to_dataset.py index 91bbf9b6..8d287daa 100644 --- a/tests/_utils/test_register_schema_to_dataset.py +++ b/tests/_utils/test_register_schema_to_dataset.py @@ -1,6 +1,5 @@ import pytest from chispa.dataframe_comparer import assert_df_equality # type: ignore -from pyspark.errors import AnalysisException from pyspark.sql import SparkSession from pyspark.sql.types import IntegerType, StringType @@ -11,6 +10,7 @@ create_partially_filled_dataset, register_schema_to_dataset, ) +from typedspark._core.spark_imports import SPARK_CONNECT, AnalysisException from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset_with_alias @@ -40,7 +40,7 @@ def test_register_schema_to_dataset(spark: SparkSession): df_b = create_partially_filled_dataset(spark, Job, {Job.a: [1, 2, 3]}) with pytest.raises(AnalysisException): - df_a.join(df_b, Person.a == Job.a) + df_a.join(df_b, Person.a == Job.a).show() person = register_schema_to_dataset(df_a, Person) job = register_schema_to_dataset(df_b, Job) @@ -69,13 +69,21 @@ def test_register_schema_to_dataset_with_alias(spark: SparkSession): }, ) - with pytest.raises(AnalysisException): + def self_join_without_register_schema_to_dataset_with_alias(): df_a = df.alias("a") df_b = df.alias("b") schema_a = register_schema_to_dataset(df_a, Person) schema_b = register_schema_to_dataset(df_b, Person) - df_a.join(df_b, schema_a.a == schema_b.b) + df_a.join(df_b, schema_a.a == schema_b.b).show() + + # there seems to be a discrepancy between spark and spark connect here + if SPARK_CONNECT: + self_join_without_register_schema_to_dataset_with_alias() + else: + with pytest.raises(AnalysisException): + self_join_without_register_schema_to_dataset_with_alias() + # the following is the way it works with regular spark df_a, schema_a = register_schema_to_dataset_with_alias(df, Person, "a") df_b, schema_b = register_schema_to_dataset_with_alias(df, Person, "b") joined = df_a.join(df_b, schema_a.a == schema_b.b) diff --git a/tests/conftest.py b/tests/conftest.py index 22ba1940..2c4854a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,12 +5,35 @@ from pyspark.sql import SparkSession +def pytest_addoption(parser): + parser.addoption( + "--spark-connect", + action="store_true", + default=False, + help="Run the unit tests using a spark-connect session.", + ) + + @pytest.fixture(scope="session") -def spark(): +def spark(pytestconfig: pytest.Config): """Fixture for creating a spark session.""" os.environ["PYSPARK_PYTHON"] = sys.executable os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable - spark = SparkSession.Builder().getOrCreate() + spark_connect = pytestconfig.getoption("--spark-connect") + if spark_connect: + # from typedspark import configure + + # configure(spark_connect=True) + + spark = ( + SparkSession.Builder() + .config("spark.jars.packages", "org.apache.spark:spark-connect_2.12:3.5.3") + .remote("local") + .getOrCreate() + ) + else: + spark = SparkSession.Builder().getOrCreate() + yield spark spark.stop() diff --git a/typedspark/__init__.py b/typedspark/__init__.py index 11a5bb89..47b47032 100644 --- a/typedspark/__init__.py +++ b/typedspark/__init__.py @@ -1,5 +1,7 @@ """Typedspark: column-wise type annotations for pyspark DataFrames.""" +# from typedspark._core.spark_imports import configure # noqa: F401 # isort:skip + from typedspark._core.column import Column from typedspark._core.column_meta import ColumnMeta from typedspark._core.dataset import DataSet, DataSetImplements diff --git a/typedspark/_core/column.py b/typedspark/_core/column.py index 3e528439..f5bbea4f 100644 --- a/typedspark/_core/column.py +++ b/typedspark/_core/column.py @@ -3,12 +3,13 @@ from logging import warn from typing import Generic, Optional, TypeVar, Union, get_args, get_origin -from pyspark.sql import Column as SparkColumn -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import SparkSession from pyspark.sql.functions import col from pyspark.sql.types import DataType from typedspark._core.datatypes import StructType +from typedspark._core.spark_imports import Column as SparkColumn +from typedspark._core.spark_imports import DataFrame T = TypeVar("T", bound=DataType) diff --git a/typedspark/_core/dataset.py b/typedspark/_core/dataset.py index 2b0f1210..bc0e28ba 100644 --- a/typedspark/_core/dataset.py +++ b/typedspark/_core/dataset.py @@ -6,10 +6,10 @@ from typing import Callable, Generic, List, Literal, Optional, Type, TypeVar, Union, cast, overload from pyspark import StorageLevel -from pyspark.sql import Column as SparkColumn -from pyspark.sql import DataFrame from typing_extensions import Concatenate, ParamSpec +from typedspark._core.spark_imports import Column as SparkColumn +from typedspark._core.spark_imports import DataFrame from typedspark._core.validate_schema import validate_schema from typedspark._schema.schema import Schema diff --git a/typedspark/_core/spark_imports.py b/typedspark/_core/spark_imports.py new file mode 100644 index 00000000..fe85fa6b --- /dev/null +++ b/typedspark/_core/spark_imports.py @@ -0,0 +1,33 @@ +SPARK_CONNECT = False + +if SPARK_CONNECT: + from pyspark.errors.exceptions.connect import AnalysisException # type: ignore # noqa: F401 + from pyspark.sql.connect.column import Column # type: ignore # noqa: F401 + from pyspark.sql.connect.dataframe import DataFrame # type: ignore # noqa: F401 +else: + from pyspark.sql import Column, DataFrame # type: ignore # noqa: F401 + from pyspark.sql.utils import AnalysisException # type: ignore # noqa: F401 + + +# import sys + +# from pyspark.sql import Column, DataFrame # type: ignore # noqa: F401 +# from pyspark.sql.utils import AnalysisException # type: ignore # noqa: F401 + +# SPARK_CONNECT = False + + +# def configure(spark_connect=False): +# global SPARK_CONNECT, AnalysisException, Column, DataFrame +# SPARK_CONNECT = spark_connect + +# from pyspark.errors.exceptions.connect import ( # pylint: disable=redefined-outer-name +# AnalysisException, +# ) +# from pyspark.sql.connect.column import Column # pylint: disable=redefined-outer-name +# from pyspark.sql.connect.dataframe import DataFrame # pylint: disable=redefined-outer-name + +# sys.modules[__name__].AnalysisException = AnalysisException # type: ignore +# sys.modules[__name__].Column = Column # type: ignore +# sys.modules[__name__].DataFrame = DataFrame # type: ignore +# hoi = True diff --git a/typedspark/_schema/schema.py b/typedspark/_schema/schema.py index f42e930a..b6764c7c 100644 --- a/typedspark/_schema/schema.py +++ b/typedspark/_schema/schema.py @@ -15,10 +15,10 @@ get_type_hints, ) -from pyspark.sql import DataFrame from pyspark.sql.types import DataType, StructType from typedspark._core.column import Column +from typedspark._core.spark_imports import DataFrame from typedspark._schema.dlt_kwargs import DltKwargs from typedspark._schema.get_schema_definition import get_schema_definition_as_string from typedspark._schema.structfield import get_structfield diff --git a/typedspark/_transforms/rename_duplicate_columns.py b/typedspark/_transforms/rename_duplicate_columns.py index a112a61e..fff6d7c9 100644 --- a/typedspark/_transforms/rename_duplicate_columns.py +++ b/typedspark/_transforms/rename_duplicate_columns.py @@ -5,8 +5,7 @@ from typing import Dict, Final, Type from uuid import uuid4 -from pyspark.sql import Column as SparkColumn - +from typedspark._core.spark_imports import Column as SparkColumn from typedspark._schema.schema import Schema ERROR_MSG: Final[ diff --git a/typedspark/_transforms/structtype_column.py b/typedspark/_transforms/structtype_column.py index b4c1dafa..e11ee35c 100644 --- a/typedspark/_transforms/structtype_column.py +++ b/typedspark/_transforms/structtype_column.py @@ -2,10 +2,10 @@ from typing import Dict, Optional, Type -from pyspark.sql import Column as SparkColumn from pyspark.sql.functions import struct from typedspark._core.column import Column +from typedspark._core.spark_imports import Column as SparkColumn from typedspark._schema.schema import Schema from typedspark._transforms.utils import add_nulls_for_unspecified_columns, convert_keys_to_strings diff --git a/typedspark/_transforms/transform_to_schema.py b/typedspark/_transforms/transform_to_schema.py index 2cb23fc1..82d044b8 100644 --- a/typedspark/_transforms/transform_to_schema.py +++ b/typedspark/_transforms/transform_to_schema.py @@ -3,11 +3,10 @@ from functools import reduce from typing import Dict, Optional, Type, TypeVar, Union -from pyspark.sql import Column as SparkColumn -from pyspark.sql import DataFrame - from typedspark._core.column import Column from typedspark._core.dataset import DataSet +from typedspark._core.spark_imports import Column as SparkColumn +from typedspark._core.spark_imports import DataFrame from typedspark._schema.schema import Schema from typedspark._transforms.rename_duplicate_columns import RenameDuplicateColumns from typedspark._transforms.utils import add_nulls_for_unspecified_columns, convert_keys_to_strings diff --git a/typedspark/_transforms/utils.py b/typedspark/_transforms/utils.py index 0c0bd9d3..9270cfa9 100644 --- a/typedspark/_transforms/utils.py +++ b/typedspark/_transforms/utils.py @@ -2,10 +2,10 @@ from typing import Dict, List, Optional, Type -from pyspark.sql import Column as SparkColumn from pyspark.sql.functions import lit from typedspark._core.column import Column +from typedspark._core.spark_imports import Column as SparkColumn from typedspark._schema.schema import Schema diff --git a/typedspark/_utils/load_table.py b/typedspark/_utils/load_table.py index ab578442..bc64c99d 100644 --- a/typedspark/_utils/load_table.py +++ b/typedspark/_utils/load_table.py @@ -3,9 +3,10 @@ import re from typing import Dict, Optional, Tuple, Type -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import SparkSession from typedspark._core.dataset import DataSet +from typedspark._core.spark_imports import DataFrame from typedspark._schema.schema import Schema from typedspark._utils.create_dataset_from_structtype import create_schema_from_structtype from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset diff --git a/typedspark_connect/__init__.py b/typedspark_connect/__init__.py new file mode 100644 index 00000000..a130d280 --- /dev/null +++ b/typedspark_connect/__init__.py @@ -0,0 +1,6 @@ +SPARK_CONNECT = True + + +def configure(spark_connect=False): + global SPARK_CONNECT + SPARK_CONNECT = spark_connect