diff --git a/src/main/resources/reference-general.conf b/src/main/resources/reference-general.conf index b1df36c3b..225d6789b 100644 --- a/src/main/resources/reference-general.conf +++ b/src/main/resources/reference-general.conf @@ -235,6 +235,8 @@ rowValidatorClass = "ai.starlake.job.validator.FlatRowValidator" env = "" env = ${?SL_ENV} +sqlCaseSensitivity = "default" +sqlCaseSensitivity = ${?SL_SQL_CASE_SENSITIVITY} // "upper" or "lower" or "default" sqlParameterPattern = "\\$\\{\\s*%s\\s*\\}" diff --git a/src/main/resources/templates/dags/load/snowflake__scheduled_table__sql.py.j2 b/src/main/resources/templates/dags/load/snowflake__scheduled_table__sql.py.j2 index 5e6291e5d..99bb85a17 100644 --- a/src/main/resources/templates/dags/load/snowflake__scheduled_table__sql.py.j2 +++ b/src/main/resources/templates/dags/load/snowflake__scheduled_table__sql.py.j2 @@ -1,10 +1,10 @@ # This template executes individual sql jobs and requires the following dag generation options set: # -# - stage_location: the location of the stage in snowflake [REQUIRED] -# - warehouse(COMPUTE_WH): the warehouse to use for the DAG [OPTIONAL], default to COMPUTE_WH +# - stage_location: Where the generated tasks will be stored [REQUIRED] +# - warehouse(COMPUTE_WH): warehouse where the tasks will be hosted [OPTIONAL], default to COMPUTE_WH # - timezone(UTC): the timezone to use for the schedule [OPTIONAL], default to UTC # - packages(croniter,python-dateutil,snowflake-snowpark-python): a list of packages to install before running the task [OPTIONAL], default to croniter,python-dateutil,snowflake-snowpark-python -# - sl_incoming_file_stage: the stage to use for incoming files [OPTIONAL] +# - sl_incoming_file_stage: the stage to use for incoming files in load tasks [REQUIRED] # - sl_env_var: starlake variables specified as a map in json format - at least the root project path SL_ROOT should be specified [OPTIONAL] # - retries(1): the number of retries to attempt before failing the task [OPTIONAL] # - retry_delay(300): the delay between retries in seconds [OPTIONAL] diff --git a/src/main/resources/templates/dags/transform/snowflake__scheduled_task__sql.py.j2 b/src/main/resources/templates/dags/transform/snowflake__scheduled_task__sql.py.j2 index e03c5aff2..8adbf79f4 100644 --- a/src/main/resources/templates/dags/transform/snowflake__scheduled_task__sql.py.j2 +++ b/src/main/resources/templates/dags/transform/snowflake__scheduled_task__sql.py.j2 @@ -1,7 +1,7 @@ # This template executes individual SNOWFLAKE SQL jobs and requires the following dag generation options set: # -# - stage_location: the location of the stage in snowflake [REQUIRED] -# - warehouse(COMPUTE_WH): the warehouse to use for the DAG [OPTIONAL], default to COMPUTE_WH +# - stage_location: Where the generated tasks will be stored [REQUIRED] +# - warehouse(COMPUTE_WH): warehouse where the tasks will be hosted [OPTIONAL], default to COMPUTE_WH # - timezone(UTC): the timezone to use for the schedule [OPTIONAL], default to UTC # - packages(croniter,python-dateutil,snowflake-snowpark-python): a list of packages to install before running the task [OPTIONAL], default to croniter,python-dateutil,snowflake-snowpark-python # - sl_env_var: starlake variables specified as a map in json format - at least the root project path SL_ROOT should be specified [OPTIONAL] diff --git a/src/main/scala/ai/starlake/config/Settings.scala b/src/main/scala/ai/starlake/config/Settings.scala index 190eeb5b4..8a44d67aa 100644 --- a/src/main/scala/ai/starlake/config/Settings.scala +++ b/src/main/scala/ai/starlake/config/Settings.scala @@ -787,7 +787,8 @@ object Settings extends StrictLogging { maxInteractiveRecords: Int, duckdbPath: Option[String], ack: Option[String], - duckDbEnableExternalAccess: Boolean + duckDbEnableExternalAccess: Boolean, + sqlCaseSensitivity: String // createTableIfNotExists: Boolean ) extends Serializable { @@ -1001,7 +1002,7 @@ object Settings extends StrictLogging { if (this.connections.isEmpty) s"connectionRef must be defined. Define a connection first and set it to this newly defined connection" else - s"connectionRef must be defined. Valid connection names are $validConnectionNames" + s"connectionRef resolves to an empty value. It must be defined. Valid connection names are $validConnectionNames" errors = errors :+ ValidationMessage(Severity.Error, "AppConfig", msg) } else { this.connections.get(this.connectionRef) match { diff --git a/src/main/scala/ai/starlake/extract/ExtractBigQuerySchema.scala b/src/main/scala/ai/starlake/extract/ExtractBigQuerySchema.scala index e5b252335..283950ce3 100644 --- a/src/main/scala/ai/starlake/extract/ExtractBigQuerySchema.scala +++ b/src/main/scala/ai/starlake/extract/ExtractBigQuerySchema.scala @@ -56,7 +56,6 @@ class ExtractBigQuerySchema(config: BigQueryTablesConfig)(implicit settings: Set tablesToExtract: Map[String, List[String]] ): List[Domain] = { val datasetNames = tablesToExtract.keys.toList - val lowercaseDatasetNames = tablesToExtract.keys.map(_.toLowerCase()).toList val filteredDatasets = if (datasetNames.size == 1) { // We optimize extraction for a single dataset @@ -69,10 +68,10 @@ class ExtractBigQuerySchema(config: BigQueryTablesConfig)(implicit settings: Set .iterateAll() .asScala .filter(ds => - datasetNames.isEmpty || lowercaseDatasetNames.contains( - ds.getDatasetId.getDataset().toLowerCase() - ) + datasetNames.isEmpty || tablesToExtract.keys + .exists(_.equalsIgnoreCase(ds.getDatasetId.getDataset)) ) + } filteredDatasets.map { dataset => extractDataset(schemaHandler, dataset) @@ -98,21 +97,23 @@ class ExtractBigQuerySchema(config: BigQueryTablesConfig)(implicit settings: Set val tables = schemaHandler.domains().find(_.finalName.equalsIgnoreCase(datasetName)) match { case Some(domain) => - val tablesToExclude = domain.tables.map(_.finalName.toLowerCase()) - allTables.filterNot(t => tablesToExclude.contains(t.getTableId.getTable().toLowerCase())) + val tablesToExclude = domain.tables.map(_.finalName) + allTables.filterNot(t => + tablesToExclude.exists(_.equalsIgnoreCase(t.getTableId.getTable)) + ) case None => allTables } val schemas = tables.flatMap { bqTable => - logger.info(s"Extracting table $datasetName.${bqTable.getTableId.getTable()}") + logger.info(s"Extracting table $datasetName.${bqTable.getTableId.getTable}") // We get the Table again below because Tables are returned with a null definition by listTables above. - Try(bigquery.getTable(bqTable.getTableId())) match { + Try(bigquery.getTable(bqTable.getTableId)) match { case scala.util.Success(tableWithDefinition) => - if (tableWithDefinition.getDefinition().isInstanceOf[StandardTableDefinition]) + if (tableWithDefinition.getDefinition.isInstanceOf[StandardTableDefinition]) Some(extractTable(tableWithDefinition)) else None case scala.util.Failure(e) => - logger.error(s"Failed to get table ${bqTable.getTableId()}", e) + logger.error(s"Failed to get table ${bqTable.getTableId}", e) None } } diff --git a/src/main/scala/ai/starlake/extract/ExtractDataJob.scala b/src/main/scala/ai/starlake/extract/ExtractDataJob.scala index 992ed1248..c5459a319 100644 --- a/src/main/scala/ai/starlake/extract/ExtractDataJob.scala +++ b/src/main/scala/ai/starlake/extract/ExtractDataJob.scala @@ -78,8 +78,8 @@ class ExtractDataJob(schemaHandler: SchemaHandler) extends ExtractPathHelper wit .filter { s => (config.includeSchemas, config.excludeSchemas) match { case (Nil, Nil) => true - case (inc, Nil) => inc.map(_.toLowerCase).contains(s.schema.toLowerCase) - case (Nil, exc) => !exc.map(_.toLowerCase).contains(s.schema.toLowerCase) + case (inc, Nil) => inc.exists(_.equalsIgnoreCase(s.schema)) + case (Nil, exc) => !exc.exists(_.equalsIgnoreCase(s.schema)) case (_, _) => throw new RuntimeException( "You can't specify includeShemas and excludeSchemas at the same time" diff --git a/src/main/scala/ai/starlake/extract/JdbcDbUtils.scala b/src/main/scala/ai/starlake/extract/JdbcDbUtils.scala index c5ea55823..f94dc2dac 100644 --- a/src/main/scala/ai/starlake/extract/JdbcDbUtils.scala +++ b/src/main/scala/ai/starlake/extract/JdbcDbUtils.scala @@ -7,6 +7,7 @@ import ai.starlake.extract.JdbcDbUtils.{lastExportTableName, Columns} import ai.starlake.job.Main import ai.starlake.schema.model._ import ai.starlake.sql.SQLUtils +import ai.starlake.sql.SQLUtils.sqlCased import ai.starlake.tests.StarlakeTestData.DomainName import ai.starlake.utils.{SparkUtils, Utils} import com.manticore.jsqlformatter.JSQLFormatter @@ -217,7 +218,7 @@ object JdbcDbUtils extends LazyLogging { } @throws[Exception] - def createSchema(conn: SQLConnection, domainName: String): Unit = { + def createSchema(conn: SQLConnection, domainName: String)(implicit settings: Settings): Unit = { executeUpdate(schemaCreateSQL(domainName), conn) match { case Success(_) => case Failure(e) => @@ -227,15 +228,15 @@ object JdbcDbUtils extends LazyLogging { } @throws[Exception] - def schemaCreateSQL(domainName: String): String = { - s"CREATE SCHEMA IF NOT EXISTS $domainName" + def schemaCreateSQL(domainName: String)(implicit settings: Settings): String = { + s"CREATE SCHEMA IF NOT EXISTS ${sqlCased(domainName)}" } - def buildDropTableSQL(tableName: String): String = { - s"DROP TABLE IF EXISTS $tableName" + def buildDropTableSQL(tableName: String)(implicit settings: Settings): String = { + s"DROP TABLE IF EXISTS ${sqlCased(tableName)}" } @throws[Exception] - def dropTable(conn: SQLConnection, tableName: String): Unit = { + def dropTable(conn: SQLConnection, tableName: String)(implicit settings: Settings): Unit = { executeUpdate(buildDropTableSQL(tableName), conn) match { case Success(_) => case Failure(e) => @@ -290,6 +291,34 @@ object JdbcDbUtils extends LazyLogging { stmt.close() result } + def executeQueryAsTable( + query: String, + connection: SQLConnection + ): List[Map[String, String]] = { + val resultTable = ListBuffer[Map[String, String]]() + val statement = connection.createStatement() + try { + // Establish the connection + val resultSet = statement.executeQuery(query) + + // Get column names + val metaData = resultSet.getMetaData + val columnCount = metaData.getColumnCount + val columnNames = (1 to columnCount).map(metaData.getColumnName) + + // Process the result set + while (resultSet.next()) { + val row = columnNames + .map(name => name -> Option(resultSet.getObject(name)).map(_.toString).getOrElse("null")) + .toMap + resultTable += row + } + } finally { + statement.close() + } + + resultTable.toList + } def executeQueryAsTable( query: String, @@ -634,10 +663,12 @@ object JdbcDbUtils extends LazyLogging { val jdbcServer = url.split(":")(1) val jdbcEngine = settings.appConfig.jdbcEngines.get(jdbcServer) val jdbcTableMap = - jdbcSchema.tables - .map(tblSchema => tblSchema.name.toUpperCase -> tblSchema) - .toMap - val uppercaseTableNames = jdbcTableMap.keys.toList + CaseInsensitiveMap( + jdbcSchema.tables + .map(tblSchema => tblSchema.name -> tblSchema) + .toMap + ) + val tableNamesToExtract = jdbcTableMap.keys.toList val schemaAndTableNames = withJDBCConnection(readOnlyConnection(connectionSettings).options) { connection => val databaseMetaData = connection.getMetaData() @@ -646,56 +677,53 @@ object JdbcDbUtils extends LazyLogging { databaseMetaData, jdbcSchema.schema ).map { schemaName => - val lowerCasedExcludeTables = jdbcSchema.exclude.map(_.toLowerCase) - def tablesInScopePredicate(tablesToExtract: List[String] = Nil): TableName => Boolean = (tableName: String) => { - !lowerCasedExcludeTables.contains( - tableName.toLowerCase - ) && (tablesToExtract.isEmpty || tablesToExtract.contains(tableName.toUpperCase())) + !jdbcSchema.exclude.exists( + _.equalsIgnoreCase(tableName) + ) && (tablesToExtract.isEmpty || + tablesToExtract.exists(_.equalsIgnoreCase(tableName))) } val sqlDefinedTables = jdbcSchema.tables.filter(_.sql.isDefined).map(_.name) - val selectedTables = uppercaseTableNames match { - case list if list.isEmpty || list.contains("*") => - extractTableNames( - schemaName, - jdbcSchema, - sqlDefinedTables, - tablesInScopePredicate(), - connectionSettings, - databaseMetaData, - skipRemarks, - jdbcEngine, - connection - ) - case list => - val extractedTableNames = + val selectedTables = + tableNamesToExtract match { + case list if list.isEmpty || list.contains("*") => extractTableNames( schemaName, jdbcSchema, sqlDefinedTables, - tablesInScopePredicate(list), + tablesInScopePredicate(), connectionSettings, databaseMetaData, skipRemarks, jdbcEngine, connection ) - val notExtractedTable = list.diff( - extractedTableNames - .map { case (tableName, _) => tableName } - .map(_.toUpperCase()) - .toList - ) - if (notExtractedTable.nonEmpty) { - val tablesNotExtractedStr = notExtractedTable.mkString(", ") - logger.warn( - s"The following tables where not extracted for $schemaName.${jdbcSchema.schema} : $tablesNotExtractedStr" + case list => + val extractedTableNames = + extractTableNames( + schemaName, + jdbcSchema, + sqlDefinedTables, + tablesInScopePredicate(list), + connectionSettings, + databaseMetaData, + skipRemarks, + jdbcEngine, + connection + ) + val notExtractedTable = list.diff( + extractedTableNames.map { case (tableName, _) => tableName }.toList ) - } - extractedTableNames - } + if (notExtractedTable.nonEmpty) { + val tablesNotExtractedStr = notExtractedTable.mkString(", ") + logger.warn( + s"The following tables where not extracted for $schemaName.${jdbcSchema.schema} : $tablesNotExtractedStr" + ) + } + extractedTableNames + } logger.whenDebugEnabled { selectedTables.keys.foreach(table => logger.debug(s"Selected: $table")) } @@ -748,36 +776,41 @@ object JdbcDbUtils extends LazyLogging { ) ) val primaryKeys = jdbcColumnMetadata.primaryKeys - val foreignKeys: Map[TableName, TableName] = jdbcColumnMetadata.foreignKeys + val foreignKeys: CaseInsensitiveMap[TableName] = + CaseInsensitiveMap(jdbcColumnMetadata.foreignKeys) val columns: List[Attribute] = jdbcColumnMetadata.columns logger.whenDebugEnabled { columns .foreach(column => logger.debug(s"column: $tableName.${column.name}")) } val jdbcCurrentTable = jdbcTableMap - .get(tableName.toUpperCase) + .get(tableName) // Limit to the columns specified by the user if any - val currentTableRequestedColumns: Map[ColumnName, Option[ColumnName]] = - jdbcCurrentTable - .map( - _.columns.map(c => - (if (keepOriginalName) c.name.toUpperCase.trim - else c.rename.getOrElse(c.name).toUpperCase.trim) -> c.rename - ) - ) - .getOrElse(Map.empty) - .toMap + val currentTableRequestedColumns: CaseInsensitiveMap[Option[ColumnName]] = + CaseInsensitiveMap( + jdbcCurrentTable + .map { + _.columns.map { c => + val key = + if (keepOriginalName) c.name.trim + else c.rename.getOrElse(c.name).trim + key -> c.rename + } + } + .getOrElse(Map.empty) + .toMap + ) val currentFilter = jdbcCurrentTable.flatMap(_.filter) val selectedColumns: List[Attribute] = columns .filter(col => currentTableRequestedColumns.isEmpty || currentTableRequestedColumns - .contains("*") || currentTableRequestedColumns - .contains(col.name.toUpperCase()) + .contains("*") || currentTableRequestedColumns.keys + .exists(_.equalsIgnoreCase(col.name)) ) .map(c => c.copy( - foreignKey = foreignKeys.get(c.name.toUpperCase) + foreignKey = foreignKeys.get(c.name) ) ) logger.whenDebugEnabled { diff --git a/src/main/scala/ai/starlake/extract/JdbcMetadata.scala b/src/main/scala/ai/starlake/extract/JdbcMetadata.scala index f0d2f48b8..49340eb98 100644 --- a/src/main/scala/ai/starlake/extract/JdbcMetadata.scala +++ b/src/main/scala/ai/starlake/extract/JdbcMetadata.scala @@ -12,6 +12,7 @@ import ai.starlake.extract.JdbcDbUtils.{ } import ai.starlake.schema.model.Attribute import com.typesafe.scalalogging.StrictLogging +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import java.sql.{DatabaseMetaData, ResultSetMetaData} import java.sql.Types.{ @@ -55,7 +56,7 @@ sealed trait JdbcColumnMetadata extends StrictLogging { /** @return * a map of foreign key name and its pk composite name */ - def foreignKeys: Map[ColumnName, ColumnName] + def foreignKeys: CaseInsensitiveMap[ColumnName] /** @return * a list of attributes representing resource's columns @@ -237,7 +238,7 @@ final class JdbcColumnDatabaseMetadata( /** @return * a map of foreign key name and its pk composite name */ - override def foreignKeys: Map[String, String] = { + override def foreignKeys: CaseInsensitiveMap[String] = { Try { Using.resource(connectionSettings match { case d if d.isMySQLOrMariaDb() => @@ -263,7 +264,7 @@ final class JdbcColumnDatabaseMetadata( val pkColumnName = foreignKeysResultSet.getString("PKCOLUMN_NAME") val pkFinalColumnName = computeFinalColumnName(pkTableName, pkColumnName) val fkColumnName = - foreignKeysResultSet.getString("FKCOLUMN_NAME").toUpperCase + foreignKeysResultSet.getString("FKCOLUMN_NAME") val pkCompositeName = if (pkSchemaName == null) s"$pkTableName.$pkFinalColumnName" @@ -276,8 +277,8 @@ final class JdbcColumnDatabaseMetadata( } match { case Failure(exception) => logger.warn(s"Could not extract foreign keys for table $tableName") - Map.empty[String, String] - case Success(value) => value + CaseInsensitiveMap(Map.empty[String, String]) + case Success(value) => CaseInsensitiveMap(value) } } @@ -326,7 +327,7 @@ final class JdbcColumnDatabaseMetadata( val colRemarks = remarks.getOrElse(colName, columnsResultSet.getString("REMARKS")) val colRequired = columnsResultSet.getString("IS_NULLABLE").equals("NO") - val foreignKey = foreignKeys.get(colName.toUpperCase) + val foreignKey = foreignKeys.get(colName) // val columnPosition = columnsResultSet.getInt("ORDINAL_POSITION") Attribute( name = if (keepOriginalName) colName else finalColName, @@ -362,7 +363,7 @@ class ResultSetColumnMetadata( /** @return * a map of foreign key name and its pk composite name */ - override def foreignKeys: Map[String, String] = Map.empty + override def foreignKeys: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty) /** @return * a list of attributes representing resource's columns diff --git a/src/main/scala/ai/starlake/extract/TemplateParams.scala b/src/main/scala/ai/starlake/extract/TemplateParams.scala index a3ed34845..d7fbb7dea 100644 --- a/src/main/scala/ai/starlake/extract/TemplateParams.scala +++ b/src/main/scala/ai/starlake/extract/TemplateParams.scala @@ -82,7 +82,7 @@ case class TemplateParams( "full_export" -> fullExport, "audit_schema" -> auditDB.getOrElse(domainToExport) ) - ) { case (list, deltaCol) => list :+ ("delta_column" -> deltaCol.toUpperCase) } + ) { case (list, deltaCol) => list :+ ("delta_column" -> deltaCol) } .toMap ++ activeEnv } } diff --git a/src/main/scala/ai/starlake/job/ingest/loaders/BigQueryNativeLoader.scala b/src/main/scala/ai/starlake/job/ingest/loaders/BigQueryNativeLoader.scala index 6b5d665c3..2339cf183 100644 --- a/src/main/scala/ai/starlake/job/ingest/loaders/BigQueryNativeLoader.scala +++ b/src/main/scala/ai/starlake/job/ingest/loaders/BigQueryNativeLoader.scala @@ -166,6 +166,7 @@ class BigQueryNativeLoader(ingestionJob: IngestionJob, accessToken: Option[Strin output } // ignore exception but log it } else { + // One single step load val bigqueryJob = new BigQueryNativeJob(targetConfig, "") bigqueryJob .loadPathsToBQ( diff --git a/src/main/scala/ai/starlake/job/ingest/loaders/DuckDbNativeLoader.scala b/src/main/scala/ai/starlake/job/ingest/loaders/DuckDbNativeLoader.scala index ffba12207..f2c33270e 100644 --- a/src/main/scala/ai/starlake/job/ingest/loaders/DuckDbNativeLoader.scala +++ b/src/main/scala/ai/starlake/job/ingest/loaders/DuckDbNativeLoader.scala @@ -4,39 +4,19 @@ import ai.starlake.config.{CometColumns, Settings} import ai.starlake.extract.JdbcDbUtils import ai.starlake.job.ingest.IngestionJob import ai.starlake.job.transform.JdbcAutoTask -import ai.starlake.schema.handlers.{SchemaHandler, StorageHandler} +import ai.starlake.schema.handlers.StorageHandler import ai.starlake.schema.model._ import ai.starlake.sql.SQLUtils import ai.starlake.utils.{IngestionCounters, SparkUtils} -import com.typesafe.scalalogging.StrictLogging -import com.univocity.parsers.csv.{CsvFormat, CsvParser, CsvParserSettings} import org.apache.hadoop.fs.Path import org.apache.spark.sql.execution.datasources.jdbc.JdbcOptionsInWrite -import java.nio.charset.Charset -import scala.util.{Try, Using} +import scala.util.Try -class DuckDbNativeLoader(ingestionJob: IngestionJob)(implicit - val settings: Settings -) extends StrictLogging { +class DuckDbNativeLoader(ingestionJob: IngestionJob)(implicit settings: Settings) + extends NativeLoader(ingestionJob, None) { - val domain: Domain = ingestionJob.domain - - val schema: Schema = ingestionJob.schema - - val storageHandler: StorageHandler = ingestionJob.storageHandler - - val schemaHandler: SchemaHandler = ingestionJob.schemaHandler - - val path: List[Path] = ingestionJob.path - - val options: Map[String, String] = ingestionJob.options - - val strategy: WriteStrategy = ingestionJob.mergedMetadata.getStrategyOptions() - - lazy val mergedMetadata: Metadata = ingestionJob.mergedMetadata - - private def requireTwoSteps(schema: Schema): Boolean = { + override protected def requireTwoSteps(schema: Schema): Boolean = { // renamed attribute can be loaded directly so it's not in the condition schema .hasTransformOrIgnoreOrScriptColumns() || @@ -69,23 +49,23 @@ class DuckDbNativeLoader(ingestionJob: IngestionJob)(implicit } val unionTempTables = tempTables.map("SELECT * FROM " + _).mkString("(", " UNION ALL ", ")") - val targetTableName = s"${domain.finalName}.${schema.finalName}" - val sqlWithTransformedFields = schema.buildSqlSelectOnLoad(unionTempTables) + val targetTableName = s"${domain.finalName}.${starlakeSchema.finalName}" + val sqlWithTransformedFields = starlakeSchema.buildSqlSelectOnLoad(unionTempTables) val taskDesc = AutoTaskDesc( name = targetTableName, sql = Some(sqlWithTransformedFields), database = schemaHandler.getDatabase(domain), domain = domain.finalName, - table = schema.finalName, - presql = schema.presql, - postsql = schema.postsql, + table = starlakeSchema.finalName, + presql = starlakeSchema.presql, + postsql = starlakeSchema.postsql, sink = mergedMetadata.sink, - rls = schema.rls, - expectations = schema.expectations, - acl = schema.acl, - comment = schema.comment, - tags = schema.tags, + rls = starlakeSchema.rls, + expectations = starlakeSchema.expectations, + acl = starlakeSchema.acl, + comment = starlakeSchema.comment, + tags = starlakeSchema.tags, writeStrategy = mergedMetadata.writeStrategy, parseSQL = Some(true), connectionRef = Option(mergedMetadata.getSinkConnectionRef()) @@ -107,7 +87,7 @@ class DuckDbNativeLoader(ingestionJob: IngestionJob)(implicit ) job.run() job.updateJdbcTableSchema( - schema.targetSparkSchemaWithIgnoreAndScript(schemaHandler), + starlakeSchema.targetSparkSchemaWithIgnoreAndScript(schemaHandler), targetTableName ) @@ -118,76 +98,13 @@ class DuckDbNativeLoader(ingestionJob: IngestionJob)(implicit } } } else { - singleStepLoad(domain.finalName, schema.finalName, schemaWithMergedMetadata, path) + singleStepLoad(domain.finalName, starlakeSchema.finalName, schemaWithMergedMetadata, path) } }.map { - => List(IngestionCounters(-1, -1, -1, path.map(_.toString))) } } - private def computeEffectiveInputSchema(): Schema = { - mergedMetadata.resolveFormat() match { - case Format.DSV => - (mergedMetadata.resolveWithHeader(), path.map(_.toString).headOption) match { - case (java.lang.Boolean.TRUE, Some(sourceFile)) => - val csvHeaders = storageHandler.readAndExecute( - new Path(sourceFile), - Charset.forName(mergedMetadata.resolveEncoding()) - ) { is => - Using.resource(is) { reader => - assert( - mergedMetadata.resolveQuote().length <= 1, - "quote must be a single character" - ) - assert( - mergedMetadata.resolveEscape().length <= 1, - "quote must be a single character" - ) - val csvParserSettings = new CsvParserSettings() - val format = new CsvFormat() - format.setDelimiter(mergedMetadata.resolveSeparator()) - mergedMetadata.resolveQuote().headOption.foreach(format.setQuote) - mergedMetadata.resolveEscape().headOption.foreach(format.setQuoteEscape) - csvParserSettings.setFormat(format) - // allocate twice the declared columns. If fail a strange exception is thrown: https://github.com/uniVocity/univocity-parsers/issues/247 - csvParserSettings.setMaxColumns(schema.attributes.length * 2) - csvParserSettings.setNullValue(mergedMetadata.resolveNullValue()) - csvParserSettings.setHeaderExtractionEnabled(true) - csvParserSettings.setMaxCharsPerColumn(-1) - val csvParser = new CsvParser(csvParserSettings) - csvParser.beginParsing(reader) - // call this in order to get the headers even if there is no record - csvParser.parseNextRecord() - csvParser.getRecordMetadata.headers().toList - } - } - val attributesMap = schema.attributes.map(attr => attr.name -> attr).toMap - val csvAttributesInOrders = - csvHeaders.map(h => - attributesMap.getOrElse(h, Attribute(h, ignore = Some(true), required = None)) - ) - // attributes not in csv input file must not be required but we don't force them to optional. - val effectiveAttributes = - csvAttributesInOrders ++ schema.attributes.diff(csvAttributesInOrders) - if (effectiveAttributes.length > schema.attributes.length) { - logger.warn( - s"Attributes in the CSV file are bigger from the schema. " + - s"Schema will be updated to match the CSV file. " + - s"Schema: ${schema.attributes.map(_.name).mkString(",")}. " + - s"CSV: ${csvHeaders.mkString(",")}" - ) - schema.copy(attributes = effectiveAttributes.take(schema.attributes.length)) - - } else { - schema.copy(attributes = effectiveAttributes) - } - - case _ => schema - } - case _ => schema - } - } - def singleStepLoad(domain: String, table: String, schema: Schema, path: List[Path]) = { val sinkConnection = mergedMetadata.getSinkConnection() val incomingSparkSchema = schema.targetSparkSchemaWithIgnoreAndScript(schemaHandler) @@ -247,29 +164,30 @@ class DuckDbNativeLoader(ingestionJob: IngestionJob)(implicit .map { p => val ps = p.toString if (ps.startsWith("file:")) - StorageHandler.localFile(p).pathAsString - else if (ps.contains { "://" }) { - val defaultEndpoint = - ps.substring(2) match { - case "gs" => "storage.googleapis.com" - case "s3" => "s3.amazonaws.com" - case _ => "s3.amazonaws.com" - } - val endpoint = - sinkConnection.options.getOrElse("s3_endpoint", defaultEndpoint) - val keyid = - sinkConnection.options("s3_access_key_id") - val secret = - sinkConnection.options("s3_secret_access_key") - JdbcDbUtils.execute("INSTALL httpfs;", conn) - JdbcDbUtils.execute("LOAD httpfs;", conn) - JdbcDbUtils.execute(s"SET s3_endpoint='$endpoint';", conn) - JdbcDbUtils.execute(s"SET s3_access_key_id='$keyid';", conn) - JdbcDbUtils.execute(s"SET s3_secret_access_key='$secret';", conn) - ps - } else { - ps - } + if (ps.startsWith("file:")) + StorageHandler.localFile(p).pathAsString + else if (ps.contains { "://" }) { + val defaultEndpoint = + ps.substring(2) match { + case "gs" => "storage.googleapis.com" + case "s3" => "s3.amazonaws.com" + case _ => "s3.amazonaws.com" + } + val endpoint = + sinkConnection.options.getOrElse("s3_endpoint", defaultEndpoint) + val keyid = + sinkConnection.options("s3_access_key_id") + val secret = + sinkConnection.options("s3_secret_access_key") + JdbcDbUtils.execute("INSTALL httpfs;", conn) + JdbcDbUtils.execute("LOAD httpfs;", conn) + JdbcDbUtils.execute(s"SET s3_endpoint='$endpoint';", conn) + JdbcDbUtils.execute(s"SET s3_access_key_id='$keyid';", conn) + JdbcDbUtils.execute(s"SET s3_secret_access_key='$secret';", conn) + ps + } else { + ps + } } .mkString("['", "','", "']") mergedMetadata.resolveFormat() match { diff --git a/src/main/scala/ai/starlake/job/ingest/loaders/NativeLoader.scala b/src/main/scala/ai/starlake/job/ingest/loaders/NativeLoader.scala index 55d2ad213..5772cc7df 100644 --- a/src/main/scala/ai/starlake/job/ingest/loaders/NativeLoader.scala +++ b/src/main/scala/ai/starlake/job/ingest/loaders/NativeLoader.scala @@ -194,6 +194,8 @@ class NativeLoader(ingestionJob: IngestionJob, accessToken: Option[String])(impl csvParser.getRecordMetadata.headers().toList } } + // The result is a list of effectiveAttributes that combines the attributes from the schema + // and the CSV file, ensuring compatibility between the two. val attributesMap = starlakeSchema.attributes.map(attr => attr.name -> attr).toMap val csvAttributesInOrders = csvHeaders.map(h => @@ -342,7 +344,7 @@ class NativeLoader(ingestionJob: IngestionJob, accessToken: Option[String])(impl val stepMap = if (twoSteps) { - val (tempCreateSchemaSql, tempCreateTableSql, _) = SparkUtils.buildCreateTableSQL( + val (tempCreateSchemaSql, tempCreateTableSql, commentsSQL) = SparkUtils.buildCreateTableSQL( tempTableName, incomingSparkSchema, caseSensitive = false, @@ -358,7 +360,7 @@ class NativeLoader(ingestionJob: IngestionJob, accessToken: Option[String])(impl 0 ) - val firstSTepCreateTableSqls = List(tempCreateSchemaSql, tempCreateTableSql) + val firstSTepCreateTableSqls = List(tempCreateSchemaSql, tempCreateTableSql) ++ commentsSQL val extraFileNameColumn = s"ALTER TABLE $tempTableName ADD COLUMN ${CometColumns.cometInputFileNameColumn} STRING DEFAULT '{{sl_input_file_name}}';" val workflowStatements = this.secondStepSQL(List(tempTableName)) @@ -388,7 +390,7 @@ class NativeLoader(ingestionJob: IngestionJob, accessToken: Option[String])(impl loadTaskSQL.asJava ) } else { - val (createSchemaSql, createTableSql, _) = SparkUtils.buildCreateTableSQL( + val (createSchemaSql, createTableSql, commentsSQL) = SparkUtils.buildCreateTableSQL( targetTableName, incomingSparkSchema, caseSensitive = false, @@ -396,7 +398,7 @@ class NativeLoader(ingestionJob: IngestionJob, accessToken: Option[String])(impl options, ddlMap ) - val createTableSqls = List(createSchemaSql, createTableSql) + val createTableSqls = List(createSchemaSql, createTableSql) ++ commentsSQL val workflowStatements = this.secondStepSQL(List(targetTableName)) val loadTaskSQL = Map( @@ -419,7 +421,6 @@ class NativeLoader(ingestionJob: IngestionJob, accessToken: Option[String])(impl } val engine = settings.appConfig.jdbcEngines(engineName.toString) - val tempStage = s"starlake_load_stage_${Random.alphanumeric.take(10).mkString("")}" val commonOptionsMap = Map( "schema" -> starlakeSchema.asMap().asJava, "sink" -> sink.asMap(engine).asJava, diff --git a/src/main/scala/ai/starlake/job/ingest/loaders/SnowflakeNativeLoader.scala b/src/main/scala/ai/starlake/job/ingest/loaders/SnowflakeNativeLoader.scala index ab93ca8e7..13c8d53bf 100644 --- a/src/main/scala/ai/starlake/job/ingest/loaders/SnowflakeNativeLoader.scala +++ b/src/main/scala/ai/starlake/job/ingest/loaders/SnowflakeNativeLoader.scala @@ -334,12 +334,7 @@ class SnowflakeNativeLoader(ingestionJob: IngestionJob)(implicit settings: Setti ddlMap ) } - val columnsString = - attrsWithDDLTypes - .map { case (attr, ddlType) => - s"'$attr': '$ddlType'" - } - .mkString(", ") + val pathsAsString = path .map { p => @@ -350,6 +345,7 @@ class SnowflakeNativeLoader(ingestionJob: IngestionJob)(implicit settings: Setti logger.info(res.toString()) res = JdbcDbUtils.executeQueryAsTable(s"CREATE OR REPLACE TEMPORARY STAGE $tempStage", conn) logger.info(res.toString()) + val putSqls = pathsAsString.map(path => s"PUT $path @$tempStage/$domain") putSqls.map { putSql => res = JdbcDbUtils.executeQueryAsTable(putSql, conn) diff --git a/src/main/scala/ai/starlake/job/sink/bigquery/BigQueryJobBase.scala b/src/main/scala/ai/starlake/job/sink/bigquery/BigQueryJobBase.scala index 1da546868..32df658fe 100644 --- a/src/main/scala/ai/starlake/job/sink/bigquery/BigQueryJobBase.scala +++ b/src/main/scala/ai/starlake/job/sink/bigquery/BigQueryJobBase.scala @@ -23,6 +23,7 @@ import com.google.iam.v1.{Binding, Policy => IAMPolicy, SetIamPolicyRequest} import com.google.protobuf.FieldMask import com.typesafe.scalalogging.StrictLogging import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import java.io.ByteArrayInputStream import java.security.SecureRandom @@ -294,13 +295,13 @@ trait BigQueryJobBase extends StrictLogging { val tableDefinition = table.getDefinition[StandardTableDefinition] val bqSchema = tableDefinition.getSchema() val bqFields = bqSchema.getFields.asScala.toList - val attributesMap = attrs.toMap + val attributesMap = CaseInsensitiveMap(attrs.toMap) val updatedFields = bqFields.map { field => - attributesMap.get(field.getName.toLowerCase) match { + attributesMap.get(field.getName) match { case None => // Maybe an ignored field logger.info( - s"Ignore this field ${table}.${field.getName} during CLS application " + s"Ignore this field $table.$field during CLS application " ) field case Some(accessPolicy) => @@ -845,10 +846,10 @@ trait BigQueryJobBase extends StrictLogging { table ) val tableName = - tableIdPk.getDataset.toUpperCase() + "_" + tableIdPk.getTable.toUpperCase() + tableIdPk.getDataset + "_" + tableIdPk.getTable val fk = ForeignKey.newBuilder .setName( - s"FK_${datasetId.getDataset.toUpperCase()}_${tableId.getTable().toUpperCase()}_${attr.getFinalName().toUpperCase()}" + s"FK_${datasetId.getDataset}_${tableId.getTable()}_${attr.getFinalName()}" ) .setColumnReferences(List(columnReference).asJava) .setReferencedTable(tableIdPk) @@ -879,7 +880,7 @@ trait BigQueryJobBase extends StrictLogging { .newBuilder(TimePartitioning.Type.DAY) .setRequirePartitionFilter(requirePartitionFilter) val partitioned = - if (!Set("_PARTITIONTIME", "_PARTITIONDATE").contains(partitionField.toUpperCase)) + if (!Set("_PARTITIONTIME", "_PARTITIONDATE").exists(_.equalsIgnoreCase(partitionField))) partitionFilter.setField(partitionField) else partitionFilter diff --git a/src/main/scala/ai/starlake/job/sink/jdbc/JdbcConnectionLoadCmd.scala b/src/main/scala/ai/starlake/job/sink/jdbc/JdbcConnectionLoadCmd.scala index 69b5872dc..f00db6c7d 100644 --- a/src/main/scala/ai/starlake/job/sink/jdbc/JdbcConnectionLoadCmd.scala +++ b/src/main/scala/ai/starlake/job/sink/jdbc/JdbcConnectionLoadCmd.scala @@ -78,17 +78,9 @@ object JdbcConnectionLoadCmd extends Cmd[JdbcConnectionLoadConfig] { checkTablePresent(starlakeConnection, jdbcEngine, outputTable) } - // This is to make sure that the column names are uppercase on JDBC databases - // TODO: Once spark 3.3 is not supported anymore, switch to withColumnsRenamed(colsMap: Map[String, String]) - val dfWithUppercaseColumns = sourceFile.map { df => - df.columns.foldLeft(df) { case (df, colName) => - df.withColumnRenamed(colName, colName.toUpperCase()) - } - } - JdbcConnectionLoadConfig( - sourceFile = dfWithUppercaseColumns, - outputDomainAndTableName = outputTable.toUpperCase(), + sourceFile = sourceFile, + outputDomainAndTableName = outputTable, strategy = strategy, starlakeConnection.sparkDatasource().getOrElse("jdbc"), starlakeConnection.options diff --git a/src/main/scala/ai/starlake/job/transform/AutoTask.scala b/src/main/scala/ai/starlake/job/transform/AutoTask.scala index 451808af3..78deb9550 100644 --- a/src/main/scala/ai/starlake/job/transform/AutoTask.scala +++ b/src/main/scala/ai/starlake/job/transform/AutoTask.scala @@ -30,6 +30,7 @@ import ai.starlake.job.strategies.TransformStrategiesBuilder import ai.starlake.schema.handlers.{SchemaHandler, StorageHandler} import ai.starlake.schema.model._ import ai.starlake.sql.SQLUtils +import ai.starlake.sql.SQLUtils.sqlCased import ai.starlake.transpiler.JSQLTranspiler import ai.starlake.utils.Formatter.RichFormatter import ai.starlake.utils._ @@ -341,9 +342,9 @@ abstract class AutoTask( val scd2Columns = List(startTsCol, endTsCol) val alterTableSqls = scd2Columns.map { column => if (engineName.toString.toLowerCase() == "redshift") - s"ALTER TABLE $fullTableName ADD COLUMN $column TIMESTAMP" + s"ALTER TABLE ${sqlCased(fullTableName)} ADD COLUMN ${sqlCased(column)} TIMESTAMP" else - s"ALTER TABLE $fullTableName ADD COLUMN IF NOT EXISTS $column TIMESTAMP NULL" + s"ALTER TABLE ${sqlCased(fullTableName)} ADD COLUMN IF NOT EXISTS ${sqlCased(column)} TIMESTAMP NULL" } alterTableSqls case _ => diff --git a/src/main/scala/ai/starlake/job/transform/BigQueryAutoTask.scala b/src/main/scala/ai/starlake/job/transform/BigQueryAutoTask.scala index 03b1ac580..76a4c2aeb 100644 --- a/src/main/scala/ai/starlake/job/transform/BigQueryAutoTask.scala +++ b/src/main/scala/ai/starlake/job/transform/BigQueryAutoTask.scala @@ -143,13 +143,8 @@ class BigQueryAutoTask( sql: String, jobTimeoutMs: Option[Long] = None ): BigQueryNativeJob = { - val toUpperSql = sql.toUpperCase() - val finalSql = - if (toUpperSql.startsWith("WITH") || toUpperSql.startsWith("SELECT")) - sql // "(" + sql + ")" - else - sql - new BigQueryNativeJob(config, finalSql, this.resultPageSize, jobTimeoutMs) + + new BigQueryNativeJob(config, sql, this.resultPageSize, jobTimeoutMs) } private def runSqls(sqls: List[String]): List[Try[BigQueryJobResult]] = { @@ -507,7 +502,7 @@ class BigQueryAutoTask( val isSCD2 = strategy.getEffectiveType() == WriteStrategyType.SCD2 if ( isSCD2 && !incomingTableSchema.getFields.asScala.exists( - _.getName().toLowerCase() == settings.appConfig.scd2StartTimestamp.toLowerCase() + _.getName.equalsIgnoreCase(settings.appConfig.scd2StartTimestamp) ) ) { val startCol = Field diff --git a/src/main/scala/ai/starlake/job/transform/JdbcAutoTask.scala b/src/main/scala/ai/starlake/job/transform/JdbcAutoTask.scala index 5a18c8770..dd9ccb415 100644 --- a/src/main/scala/ai/starlake/job/transform/JdbcAutoTask.scala +++ b/src/main/scala/ai/starlake/job/transform/JdbcAutoTask.scala @@ -5,6 +5,7 @@ import ai.starlake.extract.{ExtractSchemaCmd, ExtractSchemaConfig, JdbcDbUtils} import ai.starlake.job.metrics.{ExpectationJob, JdbcExpectationAssertionHandler} import ai.starlake.schema.handlers.{SchemaHandler, StorageHandler} import ai.starlake.schema.model.{AccessControlEntry, AutoTaskDesc, Engine, WriteStrategyType} +import ai.starlake.sql.SQLUtils.sqlCased import ai.starlake.utils.Formatter.RichFormatter import ai.starlake.utils.{JdbcJobResult, JobResult, SparkUtils, Utils} import com.typesafe.scalalogging.StrictLogging @@ -91,9 +92,9 @@ class JdbcAutoTask( val scd2Columns = List(startTsCol, endTsCol) val alterTableSqls = scd2Columns.map { column => if (engineName.toString.toLowerCase() == "redshift") - s"ALTER TABLE $fullTableName ADD COLUMN $column TIMESTAMP" + s"ALTER TABLE ${sqlCased(fullTableName)} ADD COLUMN ${sqlCased(column)} TIMESTAMP" else - s"ALTER TABLE $fullTableName ADD COLUMN IF NOT EXISTS $column TIMESTAMP NULL" + s"ALTER TABLE ${sqlCased(fullTableName)} ADD COLUMN IF NOT EXISTS ${sqlCased(column)} TIMESTAMP NULL" } alterTableSqls case _ => @@ -414,7 +415,7 @@ class JdbcAutoTask( optionsWrite, attDdl() ) - val allSqls = List(createSchema, createTable, commentSQL.getOrElse("")) + val allSqls = List(createSchema, createTable) ++ commentSQL (allSqls, false) } } diff --git a/src/main/scala/ai/starlake/job/transform/SparkAutoTask.scala b/src/main/scala/ai/starlake/job/transform/SparkAutoTask.scala index 255fb98b7..8862ba1a5 100644 --- a/src/main/scala/ai/starlake/job/transform/SparkAutoTask.scala +++ b/src/main/scala/ai/starlake/job/transform/SparkAutoTask.scala @@ -8,6 +8,7 @@ import ai.starlake.job.sink.es.{ESLoadConfig, ESLoadJob} import ai.starlake.schema.handlers.{SchemaHandler, StorageHandler} import ai.starlake.schema.model._ import ai.starlake.sql.SQLUtils +import ai.starlake.sql.SQLUtils.sqlCased import ai.starlake.utils.Formatter.RichFormatter import ai.starlake.utils._ import ai.starlake.utils.kafka.KafkaClient @@ -246,7 +247,7 @@ class SparkAutoTask( val tagsAsString = tableTagPairs.map { case (k, v) => s"'$k'='$v'" }.mkString(",") SparkUtils.sql( session, - s"CREATE SCHEMA IF NOT EXISTS ${taskDesc.domain} WITH DBPROPERTIES($tagsAsString)" + s"CREATE SCHEMA IF NOT EXISTS ${sqlCased(taskDesc.domain)} WITH DBPROPERTIES($tagsAsString)" ) } else { SparkUtils.createSchema(session, taskDesc.domain) @@ -437,7 +438,7 @@ class SparkAutoTask( val endTs = strategy.endTs.getOrElse(settings.appConfig.scd2EndTimestamp) val scd2FieldsFound = - incomingSchema.fields.exists(_.name.toLowerCase() == startTs.toLowerCase()) + incomingSchema.fields.exists(_.name.equalsIgnoreCase(startTs)) if (!scd2FieldsFound) { val incomingSchemaWithScd2 = diff --git a/src/main/scala/ai/starlake/schema/model/Domain.scala b/src/main/scala/ai/starlake/schema/model/Domain.scala index 45aabcfac..2f6b47e92 100644 --- a/src/main/scala/ai/starlake/schema/model/Domain.scala +++ b/src/main/scala/ai/starlake/schema/model/Domain.scala @@ -263,7 +263,7 @@ case class LoadDesc(version: Int, load: Domain) val filteredTables = tableNames.flatMap { tableName => this.tables.filter { table => tableNames - .exists(_.toLowerCase() == (this.finalName + "." + table.finalName).toLowerCase()) + .exists(_.equalsIgnoreCase(this.finalName + "." + table.finalName)) } } filteredTables.toList @@ -272,9 +272,7 @@ case class LoadDesc(version: Int, load: Domain) def aclTables(config: AclDependenciesConfig): List[Schema] = { val filteredTables = if (config.tables.nonEmpty) { tables.filter { table => - config.tables.exists( - _.toLowerCase() == (this.finalName + "." + table.finalName).toLowerCase() - ) + config.tables.exists(_.equalsIgnoreCase(this.finalName + "." + table.finalName)) } } else { tables @@ -294,9 +292,7 @@ case class LoadDesc(version: Int, load: Domain) def rlsTables(config: AclDependenciesConfig): Map[String, List[RowLevelSecurity]] = { val filteredTables = if (config.tables.nonEmpty) { tables.filter { table => - config.tables.exists( - _.toLowerCase() == (this.finalName + "." + table.finalName).toLowerCase() - ) + config.tables.exists(_.equalsIgnoreCase(this.finalName + "." + table.finalName)) } } else { tables @@ -394,7 +390,7 @@ object Domain { ( table, incoming.tables - .find(_.name.toLowerCase() == table.name.toLowerCase()) + .find(_.name.equalsIgnoreCase(table.name)) .getOrElse(throw new Exception("Should not happen")) ) } diff --git a/src/main/scala/ai/starlake/schema/model/Metadata.scala b/src/main/scala/ai/starlake/schema/model/Metadata.scala index e07051957..094395928 100644 --- a/src/main/scala/ai/starlake/schema/model/Metadata.scala +++ b/src/main/scala/ai/starlake/schema/model/Metadata.scala @@ -25,6 +25,7 @@ import ai.starlake.schema.model.Format.DSV import ai.starlake.schema.model.Severity._ import ai.starlake.schema.model.WriteMode.APPEND import com.fasterxml.jackson.annotation.{JsonIgnore, JsonInclude} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import scala.collection.mutable import scala.jdk.CollectionConverters._ @@ -281,10 +282,10 @@ case class Metadata( def resolveEmptyIsNull(): java.lang.Boolean = emptyIsNull.getOrElse(true).booleanValue() - def getOptions(): Map[String, String] = options.getOrElse(Map.empty) + def getOptions(): CaseInsensitiveMap[String] = CaseInsensitiveMap(options.getOrElse(Map.empty)) @JsonIgnore - def getXmlOptions(): Map[String, String] = this.getOptions() + def getXmlOptions(): CaseInsensitiveMap[String] = this.getOptions() @JsonIgnore def getXsdPath(): Option[String] = { diff --git a/src/main/scala/ai/starlake/schema/model/TypesDesc.scala b/src/main/scala/ai/starlake/schema/model/TypesDesc.scala index 4534c07af..3cc6c8a0d 100644 --- a/src/main/scala/ai/starlake/schema/model/TypesDesc.scala +++ b/src/main/scala/ai/starlake/schema/model/TypesDesc.scala @@ -212,7 +212,7 @@ case class Type( this.ddlMapping .getOrElse(Map.empty) .keys - .find(mapping => !mapping.equals(mapping.toLowerCase())) + .find(mapping => !mapping.equalsIgnoreCase(mapping)) notLowerCaseOnlyMapping match { case Some(mapping) => diff --git a/src/main/scala/ai/starlake/sql/SQLUtils.scala b/src/main/scala/ai/starlake/sql/SQLUtils.scala index 0eaa6e0c6..d462a5c19 100644 --- a/src/main/scala/ai/starlake/sql/SQLUtils.scala +++ b/src/main/scala/ai/starlake/sql/SQLUtils.scala @@ -561,4 +561,12 @@ object SQLUtils extends StrictLogging { sql } } + + def sqlCased(obj: String)(implicit settings: Settings): String = { + settings.appConfig.sqlCaseSensitivity.toLowerCase() match { + case "upper" => obj.toUpperCase + case "lower" => obj.toLowerCase + case _ => obj + } + } } diff --git a/src/main/scala/ai/starlake/sql/StarlakeJdbcDialects.scala b/src/main/scala/ai/starlake/sql/StarlakeJdbcDialects.scala index 0ccf3c6e1..f258f4f6c 100644 --- a/src/main/scala/ai/starlake/sql/StarlakeJdbcDialects.scala +++ b/src/main/scala/ai/starlake/sql/StarlakeJdbcDialects.scala @@ -5,7 +5,7 @@ import ai.starlake.extract.JdbcDbUtils.StarlakeConnectionPool import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} -import org.apache.spark.sql.types.{BooleanType, DataType, MetadataBuilder, TimestampType} +import org.apache.spark.sql.types._ import java.sql.{Connection, Types} @@ -20,6 +20,30 @@ private object StarlakeSnowflakeDialect extends JdbcDialect with SQLConfHelper { } } +private object StarlakeBigQueryDialect extends JdbcDialect with SQLConfHelper { + override def canHandle(url: String): Boolean = url.toLowerCase.startsWith("jdbc:bigquery:") + // override def quoteIdentifier(column: String): String = column + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case BooleanType => Some(JdbcType("BOOL", java.sql.Types.BOOLEAN)) + case IntegerType => Option(JdbcType("INT64", java.sql.Types.INTEGER)) + case LongType => Option(JdbcType("INT64", java.sql.Types.BIGINT)) + case DoubleType => Option(JdbcType("FLOAT64", java.sql.Types.DOUBLE)) + case FloatType => Option(JdbcType("FLOAT64", java.sql.Types.FLOAT)) + case ShortType => Option(JdbcType("INT64", java.sql.Types.SMALLINT)) + case ByteType => Option(JdbcType("INT64", java.sql.Types.TINYINT)) + case StringType => Option(JdbcType("STRING", java.sql.Types.CLOB)) + case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) + // This is a common case of timestamp without time zone. Most of the databases either only + // support TIMESTAMP type or use TIMESTAMP as an alias for TIMESTAMP WITHOUT TIME ZONE. + // Note that some dialects override this setting, e.g. as SQL Server. + case TimestampNTZType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) + case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) + case t: DecimalType => + Option(JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) + case _ => JdbcDbUtils.getCommonJDBCType(dt) + } +} + private object StarlakeDuckDbDialect extends JdbcDialect with SQLConfHelper { override def createConnectionFactory(options: JDBCOptions): Int => Connection = { @@ -57,7 +81,7 @@ private object StarlakeDuckDbDialect extends JdbcDialect with SQLConfHelper { } object StarlakeJdbcDialects { - val dialects = List(StarlakeSnowflakeDialect, StarlakeDuckDbDialect) + val dialects = List(StarlakeSnowflakeDialect, StarlakeDuckDbDialect, StarlakeBigQueryDialect) def registerDialects() = dialects.foreach { dialect => JdbcDialects.registerDialect(dialect) diff --git a/src/main/scala/ai/starlake/utils/SparkUtils.scala b/src/main/scala/ai/starlake/utils/SparkUtils.scala index 4892ebffd..29a757a04 100644 --- a/src/main/scala/ai/starlake/utils/SparkUtils.scala +++ b/src/main/scala/ai/starlake/utils/SparkUtils.scala @@ -3,6 +3,7 @@ package ai.starlake.utils import ai.starlake.config.Settings import ai.starlake.extract.JdbcDbUtils import ai.starlake.sql.SQLUtils +import ai.starlake.sql.SQLUtils.sqlCased import better.files.File import com.manticore.jsqlformatter.JSQLFormatter import com.typesafe.scalalogging.StrictLogging @@ -167,8 +168,8 @@ object SparkUtils extends StrictLogging { } } - def createSchema(session: SparkSession, domain: String): Unit = { - SparkUtils.sql(session, s"CREATE SCHEMA IF NOT EXISTS ${domain}") + def createSchema(session: SparkSession, domain: String)(implicit settings: Settings): Unit = { + SparkUtils.sql(session, s"CREATE SCHEMA IF NOT EXISTS ${sqlCased(domain)}") } def truncateTable(session: SparkSession, tableName: String): Unit = { @@ -222,7 +223,7 @@ object SparkUtils extends StrictLogging { temporaryTable: Boolean, options: JdbcOptionsInWrite, attrDdlMapping: Map[String, Map[String, String]] - )(implicit settings: Settings): (String, String, Option[String]) = { + )(implicit settings: Settings): (String, String, List[String]) = { val strSchema = schemaString( schema, @@ -239,18 +240,18 @@ object SparkUtils extends StrictLogging { strSchema.replaceAll("\"", "") val domainName = domainAndTableName.split('.').head - val createSchemaSQL = s"CREATE SCHEMA IF NOT EXISTS $domainName" + val createSchemaSQL = s"CREATE SCHEMA IF NOT EXISTS ${sqlCased(domainName)}" val temporary = if (temporaryTable) "TEMP" else "" val createTableSQL = - s"CREATE $temporary TABLE IF NOT EXISTS $domainAndTableName ($finalStrSchema) $createTableOptions" + s"CREATE $temporary TABLE IF NOT EXISTS ${sqlCased(domainAndTableName)} ($finalStrSchema) $createTableOptions" val commentSQL = if (options.tableComment.nonEmpty) - Some(s"COMMENT ON TABLE $domainAndTableName IS '${options.tableComment}'") + Some(s"COMMENT ON TABLE ${sqlCased(domainAndTableName)} IS '${options.tableComment}'") else None - - (createSchemaSQL, createTableSQL, commentSQL) + val attrsComments = commentsOnAttributes(domainAndTableName, schema) + (createSchemaSQL, createTableSQL, commentSQL.toList ++ attrsComments) } def isFlat(fields: StructType): Boolean = { @@ -376,7 +377,18 @@ object SparkUtils extends StrictLogging { } column } - columns.mkString(", ") + val result = columns.mkString(", ") + sqlCased(result) + } + + def commentsOnAttributes( + domainAndTableName: String, + schema: StructType + )(implicit settings: Settings): List[String] = { + schema.fields.flatMap { field => + getDescription(field) + .map(d => s"COMMENT ON COLUMN ${sqlCased(domainAndTableName + '.' + field.name)} IS '$d'") + }.toList } def sql(session: SparkSession, sql: String): DataFrame = { diff --git a/src/test/scala/ai/starlake/schema/generator/YamlSerdeSpec.scala b/src/test/scala/ai/starlake/schema/generator/YamlSerdeSpec.scala index 2ed7be3bb..dd47331f3 100644 --- a/src/test/scala/ai/starlake/schema/generator/YamlSerdeSpec.scala +++ b/src/test/scala/ai/starlake/schema/generator/YamlSerdeSpec.scala @@ -1455,7 +1455,8 @@ object YamlConfigGenerators { maxInteractiveRecords = maxInteractiveRecords, duckdbPath = duckdbPath, ack = None, - duckDbEnableExternalAccess = false + duckDbEnableExternalAccess = false, + sqlCaseSensitivity = "default" ) }