11from __future__ import annotations
22
3+ from functools import partial
4+ from inspect import isclass
5+
36import pyspark
47import pyspark .sql .types as pt
58from packaging .version import parse as vparse
2730 pt .NullType : dt .Null ,
2831 pt .ShortType : dt .Int16 ,
2932 pt .StringType : dt .String ,
30- pt .TimestampType : dt .Timestamp ,
3133}
3234
33- _to_pyspark_dtypes = {v : k for k , v in _from_pyspark_dtypes .items ()}
35+ try :
36+ _from_pyspark_dtypes [pt .TimestampNTZType ] = dt .Timestamp
37+ except AttributeError :
38+ _from_pyspark_dtypes [pt .TimestampType ] = dt .Timestamp
39+ else :
40+ _from_pyspark_dtypes [pt .TimestampType ] = partial (dt .Timestamp , timezone = "UTC" )
41+
42+ _to_pyspark_dtypes = {
43+ v : k
44+ for k , v in _from_pyspark_dtypes .items ()
45+ if isclass (v ) and not issubclass (v , dt .Timestamp ) and not isinstance (v , partial )
46+ }
3447_to_pyspark_dtypes [dt .JSON ] = pt .StringType
3548_to_pyspark_dtypes [dt .UUID ] = pt .StringType
3649
@@ -54,9 +67,7 @@ def to_ibis(cls, typ, nullable=True):
5467 return dt .Array (cls .to_ibis (typ .elementType ), nullable = nullable )
5568 elif isinstance (typ , pt .MapType ):
5669 return dt .Map (
57- cls .to_ibis (typ .keyType ),
58- cls .to_ibis (typ .valueType ),
59- nullable = nullable ,
70+ cls .to_ibis (typ .keyType ), cls .to_ibis (typ .valueType ), nullable = nullable
6071 )
6172 elif isinstance (typ , pt .StructType ):
6273 fields = {f .name : cls .to_ibis (f .dataType ) for f in typ .fields }
@@ -97,11 +108,17 @@ def from_ibis(cls, dtype):
97108 value_contains_null = dtype .value_type .nullable
98109 return pt .MapType (key_type , value_type , value_contains_null )
99110 elif dtype .is_struct ():
100- fields = [
101- pt .StructField (n , cls .from_ibis (t ), t .nullable )
102- for n , t in dtype .fields .items ()
103- ]
104- return pt .StructType (fields )
111+ return pt .StructType (
112+ [
113+ pt .StructField (field , cls .from_ibis (dtype ), dtype .nullable )
114+ for field , dtype in dtype .fields .items ()
115+ ]
116+ )
117+ elif dtype .is_timestamp ():
118+ if dtype .timezone is not None :
119+ return pt .TimestampType ()
120+ else :
121+ return pt .TimestampNTZType ()
105122 else :
106123 try :
107124 return _to_pyspark_dtypes [type (dtype )]()
@@ -114,11 +131,7 @@ def from_ibis(cls, dtype):
114131class PySparkSchema (SchemaMapper ):
115132 @classmethod
116133 def from_ibis (cls , schema ):
117- fields = [
118- pt .StructField (name , PySparkType .from_ibis (dtype ), dtype .nullable )
119- for name , dtype in schema .items ()
120- ]
121- return pt .StructType (fields )
134+ return PySparkType .from_ibis (schema .as_struct ())
122135
123136 @classmethod
124137 def to_ibis (cls , schema ):
0 commit comments