Skip to content

Commit c8d8b19

Browse files
committed
spark connect
1 parent 0c262d9 commit c8d8b19

File tree

16 files changed

+98
-21
lines changed

16 files changed

+98
-21
lines changed

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# pyspark
2-
pyspark==3.5.3
2+
pyspark[connect]==3.5.3
33
# linters
44
flake8==7.1.1
55
pylint==3.3.2

tests/_core/test_column.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# from typedspark import configure # isort:skip
2+
3+
# configure(spark_connect=True)
4+
15
from dataclasses import dataclass
26
from typing import Annotated
37

tests/_core/test_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import pandas as pd
44
import pytest
55
from pyspark import StorageLevel
6-
from pyspark.sql import DataFrame, SparkSession
6+
from pyspark.sql import SparkSession
77
from pyspark.sql.types import LongType, StringType
88

99
from typedspark import Column, DataSet, Schema
1010
from typedspark._core.dataset import DataSetImplements
11+
from typedspark._core.spark_imports import DataFrame
1112
from typedspark._utils.create_dataset import create_empty_dataset
1213

1314

tests/_utils/test_register_schema_to_dataset.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytest
22
from chispa.dataframe_comparer import assert_df_equality # type: ignore
3-
from pyspark.errors import AnalysisException
43
from pyspark.sql import SparkSession
54
from pyspark.sql.types import IntegerType, StringType
65

@@ -11,6 +10,7 @@
1110
create_partially_filled_dataset,
1211
register_schema_to_dataset,
1312
)
13+
from typedspark._core.spark_imports import SPARK_CONNECT, AnalysisException
1414
from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset_with_alias
1515

1616

@@ -40,7 +40,7 @@ def test_register_schema_to_dataset(spark: SparkSession):
4040
df_b = create_partially_filled_dataset(spark, Job, {Job.a: [1, 2, 3]})
4141

4242
with pytest.raises(AnalysisException):
43-
df_a.join(df_b, Person.a == Job.a)
43+
df_a.join(df_b, Person.a == Job.a).show()
4444

4545
person = register_schema_to_dataset(df_a, Person)
4646
job = register_schema_to_dataset(df_b, Job)
@@ -69,13 +69,21 @@ def test_register_schema_to_dataset_with_alias(spark: SparkSession):
6969
},
7070
)
7171

72-
with pytest.raises(AnalysisException):
72+
def self_join_without_register_schema_to_dataset_with_alias():
7373
df_a = df.alias("a")
7474
df_b = df.alias("b")
7575
schema_a = register_schema_to_dataset(df_a, Person)
7676
schema_b = register_schema_to_dataset(df_b, Person)
77-
df_a.join(df_b, schema_a.a == schema_b.b)
77+
df_a.join(df_b, schema_a.a == schema_b.b).show()
78+
79+
# there seems to be a discrepancy between spark and spark connect here
80+
if SPARK_CONNECT:
81+
self_join_without_register_schema_to_dataset_with_alias()
82+
else:
83+
with pytest.raises(AnalysisException):
84+
self_join_without_register_schema_to_dataset_with_alias()
7885

86+
# the following is the way it works with regular spark
7987
df_a, schema_a = register_schema_to_dataset_with_alias(df, Person, "a")
8088
df_b, schema_b = register_schema_to_dataset_with_alias(df, Person, "b")
8189
joined = df_a.join(df_b, schema_a.a == schema_b.b)

tests/conftest.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,35 @@
55
from pyspark.sql import SparkSession
66

77

8+
def pytest_addoption(parser):
9+
parser.addoption(
10+
"--spark-connect",
11+
action="store_true",
12+
default=False,
13+
help="Run the unit tests using a spark-connect session.",
14+
)
15+
16+
817
@pytest.fixture(scope="session")
9-
def spark():
18+
def spark(pytestconfig: pytest.Config):
1019
"""Fixture for creating a spark session."""
1120
os.environ["PYSPARK_PYTHON"] = sys.executable
1221
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
1322

