Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pyspark
pyspark==3.5.3
pyspark[connect]==3.5.3
# linters
flake8==7.1.1
pylint==3.3.2
Expand Down
4 changes: 4 additions & 0 deletions tests/_core/test_column.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# from typedspark import configure # isort:skip

# configure(spark_connect=True)

from dataclasses import dataclass
from typing import Annotated

Expand Down
3 changes: 2 additions & 1 deletion tests/_core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import pandas as pd
import pytest
from pyspark import StorageLevel
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import SparkSession
from pyspark.sql.types import LongType, StringType

from typedspark import Column, DataSet, Schema
from typedspark._core.dataset import DataSetImplements
from typedspark._core.spark_imports import DataFrame
from typedspark._utils.create_dataset import create_empty_dataset


Expand Down
16 changes: 12 additions & 4 deletions tests/_utils/test_register_schema_to_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
from chispa.dataframe_comparer import assert_df_equality # type: ignore
from pyspark.errors import AnalysisException
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType

Expand All @@ -11,6 +10,7 @@
create_partially_filled_dataset,
register_schema_to_dataset,
)
from typedspark._core.spark_imports import SPARK_CONNECT, AnalysisException
from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset_with_alias


Expand Down Expand Up @@ -40,7 +40,7 @@ def test_register_schema_to_dataset(spark: SparkSession):
df_b = create_partially_filled_dataset(spark, Job, {Job.a: [1, 2, 3]})

with pytest.raises(AnalysisException):
df_a.join(df_b, Person.a == Job.a)
df_a.join(df_b, Person.a == Job.a).show()

person = register_schema_to_dataset(df_a, Person)
job = register_schema_to_dataset(df_b, Job)
Expand Down Expand Up @@ -69,13 +69,21 @@ def test_register_schema_to_dataset_with_alias(spark: SparkSession):
},
)

with pytest.raises(AnalysisException):
def self_join_without_register_schema_to_dataset_with_alias():
df_a = df.alias("a")
df_b = df.alias("b")
schema_a = register_schema_to_dataset(df_a, Person)
schema_b = register_schema_to_dataset(df_b, Person)
df_a.join(df_b, schema_a.a == schema_b.b)
df_a.join(df_b, schema_a.a == schema_b.b).show()

# there seems to be a discrepancy between spark and spark connect here
if SPARK_CONNECT:
self_join_without_register_schema_to_dataset_with_alias()
else:
with pytest.raises(AnalysisException):
self_join_without_register_schema_to_dataset_with_alias()

# the following is the way it works with regular spark
df_a, schema_a = register_schema_to_dataset_with_alias(df, Person, "a")
df_b, schema_b = register_schema_to_dataset_with_alias(df, Person, "b")
joined = df_a.join(df_b, schema_a.a == schema_b.b)
Expand Down
27 changes: 25 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,35 @@
from pyspark.sql import SparkSession


def pytest_addoption(parser):
parser.addoption(
"--spark-connect",
action="store_true",
default=False,
help="Run the unit tests using a spark-connect session.",
)


@pytest.fixture(scope="session")
def spark():
def spark(pytestconfig: pytest.Config):
"""Fixture for creating a spark session."""
os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable

spark = SparkSession.Builder().getOrCreate()
spark_connect = pytestconfig.getoption("--spark-connect")
if spark_connect:
# from typedspark import configure

# configure(spark_connect=True)

spark = (
SparkSession.Builder()
.config("spark.jars.packages", "org.apache.spark:spark-connect_2.12:3.5.3")
.remote("local")
.getOrCreate()
)
else:
spark = SparkSession.Builder().getOrCreate()

