@@ -29,17 +29,16 @@ import org.neo4j.spark.converter.CypherToSparkTypeConverter.timeType
2929import org .neo4j .spark .converter .SparkToCypherTypeConverter .mapping
3030import org .neo4j .spark .service .SchemaService .normalizedClassName
3131import org .neo4j .spark .util .Neo4jImplicits .EntityImplicits
32+ import org .neo4j .spark .util .Neo4jOptions
3233
3334import scala .collection .JavaConverters ._
3435
3536trait TypeConverter [SOURCE_TYPE , DESTINATION_TYPE ] {
36-
3737 def convert (sourceType : SOURCE_TYPE , value : Any = null ): DESTINATION_TYPE
38-
3938}
4039
4140object CypherToSparkTypeConverter {
42- def apply (): CypherToSparkTypeConverter = new CypherToSparkTypeConverter ()
41+ def apply (options : Neo4jOptions ): CypherToSparkTypeConverter = new CypherToSparkTypeConverter (options )
4342
4443 private val cleanTerms : String = " Unmodifiable|Internal|Iso|2D|3D|Offset"
4544
@@ -66,66 +65,78 @@ object CypherToSparkTypeConverter {
6665 ))
6766}
6867
69- class CypherToSparkTypeConverter extends TypeConverter [String , DataType ] {
68+ class CypherToSparkTypeConverter ( options : Neo4jOptions ) extends TypeConverter [String , DataType ] {
7069
71- override def convert (sourceType : String , value : Any = null ): DataType = sourceType
72- .replaceAll(cleanTerms, " " ) match {
73- case " Node" | " Relationship" => if (value != null ) value.asInstanceOf [Entity ].toStruct else DataTypes .NullType
74- case " NodeArray" | " RelationshipArray" =>
75- if (value != null ) DataTypes .createArrayType(value.asInstanceOf [Entity ].toStruct) else DataTypes .NullType
76- case " Boolean" => DataTypes .BooleanType
77- case " Long" => DataTypes .LongType
78- case " Double" => DataTypes .DoubleType
79- case " Point" => pointType
80- case " DateTime" | " ZonedDateTime" => DataTypes .TimestampType
81- case " LocalDateTime" => DataTypes .TimestampNTZType
82- case " Time" | " LocalTime" => timeType
83- case " Date" | " LocalDate" => DataTypes .DateType
84- case " Duration" => durationType
85- case " ByteArray" => DataTypes .BinaryType
86- case " Map" => {
87- val valueType = if (value == null ) {
88- DataTypes .NullType
89- } else {
90- val map = value.asInstanceOf [java.util.Map [String , AnyRef ]].asScala
91- val types = map.values
92- .map(normalizedClassName)
93- .toSet
94- if (types.size == 1 ) convert(types.head, map.values.head) else DataTypes .StringType
95- }
96- DataTypes .createMapType(DataTypes .StringType , valueType)
70+ override def convert (sourceType : String , value : Any = null ): DataType = {
71+ var cleanedSourceType = sourceType.replaceAll(cleanTerms, " " )
72+ if (options.legacyTypeConversionEnabled) {
73+ cleanedSourceType = cleanedSourceType.replaceAll(" Local|Zoned" , " " )
9774 }
98- case " Array" => {
99- val valueType = if (value == null ) {
100- DataTypes .NullType
101- } else {
102- val list = value.asInstanceOf [java.util.List [AnyRef ]].asScala
103- val types = list
104- .map(normalizedClassName)
105- .toSet
106- if (types.size == 1 ) convert(types.head, list.head) else DataTypes .StringType
75+ cleanedSourceType match {
76+ case " Node" | " Relationship" =>
77+ if (value != null ) value.asInstanceOf [Entity ].toStruct(options) else DataTypes .NullType
78+ case " NodeArray" | " RelationshipArray" =>
79+ if (value != null ) DataTypes .createArrayType(value.asInstanceOf [Entity ].toStruct(options))
80+ else DataTypes .NullType
81+ case " Boolean" => DataTypes .BooleanType
82+ case " Long" => DataTypes .LongType
83+ case " Double" => DataTypes .DoubleType
84+ case " Point" => pointType
85+ case " DateTime" | " ZonedDateTime" => DataTypes .TimestampType
86+ case " LocalDateTime" =>
87+ if (options.legacyTypeConversionEnabled) {
88+ DataTypes .TimestampType
89+ } else {
90+ DataTypes .TimestampNTZType
91+ }
92+ case " Time" | " LocalTime" => timeType
93+ case " Date" | " LocalDate" => DataTypes .DateType
94+ case " Duration" => durationType
95+ case " ByteArray" => DataTypes .BinaryType
96+ case " Map" => {
97+ val valueType = if (value == null ) {
98+ DataTypes .NullType
99+ } else {
100+ val map = value.asInstanceOf [java.util.Map [String , AnyRef ]].asScala
101+ val types = map.values
102+ .map(value => normalizedClassName(value, options))
103+ .toSet
104+ if (types.size == 1 ) convert(types.head, map.values.head) else DataTypes .StringType
105+ }
106+ DataTypes .createMapType(DataTypes .StringType , valueType)
107107 }
108- DataTypes .createArrayType(valueType)
108+ case " Array" => {
109+ val valueType = if (value == null ) {
110+ DataTypes .NullType
111+ } else {
112+ val list = value.asInstanceOf [java.util.List [AnyRef ]].asScala
113+ val types = list
114+ .map(value => normalizedClassName(value, options))
115+ .toSet
116+ if (types.size == 1 ) convert(types.head, list.head) else DataTypes .StringType
117+ }
118+ DataTypes .createArrayType(valueType)
119+ }
120+ // These are from APOC
121+ case " StringArray" => DataTypes .createArrayType(DataTypes .StringType )
122+ case " LongArray" => DataTypes .createArrayType(DataTypes .LongType )
123+ case " DoubleArray" => DataTypes .createArrayType(DataTypes .DoubleType )
124+ case " BooleanArray" => DataTypes .createArrayType(DataTypes .BooleanType )
125+ case " PointArray" => DataTypes .createArrayType(pointType)
126+ case " DateTimeArray" | " ZonedDateTimeArray" => DataTypes .createArrayType(DataTypes .TimestampType )
127+ case " TimeArray" | " LocalTimeArray" => DataTypes .createArrayType(timeType)
128+ case " DateArray" | " LocalDateArray" => DataTypes .createArrayType(DataTypes .DateType )
129+ case " DurationArray" => DataTypes .createArrayType(durationType)
130+ // Default is String
131+ case _ => DataTypes .StringType
109132 }
110- // These are from APOC
111- case " StringArray" => DataTypes .createArrayType(DataTypes .StringType )
112- case " LongArray" => DataTypes .createArrayType(DataTypes .LongType )
113- case " DoubleArray" => DataTypes .createArrayType(DataTypes .DoubleType )
114- case " BooleanArray" => DataTypes .createArrayType(DataTypes .BooleanType )
115- case " PointArray" => DataTypes .createArrayType(pointType)
116- case " DateTimeArray" | " ZonedDateTimeArray" => DataTypes .createArrayType(DataTypes .TimestampType )
117- case " TimeArray" | " LocalTimeArray" => DataTypes .createArrayType(timeType)
118- case " DateArray" | " LocalDateArray" => DataTypes .createArrayType(DataTypes .DateType )
119- case " DurationArray" => DataTypes .createArrayType(durationType)
120- // Default is String
121- case _ => DataTypes .StringType
122133 }
123134}
124135
125136object SparkToCypherTypeConverter {
126- def apply (): SparkToCypherTypeConverter = new SparkToCypherTypeConverter ()
137+ def apply (options : Neo4jOptions ): SparkToCypherTypeConverter = new SparkToCypherTypeConverter (options )
127138
128- private val mapping : Map [DataType , String ] = Map (
139+ private val baseMappings : Map [DataType , String ] = Map (
129140 DataTypes .BooleanType -> " BOOLEAN" ,
130141 DataTypes .StringType -> " STRING" ,
131142 DecimalType .SYSTEM_DEFAULT -> " STRING" ,
@@ -136,34 +147,50 @@ object SparkToCypherTypeConverter {
136147 DataTypes .FloatType -> " FLOAT" ,
137148 DataTypes .DoubleType -> " FLOAT" ,
138149 DataTypes .DateType -> " DATE" ,
139- DataTypes .TimestampType -> " ZONED DATETIME" ,
140- DataTypes .TimestampNTZType -> " LOCAL DATETIME" ,
141- DayTimeIntervalType () -> " DURATION" ,
142- YearMonthIntervalType () -> " DURATION" ,
143150 durationType -> " DURATION" ,
144151 pointType -> " POINT" ,
145152 // Cypher graph entities do not allow null values in arrays
146153 DataTypes .createArrayType(DataTypes .BooleanType , false ) -> " LIST<BOOLEAN NOT NULL>" ,
147154 DataTypes .createArrayType(DataTypes .StringType , false ) -> " LIST<STRING NOT NULL>" ,
148155 DataTypes .createArrayType(DecimalType .SYSTEM_DEFAULT , false ) -> " LIST<STRING NOT NULL>" ,
149- DataTypes .createArrayType(DataTypes .ByteType , false ) -> " ByteArray" ,
150156 DataTypes .createArrayType(DataTypes .ShortType , false ) -> " LIST<INTEGER NOT NULL>" ,
151157 DataTypes .createArrayType(DataTypes .IntegerType , false ) -> " LIST<INTEGER NOT NULL>" ,
152158 DataTypes .createArrayType(DataTypes .LongType , false ) -> " LIST<INTEGER NOT NULL>" ,
153159 DataTypes .createArrayType(DataTypes .FloatType , false ) -> " LIST<FLOAT NOT NULL>" ,
154160 DataTypes .createArrayType(DataTypes .DoubleType , false ) -> " LIST<FLOAT NOT NULL>" ,
155161 DataTypes .createArrayType(DataTypes .DateType , false ) -> " LIST<DATE NOT NULL>" ,
156- DataTypes .createArrayType(DataTypes .TimestampType , false ) -> " LIST<ZONED DATETIME NOT NULL>" ,
157- DataTypes .createArrayType(DataTypes .TimestampNTZType , false ) -> " LIST<LOCAL DATETIME NOT NULL>" ,
158- DataTypes .createArrayType(DayTimeIntervalType (), false ) -> " LIST<DURATION NOT NULL>" ,
159- DataTypes .createArrayType(DayTimeIntervalType (), true ) -> " LIST<DURATION NOT NULL>" ,
160- DataTypes .createArrayType(YearMonthIntervalType (), false ) -> " LIST<DURATION NOT NULL>" ,
161- DataTypes .createArrayType(YearMonthIntervalType (), true ) -> " LIST<DURATION NOT NULL>" ,
162162 DataTypes .createArrayType(durationType, false ) -> " LIST<DURATION NOT NULL>" ,
163163 DataTypes .createArrayType(pointType, false ) -> " LIST<POINT NOT NULL>"
164164 )
165+
166+ private def mapping (sourceType : DataType , options : Neo4jOptions ): String = {
167+ val mappings = sourceTypeMappings(options)
168+ mappings(sourceType)
169+ }
170+
171+ private def sourceTypeMappings (options : Neo4jOptions ): Map [DataType , String ] = {
172+ var result = baseMappings
173+ if (options.legacyTypeConversionEnabled) {
174+ result += (DataTypes .TimestampType -> " LOCAL DATETIME" )
175+ result += (DataTypes .createArrayType(DataTypes .TimestampType , false ) -> " LIST<LOCAL DATETIME NOT NULL>" )
176+ result += (DataTypes .createArrayType(DataTypes .TimestampType , true ) -> " LIST<LOCAL DATETIME NOT NULL>" )
177+ } else {
178+ result += (DataTypes .TimestampType -> " ZONED DATETIME" )
179+ result += (DataTypes .TimestampNTZType -> " LOCAL DATETIME" )
180+ result += (DayTimeIntervalType () -> " DURATION" )
181+ result += (YearMonthIntervalType () -> " DURATION" )
182+ result += (DataTypes .createArrayType(DataTypes .ByteType , false ) -> " ByteArray" )
183+ result += (DataTypes .createArrayType(DataTypes .TimestampType , false ) -> " LIST<ZONED DATETIME NOT NULL>" )
184+ result += (DataTypes .createArrayType(DataTypes .TimestampNTZType , false ) -> " LIST<LOCAL DATETIME NOT NULL>" )
185+ result += (DataTypes .createArrayType(DayTimeIntervalType (), false ) -> " LIST<DURATION NOT NULL>" )
186+ result += (DataTypes .createArrayType(DayTimeIntervalType (), true ) -> " LIST<DURATION NOT NULL>" )
187+ result += (DataTypes .createArrayType(YearMonthIntervalType (), false ) -> " LIST<DURATION NOT NULL>" )
188+ result += (DataTypes .createArrayType(YearMonthIntervalType (), true ) -> " LIST<DURATION NOT NULL>" )
189+ }
190+ result
191+ }
165192}
166193
167- class SparkToCypherTypeConverter extends TypeConverter [DataType , String ] {
168- override def convert (sourceType : DataType , value : Any ): String = mapping(sourceType)
194+ class SparkToCypherTypeConverter ( options : Neo4jOptions ) extends TypeConverter [DataType , String ] {
195+ override def convert (sourceType : DataType , value : Any ): String = mapping(sourceType, options )
169196}
0 commit comments