14-
spark = SparkSession.Builder().getOrCreate()
23+
spark_connect = pytestconfig.getoption("--spark-connect")
24+
if spark_connect:
25+
# from typedspark import configure
26+
27+
# configure(spark_connect=True)
28+
29+
spark = (
30+
SparkSession.Builder()
31+
.config("spark.jars.packages", "org.apache.spark:spark-connect_2.12:3.5.3")
32+
.remote("local")
33+
.getOrCreate()
34+
)
35+
else:
36+
spark = SparkSession.Builder().getOrCreate()
37+
1538
yield spark
1639
spark.stop()

typedspark/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Typedspark: column-wise type annotations for pyspark DataFrames."""
22

3+
# from typedspark._core.spark_imports import configure # noqa: F401 # isort:skip
4+
35
from typedspark._core.column import Column
46
from typedspark._core.column_meta import ColumnMeta
57
from typedspark._core.dataset import DataSet, DataSetImplements

typedspark/_core/column.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
from logging import warn
44
from typing import Generic, Optional, TypeVar, Union, get_args, get_origin
55

6-
from pyspark.sql import Column as SparkColumn
7-
from pyspark.sql import DataFrame, SparkSession
6+
from pyspark.sql import SparkSession
87
from pyspark.sql.functions import col
98
from pyspark.sql.types import DataType
109

1110
from typedspark._core.datatypes import StructType
11+
from typedspark._core.spark_imports import Column as SparkColumn
12+
from typedspark._core.spark_imports import DataFrame
1213

1314
T = TypeVar("T", bound=DataType)
1415

typedspark/_core/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from typing import Callable, Generic, List, Literal, Optional, Type, TypeVar, Union, cast, overload
77

88
from pyspark import StorageLevel
9-
from pyspark.sql import Column as SparkColumn
10-
from pyspark.sql import DataFrame
119
from typing_extensions import Concatenate, ParamSpec
1210

11+
from typedspark._core.spark_imports import Column as SparkColumn
12+
from typedspark._core.spark_imports import DataFrame
1313
from typedspark._core.validate_schema import validate_schema
1414
from typedspark._schema.schema import Schema
1515

typedspark/_core/spark_imports.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
SPARK_CONNECT = False
2+
3+
if SPARK_CONNECT:
4+
from pyspark.errors.exceptions.connect import AnalysisException # type: ignore # noqa: F401
5+
from pyspark.sql.connect.column import Column # type: ignore # noqa: F401
6+
from pyspark.sql.connect.dataframe import DataFrame # type: ignore # noqa: F401
7+
else:
8+
from pyspark.sql import Column, DataFrame # type: ignore # noqa: F401
9+
from pyspark.sql.utils import AnalysisException # type: ignore # noqa: F401
10+
11+
12+
# import sys
13+
14+
# from pyspark.sql import Column, DataFrame # type: ignore # noqa: F401
15+
# from pyspark.sql.utils import AnalysisException # type: ignore # noqa: F401
16+
17+
# SPARK_CONNECT = False
18+
19+
20+
# def configure(spark_connect=False):
21+
# global SPARK_CONNECT, AnalysisException, Column, DataFrame
22+
# SPARK_CONNECT = spark_connect
23+
24+
# from pyspark.errors.exceptions.connect import ( # pylint: disable=redefined-outer-name
25+
# AnalysisException,
26+
# )
27+
# from pyspark.sql.connect.column import Column # pylint: disable=redefined-outer-name
28+
# from pyspark.sql.connect.dataframe import DataFrame # pylint: disable=redefined-outer-name
29+
30+
# sys.modules[__name__].AnalysisException = AnalysisException # type: ignore
31+
# sys.modules[__name__].Column = Column # type: ignore
32+
# sys.modules[__name__].DataFrame = DataFrame # type: ignore
33+
# hoi = True

typedspark/_schema/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
get_type_hints,
1616
)
1717

18-
from pyspark.sql import DataFrame
1918
from pyspark.sql.types import DataType, StructType
2019

2120
from typedspark._core.column import Column
21+
from typedspark._core.spark_imports import DataFrame
2222
from typedspark._schema.dlt_kwargs import DltKwargs
2323
from typedspark._schema.get_schema_definition import get_schema_definition_as_string
2424
from typedspark._schema.structfield import get_structfield

0 commit comments

Comments
 (0)