Skip to content

Commit 01a756f

Browse files
authored
Merge pull request #156 from caraml-dev/support-array-datatype
feat: add array datatype on MaxCompute reader
2 parents 1fcaccb + 5591b09 commit 01a756f

File tree

6 files changed

+132
-42
lines changed

6 files changed

+132
-42
lines changed

caraml-store-pyspark/scripts/historical_feature_retrieval_job.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -358,29 +358,6 @@ class Field(NamedTuple):
358358
name: str
359359
type: str
360360

361-
@property
362-
def spark_type(self):
363-
"""
364-
Returns Spark data type that corresponds to the field's Feast type
365-
"""
366-
feast_to_spark_type_mapping = {
367-
"bytes": "binary",
368-
"string": "string",
369-
"int32": "int",
370-
"int64": "bigint",
371-
"double": "double",
372-
"float": "float",
373-
"bool": "boolean",
374-
"bytes_list": "array<binary>",
375-
"string_list": "array<string>",
376-
"int32_list": "array<int>",
377-
"int64_list": "array<bigint>",
378-
"double_list": "array<double>",
379-
"float_list": "array<float>",
380-
"bool_list": "array<boolean>",
381-
}
382-
return feast_to_spark_type_mapping[self.type.lower()]
383-
384361

385362
class FeatureTable(NamedTuple):
386363
"""
@@ -463,6 +440,43 @@ def entity_selections(self):
463440
return [f"{self.field_mapping.get(entity, entity)} as {entity}" for entity in self.entities]
464441

465442

443+
def _spark_type(field: Field, source: Source) -> str:
444+
if isinstance(source, MaxComputeSource):
445+
return {
446+
"bytes": "tinyint",
447+
"string": "string",
448+
"int32": "int",
449+
"int64": "bigint",
450+
"double": "double",
451+
"float": "float",
452+
"bool": "boolean",
453+
"bytes_list": "array<tinyint>",
454+
"string_list": "array<string>",
455+
"int32_list": "array<int>",
456+
"int64_list": "array<bigint>",
457+
"double_list": "array<double>",
458+
"float_list": "array<float>",
459+
"bool_list": "array<boolean>",
460+
}[field.type.lower()]
461+
else:
462+
return {
463+
"bytes": "binary",
464+
"string": "string",
465+
"int32": "int",
466+
"int64": "bigint",
467+
"double": "double",
468+
"float": "float",
469+
"bool": "boolean",
470+
"bytes_list": "array<binary>",
471+
"string_list": "array<string>",
472+
"int32_list": "array<int>",
473+
"int64_list": "array<bigint>",
474+
"double_list": "array<double>",
475+
"float_list": "array<float>",
476+
"bool_list": "array<boolean>",
477+
}[field.type.lower()]
478+
479+
466480
def _map_column(df: DataFrame, col_mapping: Dict[str, str]):
467481
source_to_alias_map = {v: k for k, v in col_mapping.items()}
468482
projection = {}
@@ -820,15 +834,16 @@ def _read_and_verify_feature_table_df_from_source(
820834
feature_table_dtypes = dict(mapped_source_df.dtypes)
821835
for field in feature_table.entities + feature_table.features:
822836
column_type = feature_table_dtypes.get(field.name)
837+
spark_type = _spark_type(field, source)
823838

824-
if column_type != field.spark_type:
825-
if _type_casting_allowed(field.spark_type, column_type):
839+
if column_type != spark_type:
840+
if _type_casting_allowed(spark_type, column_type):
826841
mapped_source_df = mapped_source_df.withColumn(
827-
field.name, col(field.name).cast(field.spark_type)
842+
field.name, col(field.name).cast(spark_type)
828843
)
829844
else:
830845
raise SchemaError(
831-
f"{field.name} should be of {field.spark_type} type, but is {column_type} instead"
846+
f"{field.name} should be of {spark_type} type, but is {column_type} instead"
832847
)
833848

834849
for timestamp_column in [
@@ -916,15 +931,16 @@ def retrieve_historical_features(
916931
for feature_table, source in zip(feature_tables, feature_tables_sources)
917932
]
918933

919-
expected_entities = []
934+
expected_entities: List[Field] = []
920935
for feature_table in feature_tables:
921936
expected_entities.extend(feature_table.entities)
922937

923938
entity_dtypes = dict(entity_df.dtypes)
924939
for expected_entity in expected_entities:
925-
if entity_dtypes.get(expected_entity.name) != expected_entity.spark_type:
940+
spark_type = _spark_type(expected_entity, entity_source)
941+
if entity_dtypes.get(expected_entity.name) != spark_type:
926942
raise SchemaError(
927-
f"{expected_entity.name} ({expected_entity.spark_type}) is not present in the entity dataframe."
943+
f"{expected_entity.name} ({spark_type}) is not present in the entity dataframe."
928944
)
929945

930946
entity_df.cache()

caraml-store-spark/build.gradle

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,14 @@ dependencies {
6060
testImplementation 'com.github.tomakehurst:wiremock-jre8:2.26.3'
6161
testImplementation "com.dimafeng:testcontainers-scala-kafka_$scalaVersion:0.40.12"
6262
testRuntimeOnly 'com.vladsch.flexmark:flexmark-all:0.35.10'
63-
implementation files('./prebuilt-jars/custom-dialect.jar')
64-
compileOnly('com.aliyun.odps:odps-jdbc:3.8.2') {
63+
compileOnly('com.aliyun.odps:odps-jdbc:3.10.1') {
6564
exclude group: 'org.antlr', module: 'antlr4-runtime'
6665
}
6766

67+
// to fix error in unit test.
68+
// last time it was successful because aliyun-odps-jdbc 3.8.2 depends on odps-sdk-core 0.51.5.
69+
// then odps-sdk-core depends on jackson-databind 2.15.2.
70+
testImplementation "com.fasterxml.jackson.core:jackson-databind:2.15.2"
6871
}
6972
application {
7073
mainClassName = 'dev.caraml.spark.IngestionJob'
@@ -87,7 +90,7 @@ def containerRegistry = System.getenv('DOCKER_REGISTRY')
8790
docker {
8891
dependsOn shadowJar
8992
dockerfile file('docker/Dockerfile')
90-
files shadowJar.outputs, "$rootDir/caraml-store-pyspark/scripts", "$rootDir/caraml-store-spark/prebuilt-jars/custom-dialect.jar"
93+
files shadowJar.outputs, "$rootDir/caraml-store-pyspark/scripts"
9194
copySpec.with {
9295
from("$rootDir/caraml-store-pyspark") {
9396
include 'templates/**'

caraml-store-spark/docker/Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@ FROM --platform=linux/amd64 apache/spark-py:v3.1.3
22

33
ARG GCS_CONNECTOR_VERSION=2.2.5
44
ARG BQ_CONNECTOR_VERSION=0.27.1
5-
ARG ODPS_JDBC_CONNECTOR=3.8.2
5+
ARG ODPS_JDBC_CONNECTOR=3.10.1
66
ARG HADOOP_ALIYUN_VERSION=3.2.0
77
ARG ALIYUN_SDK_OSS_VERSION=2.8.3
88
ARG JDOM_VERSION=1.1
99

1010
USER root
1111
ADD https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop2-${GCS_CONNECTOR_VERSION}.jar /opt/spark/jars
1212
ADD https://repo1.maven.org/maven2/com/google/cloud/spark/spark-bigquery-with-dependencies_2.12/${BQ_CONNECTOR_VERSION}/spark-bigquery-with-dependencies_2.12-${BQ_CONNECTOR_VERSION}.jar /opt/spark/jars
13-
ADD https://github.com/aliyun/aliyun-odps-jdbc/releases/download/v${ODPS_JDBC_CONNECTOR}/odps-jdbc-${ODPS_JDBC_CONNECTOR}-jar-with-dependencies.jar /opt/spark/jars
13+
# aliyun odps jdbc with dependencies
14+
ADD https://github.com/aliyun/aliyun-odps-jdbc/releases/download/v${ODPS_JDBC_CONNECTOR}/odps-jdbc-${ODPS_JDBC_CONNECTOR}.jar /opt/spark/jars
1415
ADD https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aliyun/${HADOOP_ALIYUN_VERSION}/hadoop-aliyun-${HADOOP_ALIYUN_VERSION}.jar /opt/spark/jars
1516
ADD https://repo1.maven.org/maven2/com/aliyun/oss/aliyun-sdk-oss/${ALIYUN_SDK_OSS_VERSION}/aliyun-sdk-oss-${ALIYUN_SDK_OSS_VERSION}.jar /opt/spark/jars
1617
ADD https://repo1.maven.org/maven2/org/jdom/jdom/${JDOM_VERSION}/jdom-${JDOM_VERSION}.jar /opt/spark/jars
@@ -20,6 +21,5 @@ RUN pip install Jinja2==3.1.2
2021
RUN mkdir -p /dev
2122

2223
ADD caraml-spark-application-with-dependencies.jar /opt/spark/jars
23-
ADD custom-dialect.jar /opt/spark/jars
2424
ADD templates /opt/spark/work-dir/
2525
ADD historical_feature_retrieval_job.py /opt/spark/work-dir
-971 Bytes
Binary file not shown.

caraml-store-spark/src/main/scala/dev/caraml/spark/odps/CustomDialect.scala

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
package dev.caraml.spark.odps
2-
import org.apache.spark.sql.jdbc.JdbcDialect
3-
import org.apache.spark.sql.jdbc.JdbcType
4-
import org.apache.spark.sql.types.DataType
2+
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType}
53
import org.apache.spark.sql.types._
6-
import org.apache.hbase.thirdparty.org.eclipse.jetty.util.ajax.JSON
4+
import java.sql.SQLException
75

86
class CustomDialect extends JdbcDialect {
97
override def canHandle(url: String): Boolean = {
@@ -14,6 +12,79 @@ class CustomDialect extends JdbcDialect {
1412
s"$colName"
1513
}
1614

15+
/*
16+
TODO: currently unsupported types
17+
- ARRAY<DECIMAL(precision,scale)>
18+
- ARRAY<VARCHAR(n)> --> temporarily map it as a string
19+
- ARRAY<CHAR(n)> --> temporarily map it as a string
20+
- ARRAY<DATE>
21+
- ARRAY<DATETIME>
22+
- ARRAY<TIMESTAMP>
23+
- ARRAY<TIMESTAMP_NTZ>
24+
- ARRAY<INTERVAL>
25+
26+
typeName below were obtained from https://www.alibabacloud.com/help/en/maxcompute/user-guide/maxcompute-v2-0-data-type-edition
27+
*/
28+
private def getCommonCatalystType(typeName: String): Option[DataType] = {
29+
typeName.toUpperCase() match {
30+
case "TINYINT" => Option(ByteType)
31+
case "SMALLINT" => Option(ShortType)
32+
case "INT" => Option(IntegerType)
33+
case "BIGINT" => Option(LongType)
34+
case "BINARY" => Option(BinaryType)
35+
case "FLOAT" => Option(FloatType)
36+
case "DOUBLE" => Option(DoubleType)
37+
// case s if s.startsWith("DECIMAL") =>
38+
// val mdat = s.stripPrefix("DECIMAL(").stripSuffix(")").split(",")
39+
// if (mdat.length == 2) {
40+
// val precision = mdat(0).toInt
41+
// val scale = mdat(1).toInt
42+
// Option(DecimalType(min(precision, DecimalType.MAX_PRECISION), min(scale, DecimalType.MAX_SCALE)))
43+
// } else {
44+
// Option(DecimalType.SYSTEM_DEFAULT)
45+
// }
46+
// case s if s.startsWith("VARCHAR") => Option(VarcharType(s.stripPrefix("VARCHAR(").stripSuffix(")").toInt))
47+
// case s if s.startsWith("CHAR") => Option(CharType(s.stripPrefix("CHAR(").stripSuffix(")").toInt))
48+
case s if s.startsWith("VARCHAR") => Option(StringType)
49+
case s if s.startsWith("CHAR") => Option(StringType)
50+
case "STRING" => Option(StringType)
51+
// case "DATE" => Option(DateType)
52+
// case "DATETIME" => Option(TimestampType)
53+
// case "TIMESTAMP" => Option(TimestampType)
54+
// case "TIMESTAMP_NTZ" => Option(TimestampType)
55+
case "BOOLEAN" => Option(BooleanType)
56+
// case "INTERVAL" => Option(CalendarIntervalType)
57+
case _ => None
58+
}
59+
}
60+
61+
override def getCatalystType(
62+
sqlType: Int,
63+
typeName: String,
64+
size: Int,
65+
md: MetadataBuilder
66+
): Option[DataType] = {
67+
sqlType match {
68+
case java.sql.Types.ARRAY =>
69+
val elementTypeName = typeName.toUpperCase().stripPrefix("ARRAY<").stripSuffix(">")
70+
val elementType = getCommonCatalystType(elementTypeName).map(ArrayType(_))
71+
72+
if (elementType.isEmpty) {
73+
throw new SQLException(s"Unsupported type $typeName")
74+
}
75+
logDebug(
76+
s"CustomDialect sqlType: $sqlType md: ${md.build().toString()} size: $size typeName: $typeName elementType: ${elementType.getOrElse(ArrayType(NullType)).elementType}"
77+
)
78+
elementType
79+
case _ =>
80+
val dataType = getCommonCatalystType(typeName.toUpperCase())
81+
logDebug(
82+
s"CustomDialect sqlType: $sqlType md: ${md.build().toString()} size: $size typeName: $typeName dataType: $dataType"
83+
)
84+
dataType
85+
}
86+
}
87+
1788
override def getJDBCType(dt: DataType): Option[JdbcType] = {
1889
dt match {
1990
case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))

caraml-store-spark/src/main/scala/dev/caraml/spark/sources/maxCompute/MaxComputeReader.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package dev.caraml.spark.sources.maxCompute
22

3-
import dev.caraml.spark.{MaxComputeSource, MaxComputeConfig}
3+
import dev.caraml.spark.{MaxComputeConfig, MaxComputeSource}
44
import org.joda.time.DateTime
55
import org.apache.spark.sql.{DataFrame, SparkSession}
66
import org.apache.spark.sql.jdbc.JdbcDialects
7-
import com.caraml.odps.CustomDialect
7+
import dev.caraml.spark.odps.CustomDialect
88
import org.apache.log4j.Logger
99

1010
object MaxComputeReader {

0 commit comments

Comments
 (0)