11from __future__ import annotations
22
3- import pyspark
3+ from functools import partial
4+ from inspect import isclass
5+
46import pyspark .sql .types as pt
5- from packaging .version import parse as vparse
67
78import ibis .common .exceptions as com
89import ibis .expr .datatypes as dt
910import ibis .expr .schema as sch
1011from ibis .formats import SchemaMapper , TypeMapper
1112
12- # DayTimeIntervalType introduced in Spark 3.2 (at least) but didn't show up in
13- # PySpark until version 3.3
14- PYSPARK_33 = vparse (pyspark .__version__ ) >= vparse ("3.3" )
15- PYSPARK_35 = vparse (pyspark .__version__ ) >= vparse ("3.5" )
16-
17-
1813_from_pyspark_dtypes = {
1914 pt .BinaryType : dt .Binary ,
2015 pt .BooleanType : dt .Boolean ,
2722 pt .NullType : dt .Null ,
2823 pt .ShortType : dt .Int16 ,
2924 pt .StringType : dt .String ,
30- pt .TimestampType : dt .Timestamp ,
3125}
3226
33- _to_pyspark_dtypes = {v : k for k , v in _from_pyspark_dtypes .items ()}
27+ try :
28+ _from_pyspark_dtypes [pt .TimestampNTZType ] = dt .Timestamp
29+ except AttributeError :
30+ _from_pyspark_dtypes [pt .TimestampType ] = dt .Timestamp
31+ else :
32+ _from_pyspark_dtypes [pt .TimestampType ] = partial (dt .Timestamp , timezone = "UTC" )
33+
34+ _to_pyspark_dtypes = {
35+ v : k
36+ for k , v in _from_pyspark_dtypes .items ()
37+ if isclass (v ) and not issubclass (v , dt .Timestamp ) and not isinstance (v , partial )
38+ }
3439_to_pyspark_dtypes [dt .JSON ] = pt .StringType
3540_to_pyspark_dtypes [dt .UUID ] = pt .StringType
3641
3742
38- if PYSPARK_33 :
39- _pyspark_interval_units = {
40- pt .DayTimeIntervalType .SECOND : "s" ,
41- pt .DayTimeIntervalType .MINUTE : "m" ,
42- pt .DayTimeIntervalType .HOUR : "h" ,
43- pt .DayTimeIntervalType .DAY : "D" ,
44- }
45-
46-
4743class PySparkType (TypeMapper ):
4844 @classmethod
4945 def to_ibis (cls , typ , nullable = True ):
5046 """Convert a pyspark type to an ibis type."""
47+ from ibis .backends .pyspark import SUPPORTS_TIMESTAMP_NTZ
48+
5149 if isinstance (typ , pt .DecimalType ):
5250 return dt .Decimal (typ .precision , typ .scale , nullable = nullable )
5351 elif isinstance (typ , pt .ArrayType ):
5452 return dt .Array (cls .to_ibis (typ .elementType ), nullable = nullable )
5553 elif isinstance (typ , pt .MapType ):
5654 return dt .Map (
57- cls .to_ibis (typ .keyType ),
58- cls .to_ibis (typ .valueType ),
59- nullable = nullable ,
55+ cls .to_ibis (typ .keyType ), cls .to_ibis (typ .valueType ), nullable = nullable
6056 )
6157 elif isinstance (typ , pt .StructType ):
6258 fields = {f .name : cls .to_ibis (f .dataType ) for f in typ .fields }
6359
6460 return dt .Struct (fields , nullable = nullable )
65- elif PYSPARK_33 and isinstance (typ , pt .DayTimeIntervalType ):
61+ elif isinstance (typ , pt .DayTimeIntervalType ):
62+ pyspark_interval_units = {
63+ pt .DayTimeIntervalType .SECOND : "s" ,
64+ pt .DayTimeIntervalType .MINUTE : "m" ,
65+ pt .DayTimeIntervalType .HOUR : "h" ,
66+ pt .DayTimeIntervalType .DAY : "D" ,
67+ }
68+
6669 if (
6770 typ .startField == typ .endField
68- and typ .startField in _pyspark_interval_units
71+ and typ .startField in pyspark_interval_units
6972 ):
70- unit = _pyspark_interval_units [typ .startField ]
73+ unit = pyspark_interval_units [typ .startField ]
7174 return dt .Interval (unit , nullable = nullable )
7275 else :
7376 raise com .IbisTypeError (f"{ typ !r} couldn't be converted to Interval" )
74- elif PYSPARK_35 and isinstance (typ , pt .TimestampNTZType ):
75- return dt .Timestamp (nullable = nullable )
77+ elif isinstance (typ , pt .TimestampNTZType ):
78+ if SUPPORTS_TIMESTAMP_NTZ :
79+ return dt .Timestamp (nullable = nullable )
80+ raise com .UnsupportedBackendType (
81+ "PySpark<3.4 doesn't properly support timestamps without a timezone"
82+ )
7683 elif isinstance (typ , pt .UserDefinedType ):
7784 return cls .to_ibis (typ .sqlType (), nullable = nullable )
7885 else :
@@ -85,6 +92,8 @@ def to_ibis(cls, typ, nullable=True):
8592
8693 @classmethod
8794 def from_ibis (cls , dtype ):
95+ from ibis .backends .pyspark import SUPPORTS_TIMESTAMP_NTZ
96+
8897 if dtype .is_decimal ():
8998 return pt .DecimalType (dtype .precision , dtype .scale )
9099 elif dtype .is_array ():
@@ -97,11 +106,21 @@ def from_ibis(cls, dtype):
97106 value_contains_null = dtype .value_type .nullable
98107 return pt .MapType (key_type , value_type , value_contains_null )
99108 elif dtype .is_struct ():
100- fields = [
101- pt .StructField (n , cls .from_ibis (t ), t .nullable )
102- for n , t in dtype .fields .items ()
103- ]
104- return pt .StructType (fields )
109+ return pt .StructType (
110+ [
111+ pt .StructField (field , cls .from_ibis (dtype ), dtype .nullable )
112+ for field , dtype in dtype .fields .items ()
113+ ]
114+ )
115+ elif dtype .is_timestamp ():
116+ if dtype .timezone is not None :
117+ return pt .TimestampType ()
118+ else :
119+ if not SUPPORTS_TIMESTAMP_NTZ :
120+ raise com .UnsupportedBackendType (
121+ "PySpark<3.4 doesn't properly support timestamps without a timezone"
122+ )
123+ return pt .TimestampNTZType ()
105124 else :
106125 try :
107126 return _to_pyspark_dtypes [type (dtype )]()
@@ -114,11 +133,7 @@ def from_ibis(cls, dtype):
114133class PySparkSchema (SchemaMapper ):
115134 @classmethod
116135 def from_ibis (cls , schema ):
117- fields = [
118- pt .StructField (name , PySparkType .from_ibis (dtype ), dtype .nullable )
119- for name , dtype in schema .items ()
120- ]
121- return pt .StructType (fields )
136+ return PySparkType .from_ibis (schema .as_struct ())
122137
123138 @classmethod
124139 def to_ibis (cls , schema ):
0 commit comments