yield spark
spark.stop()
2 changes: 2 additions & 0 deletions typedspark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Typedspark: column-wise type annotations for pyspark DataFrames."""

# from typedspark._core.spark_imports import configure # noqa: F401 # isort:skip

from typedspark._core.column import Column
from typedspark._core.column_meta import ColumnMeta
from typedspark._core.dataset import DataSet, DataSetImplements
Expand Down
5 changes: 3 additions & 2 deletions typedspark/_core/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from logging import warn
from typing import Generic, Optional, TypeVar, Union, get_args, get_origin

from pyspark.sql import Column as SparkColumn
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import DataType

from typedspark._core.datatypes import StructType
from typedspark._core.spark_imports import Column as SparkColumn
from typedspark._core.spark_imports import DataFrame

T = TypeVar("T", bound=DataType)

Expand Down
4 changes: 2 additions & 2 deletions typedspark/_core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from typing import Callable, Generic, List, Literal, Optional, Type, TypeVar, Union, cast, overload

from pyspark import StorageLevel
from pyspark.sql import Column as SparkColumn
from pyspark.sql import DataFrame
from typing_extensions import Concatenate, ParamSpec

from typedspark._core.spark_imports import Column as SparkColumn
from typedspark._core.spark_imports import DataFrame
from typedspark._core.validate_schema import validate_schema
from typedspark._schema.schema import Schema

Expand Down
33 changes: 33 additions & 0 deletions typedspark/_core/spark_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
SPARK_CONNECT = False

if SPARK_CONNECT:
from pyspark.errors.exceptions.connect import AnalysisException # type: ignore # noqa: F401
from pyspark.sql.connect.column import Column # type: ignore # noqa: F401
from pyspark.sql.connect.dataframe import DataFrame # type: ignore # noqa: F401
else:
from pyspark.sql import Column, DataFrame # type: ignore # noqa: F401
from pyspark.sql.utils import AnalysisException # type: ignore # noqa: F401


# import sys

# from pyspark.sql import Column, DataFrame # type: ignore # noqa: F401
# from pyspark.sql.utils import AnalysisException # type: ignore # noqa: F401

# SPARK_CONNECT = False


# def configure(spark_connect=False):
# global SPARK_CONNECT, AnalysisException, Column, DataFrame
# SPARK_CONNECT = spark_connect

# from pyspark.errors.exceptions.connect import ( # pylint: disable=redefined-outer-name
# AnalysisException,
# )
# from pyspark.sql.connect.column import Column # pylint: disable=redefined-outer-name
# from pyspark.sql.connect.dataframe import DataFrame # pylint: disable=redefined-outer-name

# sys.modules[__name__].AnalysisException = AnalysisException # type: ignore
# sys.modules[__name__].Column = Column # type: ignore
# sys.modules[__name__].DataFrame = DataFrame # type: ignore
# hoi = True
2 changes: 1 addition & 1 deletion typedspark/_schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
get_type_hints,
)

from pyspark.sql import DataFrame
from pyspark.sql.types import DataType, StructType

from typedspark._core.column import Column
from typedspark._core.spark_imports import DataFrame
from typedspark._schema.dlt_kwargs import DltKwargs
from typedspark._schema.get_schema_definition import get_schema_definition_as_string
from typedspark._schema.structfield import get_structfield
Expand Down
3 changes: 1 addition & 2 deletions typedspark/_transforms/rename_duplicate_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from typing import Dict, Final, Type
from uuid import uuid4

from pyspark.sql import Column as SparkColumn

from typedspark._core.spark_imports import Column as SparkColumn
from typedspark._schema.schema import Schema

ERROR_MSG: Final[
Expand Down
2 changes: 1 addition & 1 deletion typedspark/_transforms/structtype_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from typing import Dict, Optional, Type

from pyspark.sql import Column as SparkColumn
from pyspark.sql.functions import struct

from typedspark._core.column import Column
from typedspark._core.spark_imports import Column as SparkColumn
from typedspark._schema.schema import Schema
from typedspark._transforms.utils import add_nulls_for_unspecified_columns, convert_keys_to_strings

Expand Down
5 changes: 2 additions & 3 deletions typedspark/_transforms/transform_to_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from functools import reduce
from typing import Dict, Optional, Type, TypeVar, Union

from pyspark.sql import Column as SparkColumn
from pyspark.sql import DataFrame

from typedspark._core.column import Column
from typedspark._core.dataset import DataSet
from typedspark._core.spark_imports import Column as SparkColumn
from typedspark._core.spark_imports import DataFrame
from typedspark._schema.schema import Schema
from typedspark._transforms.rename_duplicate_columns import RenameDuplicateColumns
from typedspark._transforms.utils import add_nulls_for_unspecified_columns, convert_keys_to_strings
Expand Down
2 changes: 1 addition & 1 deletion typedspark/_transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from typing import Dict, List, Optional, Type

from pyspark.sql import Column as SparkColumn
from pyspark.sql.functions import lit

from typedspark._core.column import Column
from typedspark._core.spark_imports import Column as SparkColumn
from typedspark._schema.schema import Schema


Expand Down
3 changes: 2 additions & 1 deletion typedspark/_utils/load_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import re
from typing import Dict, Optional, Tuple, Type

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import SparkSession

from typedspark._core.dataset import DataSet
from typedspark._core.spark_imports import DataFrame
from typedspark._schema.schema import Schema
from typedspark._utils.create_dataset_from_structtype import create_schema_from_structtype
from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset
Expand Down
6 changes: 6 additions & 0 deletions typedspark_connect/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SPARK_CONNECT = True


def configure(spark_connect=False):
global SPARK_CONNECT
SPARK_CONNECT = spark_connect
Loading