Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ plugins {

airbyteBulkConnector {
core = "load"
toolkits = listOf("load-db")
toolkits = listOf("load-db", "load-avro")
}

tasks.withType<JavaCompile>().configureEach {
Expand Down Expand Up @@ -45,6 +45,9 @@ dependencies {
implementation("com.google.guava:guava:32.1.1-jre")
implementation("de.siegmar:fastcsv:4.0.0")
implementation("io.micronaut.cache:micronaut-cache-caffeine:4.3.1")
implementation("org.apache.parquet:parquet-avro:1.16.0")
implementation("org.apache.avro:avro:1.12.0")
implementation("org.xerial.snappy:snappy-java:1.1.10.8")

testImplementation("io.mockk:mockk:1.14.5")
testImplementation("org.junit.jupiter:junit-jupiter-api:$junitVersion")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ package io.airbyte.integrations.destination.snowflake.dataflow

import io.airbyte.cdk.load.dataflow.aggregate.Aggregate
import io.airbyte.cdk.load.dataflow.transform.RecordDTO
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeParquetInsertBuffer

class SnowflakeAggregate(
private val buffer: SnowflakeInsertBuffer,
private val buffer: SnowflakeParquetInsertBuffer,
) : Aggregate {
override fun accept(record: RecordDTO) {
buffer.accumulate(record.fields)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import io.airbyte.cdk.load.write.StreamStateStore
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeInsertBuffer
import io.airbyte.integrations.destination.snowflake.write.load.SnowflakeParquetInsertBuffer
import io.micronaut.cache.annotation.CacheConfig
import io.micronaut.cache.annotation.Cacheable
import jakarta.inject.Singleton
Expand All @@ -30,7 +30,7 @@ open class SnowflakeAggregateFactory(
override fun create(key: StoreKey): Aggregate {
val tableName = streamStateStore.get(key)!!.tableName
val buffer =
SnowflakeInsertBuffer(
SnowflakeParquetInsertBuffer(
tableName = tableName,
columns = getTableColumns(tableName),
snowflakeClient = snowflakeClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ import kotlin.collections.component2
import kotlin.collections.joinToString
import kotlin.collections.map
import kotlin.collections.plus
import org.apache.avro.Schema
import org.apache.avro.SchemaBuilder

internal const val NOT_NULL = "NOT NULL"

Expand All @@ -54,7 +56,7 @@ internal val DEFAULT_COLUMNS =
),
ColumnAndType(
columnName = COLUMN_NAME_AB_META,
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
columnType = "${SnowflakeDataType.OBJECT.typeName} $NOT_NULL"
),
ColumnAndType(
columnName = COLUMN_NAME_AB_GENERATION_ID,
Expand All @@ -65,7 +67,7 @@ internal val DEFAULT_COLUMNS =
internal val RAW_DATA_COLUMN =
ColumnAndType(
columnName = COLUMN_NAME_DATA,
columnType = "${SnowflakeDataType.VARIANT.typeName} $NOT_NULL"
columnType = "${SnowflakeDataType.OBJECT.typeName} $NOT_NULL"
)

internal val RAW_COLUMNS =
Expand Down Expand Up @@ -186,6 +188,16 @@ class SnowflakeColumnUtils(
is UnionType -> SnowflakeDataType.VARIANT.typeName
is UnknownType -> SnowflakeDataType.VARIANT.typeName
}

fun toAvroType(snowflakeType: String): Schema =
when (SnowflakeDataType.valueOf(snowflakeType)) {
SnowflakeDataType.BOOLEAN -> SchemaBuilder.builder().booleanType()
SnowflakeDataType.FIXED,
SnowflakeDataType.NUMBER -> SchemaBuilder.builder().longType()
SnowflakeDataType.REAL,
SnowflakeDataType.FLOAT -> SchemaBuilder.builder().doubleType()
else -> SchemaBuilder.builder().stringType()
}
}

data class ColumnAndType(val columnName: String, val columnType: String) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ enum class SnowflakeDataType(val typeName: String) {
// Numeric types
NUMBER("NUMBER(38,0)"),
FLOAT("FLOAT"),
FIXED("FIXED"),
REAL("REAL"),

// String & binary types
VARCHAR("VARCHAR"),
TEXT("TEXT"),

// Boolean type
BOOLEAN("BOOLEAN"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ import io.airbyte.integrations.destination.snowflake.db.ColumnDefinition
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.CdcDeletionMode
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.write.load.CSV_FIELD_SEPARATOR
import io.airbyte.integrations.destination.snowflake.write.load.CSV_LINE_DELIMITER
import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton

Expand Down Expand Up @@ -325,8 +323,6 @@ class SnowflakeDirectLoadSqlGenerator(
val stageName = snowflakeSqlNameUtils.fullyQualifiedStageName(tableName, true)
return """
PUT 'file://$tempFilePath' '@$stageName'
AUTO_COMPRESS = FALSE
SOURCE_COMPRESSION = GZIP
OVERWRITE = TRUE
"""
.trimIndent()
Expand All @@ -340,18 +336,13 @@ class SnowflakeDirectLoadSqlGenerator(
COPY INTO ${snowflakeSqlNameUtils.fullyQualifiedName(tableName)}
FROM '@$stageName'
FILE_FORMAT = (
TYPE = 'CSV'
COMPRESSION = GZIP
FIELD_DELIMITER = '$CSV_FIELD_SEPARATOR'
RECORD_DELIMITER = '$CSV_LINE_DELIMITER'
FIELD_OPTIONALLY_ENCLOSED_BY = '"'
TRIM_SPACE = TRUE
ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE
TYPE = 'PARQUET'
SNAPPY_COMPRESSION = TRUE
USE_VECTORIZED_SCANNER = TRUE
REPLACE_INVALID_CHARACTERS = TRUE
ESCAPE = NONE
ESCAPE_UNENCLOSED_FIELD = NONE
)
ON_ERROR = 'ABORT_STATEMENT'
MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE
PURGE = TRUE
files = ('$filename')
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.integrations.destination.snowflake.write.load

import com.google.common.annotations.VisibleForTesting
import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.Transformations
import io.airbyte.cdk.load.orchestration.db.TableName
import io.airbyte.integrations.destination.snowflake.client.SnowflakeAirbyteClient
import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleName
import io.airbyte.integrations.destination.snowflake.spec.SnowflakeConfiguration
import io.airbyte.integrations.destination.snowflake.sql.QUOTE
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils
import io.airbyte.protocol.models.Jsons
import io.github.oshai.kotlinlogging.KotlinLogging
import java.nio.file.Path
import kotlin.io.path.deleteIfExists
import kotlin.io.path.pathString
import org.apache.avro.Schema
import org.apache.avro.generic.GenericData
import org.apache.avro.generic.GenericRecord
import org.apache.hadoop.conf.Configuration
import org.apache.parquet.avro.AvroParquetWriter
import org.apache.parquet.hadoop.ParquetWriter
import org.apache.parquet.hadoop.metadata.CompressionCodecName
import org.apache.parquet.io.LocalOutputFile

private val logger = KotlinLogging.logger {}

internal const val PARQUET_FILE_PREFIX = "snowflake"
internal const val PARQUET_FILE_SUFFIX = ".parquet"

class SnowflakeParquetInsertBuffer(
private val tableName: TableName,
val columns: LinkedHashMap<String, String>,
private val snowflakeClient: SnowflakeAirbyteClient,
val snowflakeConfiguration: SnowflakeConfiguration,
private val snowflakeColumnUtils: SnowflakeColumnUtils,
) {

@VisibleForTesting internal var parquetFilePath: Path? = null
@VisibleForTesting internal var recordCount = 0
private var writer: ParquetWriter<GenericRecord>? = null
private var schema: Schema? = null

private val snowflakeRecordFormatter: SnowflakeRecordFormatter =
when (snowflakeConfiguration.legacyRawTablesOnly) {
true -> SnowflakeRawRecordFormatter(columns, snowflakeColumnUtils)
else -> SnowflakeParquetRecordFormatter(columns, snowflakeColumnUtils)
}

fun accumulate(recordFields: Map<String, AirbyteValue>) {
if (parquetFilePath == null) {
parquetFilePath =
Path.of(
System.getProperty("java.io.tmpdir"),
"$PARQUET_FILE_PREFIX${System.currentTimeMillis()}$PARQUET_FILE_SUFFIX"
)
schema = buildSchema()
writer = buildWriter(schema = schema!!, path = parquetFilePath!!)
}

val record = createRecord(recordFields)
writer?.let { w ->
w.write(record)
recordCount++
}
}

suspend fun flush() {
parquetFilePath?.let { filePath ->
try {
writer?.close()
logger.info { "Beginning insert into ${tableName.toPrettyString(quote = QUOTE)}" }
// Next, put the CSV file into the staging table
snowflakeClient.putInStage(tableName, filePath.pathString)
// Finally, copy the data from the staging table to the final table
snowflakeClient.copyFromStage(tableName, filePath.fileName.toString())
logger.info {
"Finished insert of $recordCount row(s) into ${tableName.toPrettyString(quote = QUOTE)}"
}
} catch (e: Exception) {
logger.error(e) { "Unable to flush accumulated data." }
throw e
} finally {
filePath.deleteIfExists()
writer = null
recordCount = 0
}
}
}

private fun buildWriter(schema: Schema, path: Path): ParquetWriter<GenericRecord> =
AvroParquetWriter.builder<GenericRecord>(LocalOutputFile(path))
.withSchema(schema)
.withConf(Configuration())
.withCompressionCodec(CompressionCodecName.SNAPPY)
.build()

private fun buildSchema(): Schema {
val schema = mutableMapOf<String, Any>()
schema["type"] = "record"
schema["name"] = Transformations.toAvroSafeName(tableName.name)
schema["fields"] =
columns.map { (key, value) ->
if (value.equals("VARIANT", true)) {
mapOf(
"name" to Transformations.toAlphanumericAndUnderscore(key),
"type" to
mapOf(
"type" to snowflakeColumnUtils.toAvroType(value).name,
"logicalType" to "variant"
),
)
} else {
mapOf(
"name" to Transformations.toAlphanumericAndUnderscore(key),
"type" to listOf(snowflakeColumnUtils.toAvroType(value).name, "null"),
)
}
}

return Schema.Parser().parse(Jsons.serialize(schema))
}

private fun createRecord(recordFields: Map<String, AirbyteValue>): GenericRecord {
val record = GenericData.Record(schema)
val recordValues = snowflakeRecordFormatter.format(recordFields)
recordValues.forEachIndexed { index, value ->
record.put(columns.keys.toList()[index].toSnowflakeCompatibleName(), value)
}
return record
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,18 @@
package io.airbyte.integrations.destination.snowflake.write.load

import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.ArrayValue
import io.airbyte.cdk.load.data.BooleanValue
import io.airbyte.cdk.load.data.DateValue
import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.data.NullValue
import io.airbyte.cdk.load.data.NumberValue
import io.airbyte.cdk.load.data.ObjectValue
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.data.TimeWithTimezoneValue
import io.airbyte.cdk.load.data.TimeWithoutTimezoneValue
import io.airbyte.cdk.load.data.TimestampWithTimezoneValue
import io.airbyte.cdk.load.data.TimestampWithoutTimezoneValue
import io.airbyte.cdk.load.data.csv.toCsvValue
import io.airbyte.cdk.load.message.Meta
import io.airbyte.cdk.load.util.Jsons
Expand All @@ -15,7 +25,51 @@ import io.airbyte.integrations.destination.snowflake.db.toSnowflakeCompatibleNam
import io.airbyte.integrations.destination.snowflake.sql.SnowflakeColumnUtils

interface SnowflakeRecordFormatter {
fun format(record: Map<String, AirbyteValue>): List<Any>
fun format(record: Map<String, AirbyteValue>): List<Any?>
}

class SnowflakeParquetRecordFormatter(
private val columns: LinkedHashMap<String, String>,
val snowflakeColumnUtils: SnowflakeColumnUtils,
) : SnowflakeRecordFormatter {

private val airbyteColumnNames =
snowflakeColumnUtils.getFormattedDefaultColumnNames(false).toSet()

override fun format(record: Map<String, AirbyteValue>): List<Any?> =
columns.map { (columnName, _) ->
/*
* Meta columns are forced to uppercase for backwards compatibility with previous
* versions of the destination. Therefore, convert the column to lowercase so
* that it can match the constants, which use the lowercase version of the meta
* column names.
*/
if (airbyteColumnNames.contains(columnName)) {
convertValue(record[columnName.lowercase()])
} else {
record.keys
.find { it == columnName.toSnowflakeCompatibleName() }
?.let { convertValue(record[it]) }
}
}

private fun convertValue(value: AirbyteValue?) =
value?.let {
when (value) {
is BooleanValue -> value.value
is DateValue -> value.value.toString()
is IntegerValue -> value.value
is NumberValue -> value.value
is TimeWithTimezoneValue -> value.value.toString()
is TimeWithoutTimezoneValue -> value.value.toString()
is TimestampWithoutTimezoneValue -> value.value.toString()
is TimestampWithTimezoneValue -> value.value.toString()
is ObjectValue -> value.serializeToString()
is ArrayValue -> value.serializeToString()
is StringValue -> value.value
is NullValue -> null
}
}
}

class SnowflakeSchemaRecordFormatter(
Expand Down
Loading