Skip to content

Commit 1ac19c3

Browse files
authored
feat: allow users to enable the legacy type conversion (#836)
Users can now force the legacy conversion logic for existing supported types, as it was before #816 and #796. This concerns timestamps, durations/intervals and byte arrays.
1 parent 194d900 commit 1ac19c3

File tree

8 files changed

+351
-117
lines changed

8 files changed

+351
-117
lines changed

common/src/main/scala/org/neo4j/spark/converter/DataConverter.scala

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.neo4j.driver.types.IsoDuration
3333
import org.neo4j.driver.types.Node
3434
import org.neo4j.driver.types.Relationship
3535
import org.neo4j.spark.service.SchemaService
36+
import org.neo4j.spark.util.Neo4jOptions
3637
import org.neo4j.spark.util.Neo4jUtil
3738

3839
import java.time._
@@ -54,9 +55,9 @@ trait DataConverter[T] {
5455
}
5556

5657
object SparkToNeo4jDataConverter {
57-
def apply(): SparkToNeo4jDataConverter = new SparkToNeo4jDataConverter()
58+
def apply(options: Neo4jOptions): SparkToNeo4jDataConverter = new SparkToNeo4jDataConverter(options)
5859

59-
def dayTimeMicrosToNeo4jDuration(micros: Long): Value = {
60+
private def dayTimeMicrosToNeo4jDuration(micros: Long): Value = {
6061
val oneSecondInMicros = 1000000L
6162
val oneDayInMicros = 24 * 3600 * oneSecondInMicros
6263
val numberDays = Math.floorDiv(micros, oneDayInMicros)
@@ -67,30 +68,35 @@ object SparkToNeo4jDataConverter {
6768
}
6869

6970
// while Neo4j supports years, this driver version's API does not expose it.
70-
def yearMonthIntervalToNeo4jDuration(months: Int): Value = {
71+
private def yearMonthIntervalToNeo4jDuration(months: Int): Value = {
7172
Values.isoDuration(months.toLong, 0L, 0L, 0)
7273
}
7374
}
7475

75-
class SparkToNeo4jDataConverter extends DataConverter[Value] {
76+
class SparkToNeo4jDataConverter(options: Neo4jOptions) extends DataConverter[Value] {
7677

7778
override def convert(value: Any, dataType: DataType): Value = {
7879
value match {
79-
case date: java.sql.Date => convert(date.toLocalDate, dataType)
80-
case timestamp: java.sql.Timestamp => convert(timestamp.toInstant.atZone(ZoneOffset.UTC), dataType)
80+
case date: java.sql.Date => convert(date.toLocalDate, dataType)
81+
case timestamp: java.sql.Timestamp =>
82+
if (options.legacyTypeConversionEnabled) {
83+
convert(timestamp.toLocalDateTime, dataType)
84+
} else {
85+
convert(timestamp.toInstant.atZone(ZoneOffset.UTC), dataType)
86+
}
8187
case intValue: Int if dataType == DataTypes.DateType =>
8288
convert(
8389
DateTimeUtils
8490
.toJavaDate(intValue),
8591
dataType
8692
)
87-
case intValue: Int if dataType.isInstanceOf[YearMonthIntervalType] =>
93+
case intValue: Int if dataType.isInstanceOf[YearMonthIntervalType] && !options.legacyTypeConversionEnabled =>
8894
SparkToNeo4jDataConverter.yearMonthIntervalToNeo4jDuration(intValue)
8995
case longValue: Long if dataType == DataTypes.TimestampType =>
9096
convert(DateTimeUtils.toJavaTimestamp(longValue), dataType)
91-
case longValue: Long if dataType == DataTypes.TimestampNTZType =>
97+
case longValue: Long if dataType == DataTypes.TimestampNTZType && !options.legacyTypeConversionEnabled =>
9298
convert(DateTimeUtils.microsToLocalDateTime(longValue), dataType)
93-
case longValue: Long if dataType.isInstanceOf[DayTimeIntervalType] =>
99+
case longValue: Long if dataType.isInstanceOf[DayTimeIntervalType] && !options.legacyTypeConversionEnabled =>
94100
SparkToNeo4jDataConverter.dayTimeMicrosToNeo4jDuration(longValue)
95101
case unsafeRow: UnsafeRow => {
96102
val structType = extractStructType(dataType)
@@ -138,7 +144,7 @@ class SparkToNeo4jDataConverter extends DataConverter[Value] {
138144
case arrayType: ArrayType => arrayType.elementType
139145
case _ => dataType
140146
}
141-
if (sparkType == DataTypes.ByteType) {
147+
if (sparkType == DataTypes.ByteType && !options.legacyTypeConversionEnabled) {
142148
Values.value(unsafeArray.toByteArray)
143149
} else {
144150
val javaList = unsafeArray.toSeq[AnyRef](sparkType)
@@ -166,10 +172,10 @@ class SparkToNeo4jDataConverter extends DataConverter[Value] {
166172
}
167173

168174
object Neo4jToSparkDataConverter {
169-
def apply(): Neo4jToSparkDataConverter = new Neo4jToSparkDataConverter()
175+
def apply(options: Neo4jOptions): Neo4jToSparkDataConverter = new Neo4jToSparkDataConverter(options)
170176
}
171177

172-
class Neo4jToSparkDataConverter extends DataConverter[Any] {
178+
class Neo4jToSparkDataConverter(options: Neo4jOptions) extends DataConverter[Any] {
173179

174180
override def convert(value: Any, dataType: DataType): Any = {
175181
if (dataType != null && dataType == DataTypes.StringType && value != null && !value.isInstanceOf[String]) {
@@ -217,8 +223,14 @@ class Neo4jToSparkDataConverter extends DataConverter[Any] {
217223
))
218224
}
219225
case zt: ZonedDateTime => DateTimeUtils.instantToMicros(zt.toInstant)
220-
case dt: LocalDateTime => DateTimeUtils.localDateTimeToMicros(dt)
221-
case d: LocalDate => d.toEpochDay.toInt
226+
case dt: LocalDateTime => {
227+
if (options.legacyTypeConversionEnabled) {
228+
DateTimeUtils.instantToMicros(dt.toInstant(ZoneOffset.UTC))
229+
} else {
230+
DateTimeUtils.localDateTimeToMicros(dt)
231+
}
232+
}
233+
case d: LocalDate => d.toEpochDay.toInt
222234
case lt: LocalTime => {
223235
InternalRow.fromSeq(Seq(
224236
UTF8String.fromString(SchemaService.TIME_TYPE_LOCAL),

common/src/main/scala/org/neo4j/spark/converter/TypeConverter.scala

Lines changed: 94 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,16 @@ import org.neo4j.spark.converter.CypherToSparkTypeConverter.timeType
2929
import org.neo4j.spark.converter.SparkToCypherTypeConverter.mapping
3030
import org.neo4j.spark.service.SchemaService.normalizedClassName
3131
import org.neo4j.spark.util.Neo4jImplicits.EntityImplicits
32+
import org.neo4j.spark.util.Neo4jOptions
3233

3334
import scala.collection.JavaConverters._
3435

3536
trait TypeConverter[SOURCE_TYPE, DESTINATION_TYPE] {
36-
3737
def convert(sourceType: SOURCE_TYPE, value: Any = null): DESTINATION_TYPE
38-
3938
}
4039

4140
object CypherToSparkTypeConverter {
42-
def apply(): CypherToSparkTypeConverter = new CypherToSparkTypeConverter()
41+
def apply(options: Neo4jOptions): CypherToSparkTypeConverter = new CypherToSparkTypeConverter(options)
4342

4443
private val cleanTerms: String = "Unmodifiable|Internal|Iso|2D|3D|Offset"
4544

@@ -66,66 +65,78 @@ object CypherToSparkTypeConverter {
6665
))
6766
}
6867

69-
class CypherToSparkTypeConverter extends TypeConverter[String, DataType] {
68+
class CypherToSparkTypeConverter(options: Neo4jOptions) extends TypeConverter[String, DataType] {
7069

71-
override def convert(sourceType: String, value: Any = null): DataType = sourceType
72-
.replaceAll(cleanTerms, "") match {
73-
case "Node" | "Relationship" => if (value != null) value.asInstanceOf[Entity].toStruct else DataTypes.NullType
74-
case "NodeArray" | "RelationshipArray" =>
75-
if (value != null) DataTypes.createArrayType(value.asInstanceOf[Entity].toStruct) else DataTypes.NullType
76-
case "Boolean" => DataTypes.BooleanType
77-
case "Long" => DataTypes.LongType
78-
case "Double" => DataTypes.DoubleType
79-
case "Point" => pointType
80-
case "DateTime" | "ZonedDateTime" => DataTypes.TimestampType
81-
case "LocalDateTime" => DataTypes.TimestampNTZType
82-
case "Time" | "LocalTime" => timeType
83-
case "Date" | "LocalDate" => DataTypes.DateType
84-
case "Duration" => durationType
85-
case "ByteArray" => DataTypes.BinaryType
86-
case "Map" => {
87-
val valueType = if (value == null) {
88-
DataTypes.NullType
89-
} else {
90-
val map = value.asInstanceOf[java.util.Map[String, AnyRef]].asScala
91-
val types = map.values
92-
.map(normalizedClassName)
93-
.toSet
94-
if (types.size == 1) convert(types.head, map.values.head) else DataTypes.StringType
95-
}
96-
DataTypes.createMapType(DataTypes.StringType, valueType)
70+
override def convert(sourceType: String, value: Any = null): DataType = {
71+
var cleanedSourceType = sourceType.replaceAll(cleanTerms, "")
72+
if (options.legacyTypeConversionEnabled) {
73+
cleanedSourceType = cleanedSourceType.replaceAll("Local|Zoned", "")
9774
}
98-
case "Array" => {
99-
val valueType = if (value == null) {
100-
DataTypes.NullType
101-
} else {
102-
val list = value.asInstanceOf[java.util.List[AnyRef]].asScala
103-
val types = list
104-
.map(normalizedClassName)
105-
.toSet
106-
if (types.size == 1) convert(types.head, list.head) else DataTypes.StringType
75+
cleanedSourceType match {
76+
case "Node" | "Relationship" =>
77+
if (value != null) value.asInstanceOf[Entity].toStruct(options) else DataTypes.NullType
78+
case "NodeArray" | "RelationshipArray" =>
79+
if (value != null) DataTypes.createArrayType(value.asInstanceOf[Entity].toStruct(options))
80+
else DataTypes.NullType
81+
case "Boolean" => DataTypes.BooleanType
82+
case "Long" => DataTypes.LongType
83+
case "Double" => DataTypes.DoubleType
84+
case "Point" => pointType
85+
case "DateTime" | "ZonedDateTime" => DataTypes.TimestampType
86+
case "LocalDateTime" =>
87+
if (options.legacyTypeConversionEnabled) {
88+
DataTypes.TimestampType
89+
} else {
90+
DataTypes.TimestampNTZType
91+
}
92+
case "Time" | "LocalTime" => timeType
93+
case "Date" | "LocalDate" => DataTypes.DateType
94+
case "Duration" => durationType
95+
case "ByteArray" => DataTypes.BinaryType
96+
case "Map" => {
97+
val valueType = if (value == null) {
98+
DataTypes.NullType
99+
} else {
100+
val map = value.asInstanceOf[java.util.Map[String, AnyRef]].asScala
101+
val types = map.values
102+
.map(value => normalizedClassName(value, options))
103+
.toSet
104+
if (types.size == 1) convert(types.head, map.values.head) else DataTypes.StringType
105+
}
106+
DataTypes.createMapType(DataTypes.StringType, valueType)
107107
}
108-
DataTypes.createArrayType(valueType)
108+
case "Array" => {
109+
val valueType = if (value == null) {
110+
DataTypes.NullType
111+
} else {
112+
val list = value.asInstanceOf[java.util.List[AnyRef]].asScala
113+
val types = list
114+
.map(value => normalizedClassName(value, options))
115+
.toSet
116+
if (types.size == 1) convert(types.head, list.head) else DataTypes.StringType
117+
}
118+
DataTypes.createArrayType(valueType)
119+
}
120+
// These are from APOC
121+
case "StringArray" => DataTypes.createArrayType(DataTypes.StringType)
122+
case "LongArray" => DataTypes.createArrayType(DataTypes.LongType)
123+
case "DoubleArray" => DataTypes.createArrayType(DataTypes.DoubleType)
124+
case "BooleanArray" => DataTypes.createArrayType(DataTypes.BooleanType)
125+
case "PointArray" => DataTypes.createArrayType(pointType)
126+
case "DateTimeArray" | "ZonedDateTimeArray" => DataTypes.createArrayType(DataTypes.TimestampType)
127+
case "TimeArray" | "LocalTimeArray" => DataTypes.createArrayType(timeType)
128+
case "DateArray" | "LocalDateArray" => DataTypes.createArrayType(DataTypes.DateType)
129+
case "DurationArray" => DataTypes.createArrayType(durationType)
130+
// Default is String
131+
case _ => DataTypes.StringType
109132
}
110-
// These are from APOC
111-
case "StringArray" => DataTypes.createArrayType(DataTypes.StringType)
112-
case "LongArray" => DataTypes.createArrayType(DataTypes.LongType)
113-
case "DoubleArray" => DataTypes.createArrayType(DataTypes.DoubleType)
114-
case "BooleanArray" => DataTypes.createArrayType(DataTypes.BooleanType)
115-
case "PointArray" => DataTypes.createArrayType(pointType)
116-
case "DateTimeArray" | "ZonedDateTimeArray" => DataTypes.createArrayType(DataTypes.TimestampType)
117-
case "TimeArray" | "LocalTimeArray" => DataTypes.createArrayType(timeType)
118-
case "DateArray" | "LocalDateArray" => DataTypes.createArrayType(DataTypes.DateType)
119-
case "DurationArray" => DataTypes.createArrayType(durationType)
120-
// Default is String
121-
case _ => DataTypes.StringType
122133
}
123134
}
124135

125136
object SparkToCypherTypeConverter {
126-
def apply(): SparkToCypherTypeConverter = new SparkToCypherTypeConverter()
137+
def apply(options: Neo4jOptions): SparkToCypherTypeConverter = new SparkToCypherTypeConverter(options)
127138

128-
private val mapping: Map[DataType, String] = Map(
139+
private val baseMappings: Map[DataType, String] = Map(
129140
DataTypes.BooleanType -> "BOOLEAN",
130141
DataTypes.StringType -> "STRING",
131142
DecimalType.SYSTEM_DEFAULT -> "STRING",
@@ -136,34 +147,50 @@ object SparkToCypherTypeConverter {
136147
DataTypes.FloatType -> "FLOAT",
137148
DataTypes.DoubleType -> "FLOAT",
138149
DataTypes.DateType -> "DATE",
139-
DataTypes.TimestampType -> "ZONED DATETIME",
140-
DataTypes.TimestampNTZType -> "LOCAL DATETIME",
141-
DayTimeIntervalType() -> "DURATION",
142-
YearMonthIntervalType() -> "DURATION",
143150
durationType -> "DURATION",
144151
pointType -> "POINT",
145152
// Cypher graph entities do not allow null values in arrays
146153
DataTypes.createArrayType(DataTypes.BooleanType, false) -> "LIST<BOOLEAN NOT NULL>",
147154
DataTypes.createArrayType(DataTypes.StringType, false) -> "LIST<STRING NOT NULL>",
148155
DataTypes.createArrayType(DecimalType.SYSTEM_DEFAULT, false) -> "LIST<STRING NOT NULL>",
149-
DataTypes.createArrayType(DataTypes.ByteType, false) -> "ByteArray",
150156
DataTypes.createArrayType(DataTypes.ShortType, false) -> "LIST<INTEGER NOT NULL>",
151157
DataTypes.createArrayType(DataTypes.IntegerType, false) -> "LIST<INTEGER NOT NULL>",
152158
DataTypes.createArrayType(DataTypes.LongType, false) -> "LIST<INTEGER NOT NULL>",
153159
DataTypes.createArrayType(DataTypes.FloatType, false) -> "LIST<FLOAT NOT NULL>",
154160
DataTypes.createArrayType(DataTypes.DoubleType, false) -> "LIST<FLOAT NOT NULL>",
155161
DataTypes.createArrayType(DataTypes.DateType, false) -> "LIST<DATE NOT NULL>",
156-
DataTypes.createArrayType(DataTypes.TimestampType, false) -> "LIST<ZONED DATETIME NOT NULL>",
157-
DataTypes.createArrayType(DataTypes.TimestampNTZType, false) -> "LIST<LOCAL DATETIME NOT NULL>",
158-
DataTypes.createArrayType(DayTimeIntervalType(), false) -> "LIST<DURATION NOT NULL>",
159-
DataTypes.createArrayType(DayTimeIntervalType(), true) -> "LIST<DURATION NOT NULL>",
160-
DataTypes.createArrayType(YearMonthIntervalType(), false) -> "LIST<DURATION NOT NULL>",
161-
DataTypes.createArrayType(YearMonthIntervalType(), true) -> "LIST<DURATION NOT NULL>",
162162
DataTypes.createArrayType(durationType, false) -> "LIST<DURATION NOT NULL>",
163163
DataTypes.createArrayType(pointType, false) -> "LIST<POINT NOT NULL>"
164164
)
165+
166+
private def mapping(sourceType: DataType, options: Neo4jOptions): String = {
167+
val mappings = sourceTypeMappings(options)
168+
mappings(sourceType)
169+
}
170+
171+
private def sourceTypeMappings(options: Neo4jOptions): Map[DataType, String] = {
172+
var result = baseMappings
173+
if (options.legacyTypeConversionEnabled) {
174+
result += (DataTypes.TimestampType -> "LOCAL DATETIME")
175+
result += (DataTypes.createArrayType(DataTypes.TimestampType, false) -> "LIST<LOCAL DATETIME NOT NULL>")
176+
result += (DataTypes.createArrayType(DataTypes.TimestampType, true) -> "LIST<LOCAL DATETIME NOT NULL>")
177+
} else {
178+
result += (DataTypes.TimestampType -> "ZONED DATETIME")
179+
result += (DataTypes.TimestampNTZType -> "LOCAL DATETIME")
180+
result += (DayTimeIntervalType() -> "DURATION")
181+
result += (YearMonthIntervalType() -> "DURATION")
182+
result += (DataTypes.createArrayType(DataTypes.ByteType, false) -> "ByteArray")
183+
result += (DataTypes.createArrayType(DataTypes.TimestampType, false) -> "LIST<ZONED DATETIME NOT NULL>")
184+
result += (DataTypes.createArrayType(DataTypes.TimestampNTZType, false) -> "LIST<LOCAL DATETIME NOT NULL>")
185+
result += (DataTypes.createArrayType(DayTimeIntervalType(), false) -> "LIST<DURATION NOT NULL>")
186+
result += (DataTypes.createArrayType(DayTimeIntervalType(), true) -> "LIST<DURATION NOT NULL>")
187+
result += (DataTypes.createArrayType(YearMonthIntervalType(), false) -> "LIST<DURATION NOT NULL>")
188+
result += (DataTypes.createArrayType(YearMonthIntervalType(), true) -> "LIST<DURATION NOT NULL>")
189+
}
190+
result
191+
}
165192
}
166193

167-
class SparkToCypherTypeConverter extends TypeConverter[DataType, String] {
168-
override def convert(sourceType: DataType, value: Any): String = mapping(sourceType)
194+
class SparkToCypherTypeConverter(options: Neo4jOptions) extends TypeConverter[DataType, String] {
195+
override def convert(sourceType: DataType, value: Any): String = mapping(sourceType, options)
169196
}

common/src/main/scala/org/neo4j/spark/service/MappingService.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
4646
extends Neo4jMappingStrategy[InternalRow, Option[java.util.Map[String, AnyRef]]]
4747
with Logging {
4848

49-
private val dataConverter = SparkToNeo4jDataConverter()
49+
private val dataConverter = SparkToNeo4jDataConverter(options)
5050

5151
override def node(row: InternalRow, schema: StructType): Option[java.util.Map[String, AnyRef]] = {
5252
val rowMap: java.util.Map[String, Object] = new java.util.HashMap[String, Object]
@@ -212,7 +212,7 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
212212
class Neo4jReadMappingStrategy(private val options: Neo4jOptions, requiredColumns: StructType)
213213
extends Neo4jMappingStrategy[Record, InternalRow] {
214214

215-
private val dataConverter = Neo4jToSparkDataConverter()
215+
private val dataConverter = Neo4jToSparkDataConverter(options)
216216

217217
override def node(record: Record, schema: StructType): InternalRow = {
218218
if (requiredColumns.nonEmpty) {

common/src/main/scala/org/neo4j/spark/service/SchemaService.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ class SchemaService(
6969

7070
private val sessionTransactionConfig = options.toNeo4jTransactionConfig
7171

72-
private val cypherToSparkTypeConverter = CypherToSparkTypeConverter()
72+
private val cypherToSparkTypeConverter = CypherToSparkTypeConverter(options)
7373

74-
private val sparkToCypherTypeConverter = SparkToCypherTypeConverter()
74+
private val sparkToCypherTypeConverter = SparkToCypherTypeConverter(options)
7575

7676
private def structForNode(labels: Seq[String] = options.nodeMetadata.labels) = {
7777
val structFields: mutable.Buffer[StructField] = (try {
@@ -158,9 +158,9 @@ class SchemaService(
158158
case SchemaStrategy.SAMPLE => {
159159
val types = t._2.map(value => {
160160
if (options.query.queryType == QueryType.QUERY) {
161-
normalizedClassName(value)
161+
normalizedClassName(value, options)
162162
} else {
163-
normalizedClassNameFromGraphEntity(value)
163+
normalizedClassNameFromGraphEntity(value, options)
164164
}
165165
}).toSet
166166

@@ -1027,8 +1027,8 @@ object SchemaService {
10271027

10281028
val DURATION_TYPE = "duration"
10291029

1030-
def normalizedClassName(value: AnyRef): String = value match {
1031-
case binary: Array[Byte] => "ByteArray"
1030+
def normalizedClassName(value: AnyRef, options: Neo4jOptions): String = value match {
1031+
case binary: Array[Byte] => if (options.legacyTypeConversionEnabled) value.getClass.getSimpleName else "ByteArray"
10321032
case list: java.util.List[_] => "Array"
10331033
case map: java.util.Map[String, _] => "Map"
10341034
case null => "String"
@@ -1037,8 +1037,8 @@ object SchemaService {
10371037

10381038
// from nodes and relationships we cannot have maps as properties and elements in lists are the same type
10391039
// special treatment for ByteArray required (pattern matching on Array != List)
1040-
def normalizedClassNameFromGraphEntity(value: AnyRef): String = value match {
1041-
case binary: Array[Byte] => "ByteArray"
1040+
def normalizedClassNameFromGraphEntity(value: AnyRef, options: Neo4jOptions): String = value match {
1041+
case binary: Array[Byte] => if (options.legacyTypeConversionEnabled) value.getClass.getSimpleName else "ByteArray"
10421042
case list: java.util.List[_] => s"${list.get(0).getClass.getSimpleName}Array"
10431043
case null => "String"
10441044
case _ => value.getClass.getSimpleName

0 commit comments

Comments
 (0)