diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 861471d8..3263a335 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,7 +18,7 @@ jobs: java-version: 21 - uses: vemonet/setup-spark@v1 with: - spark-version: '3.5.3' + spark-version: '4.0.0' hadoop-version: '3' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 diff --git a/tests/_core/test_dataset.py b/tests/_core/test_dataset.py index b6397fb5..b8a1785c 100644 --- a/tests/_core/test_dataset.py +++ b/tests/_core/test_dataset.py @@ -79,7 +79,7 @@ def test_wrong_type(spark: SparkSession): def test_inherrited_functions(spark: SparkSession): df = create_empty_dataset(spark, A) - + assert hasattr(df, "_jdf") df.distinct() cached1: DataSet[A] = df.cache() cached2: DataSet[A] = df.persist(StorageLevel.MEMORY_AND_DISK) @@ -107,7 +107,7 @@ def test_schema_property_of_dataset(spark: SparkSession): def test_initialize_dataset_implements(spark: SparkSession): with pytest.raises(NotImplementedError): - DataSetImplements() + DataSetImplements() # type: ignore def test_reduce(spark: SparkSession): diff --git a/tests/conftest.py b/tests/conftest.py index 22ba1940..508b20f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,10 @@ def spark(): """Fixture for creating a spark session.""" os.environ["PYSPARK_PYTHON"] = sys.executable os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable + os.environ.pop("SPARK_REMOTE", None) + os.environ.pop("PYSPARK_CONNECT_MODE_ENABLED", None) - spark = SparkSession.Builder().getOrCreate() + SparkSession._instantiatedSession = None # clear any existing session + spark = SparkSession.builder.master("local[2]").getOrCreate() yield spark spark.stop() diff --git a/typedspark/_core/dataset.py b/typedspark/_core/dataset.py index 0e6431cb..77dd1229 100644 --- a/typedspark/_core/dataset.py +++ b/typedspark/_core/dataset.py @@ -2,12 +2,12 @@ from __future__ import annotations -from copy import deepcopy from typing import Callable, Generic, List, Literal, Optional, Type, TypeVar, Union, cast, overload from pyspark import StorageLevel from pyspark.sql import Column as SparkColumn from pyspark.sql import DataFrame +from pyspark.sql.types import StructType from typing_extensions import Concatenate, ParamSpec from typedspark._core.validate_schema import validate_schema @@ -44,6 +44,11 @@ def birthday(df: DataSetImplements[Age, T]) -> DataSet[T]: _schema_annotations: Type[_Implementation] + def __new__(cls, *args, **kwargs): + raise NotImplementedError( + "DataSetImplements should solely be used as a type annotation; it is never initialized." + ) + def __init__(self): raise NotImplementedError( "DataSetImplements should solely be used as a type annotation, it is never initialized." @@ -184,6 +189,12 @@ def __new__(cls, dataframe: DataFrame) -> DataSet[_Schema]: be difficult to access. Subsequently, we perform schema validation, if the schema annotations are provided. """ + try: + schema_snapshot: StructType = StructType.fromJson(dataframe.schema.jsonValue()) + except Exception: + # last-ditch: still try the property + schema_snapshot = dataframe.schema # type: ignore + dataframe = cast(DataSet, dataframe) dataframe.__class__ = DataSet @@ -194,13 +205,14 @@ def __new__(cls, dataframe: DataFrame) -> DataSet[_Schema]: # then we use the class' schema annotations to validate the schema and add metadata if hasattr(cls, "_schema_annotations"): dataframe._schema_annotations = cls._schema_annotations # type: ignore + dataframe._schema_snapshot = schema_snapshot # type: ignore[attr-defined] dataframe._validate_schema() - dataframe._add_schema_metadata() return dataframe # type: ignore def __init__(self, dataframe: DataFrame): - pass + # pylint: disable=unused-argument + self._add_schema_metadata() def __class_getitem__(cls, item): """Allows us to define a schema for the ``DataSet``. @@ -216,7 +228,7 @@ def _validate_schema(self) -> None: """Validates the schema of the ``DataSet`` against the schema annotations.""" validate_schema( self._schema_annotations.get_structtype(), - deepcopy(self.schema), + self._schema_snapshot, # type: ignore self._schema_annotations.get_schema_name(), )