Skip to content

handle case sensitive databases #1300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/main/resources/reference-general.conf
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ rowValidatorClass = "ai.starlake.job.validator.FlatRowValidator"
env = ""
env = ${?SL_ENV}

sqlCaseSensitivity = "default"
sqlCaseSensitivity = ${?SL_SQL_CASE_SENSITIVITY} // "upper" or "lower" or "default"

sqlParameterPattern = "\\$\\{\\s*%s\\s*\\}"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# This template executes individual sql jobs and requires the following dag generation options set:
#
# - stage_location: the location of the stage in snowflake [REQUIRED]
# - warehouse(COMPUTE_WH): the warehouse to use for the DAG [OPTIONAL], default to COMPUTE_WH
# - stage_location: Where the generated tasks will be stored [REQUIRED]
# - warehouse(COMPUTE_WH): warehouse where the tasks will be hosted [OPTIONAL], default to COMPUTE_WH
# - timezone(UTC): the timezone to use for the schedule [OPTIONAL], default to UTC
# - packages(croniter,python-dateutil,snowflake-snowpark-python): a list of packages to install before running the task [OPTIONAL], default to croniter,python-dateutil,snowflake-snowpark-python
# - sl_incoming_file_stage: the stage to use for incoming files [OPTIONAL]
# - sl_incoming_file_stage: the stage to use for incoming files in load tasks [REQUIRED]
# - sl_env_var: starlake variables specified as a map in json format - at least the root project path SL_ROOT should be specified [OPTIONAL]
# - retries(1): the number of retries to attempt before failing the task [OPTIONAL]
# - retry_delay(300): the delay between retries in seconds [OPTIONAL]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This template executes individual SNOWFLAKE SQL jobs and requires the following dag generation options set:
#
# - stage_location: the location of the stage in snowflake [REQUIRED]
# - warehouse(COMPUTE_WH): the warehouse to use for the DAG [OPTIONAL], default to COMPUTE_WH
# - stage_location: Where the generated tasks will be stored [REQUIRED]
# - warehouse(COMPUTE_WH): warehouse where the tasks will be hosted [OPTIONAL], default to COMPUTE_WH
# - timezone(UTC): the timezone to use for the schedule [OPTIONAL], default to UTC
# - packages(croniter,python-dateutil,snowflake-snowpark-python): a list of packages to install before running the task [OPTIONAL], default to croniter,python-dateutil,snowflake-snowpark-python
# - sl_env_var: starlake variables specified as a map in json format - at least the root project path SL_ROOT should be specified [OPTIONAL]
Expand Down
5 changes: 3 additions & 2 deletions src/main/scala/ai/starlake/config/Settings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,8 @@ object Settings extends StrictLogging {
maxInteractiveRecords: Int,
duckdbPath: Option[String],
ack: Option[String],
duckDbEnableExternalAccess: Boolean
duckDbEnableExternalAccess: Boolean,
sqlCaseSensitivity: String
// createTableIfNotExists: Boolean
) extends Serializable {

Expand Down Expand Up @@ -1001,7 +1002,7 @@ object Settings extends StrictLogging {
if (this.connections.isEmpty)
s"connectionRef must be defined. Define a connection first and set it to this newly defined connection"
else
s"connectionRef must be defined. Valid connection names are $validConnectionNames"
s"connectionRef resolves to an empty value. It must be defined. Valid connection names are $validConnectionNames"
errors = errors :+ ValidationMessage(Severity.Error, "AppConfig", msg)
} else {
this.connections.get(this.connectionRef) match {
Expand Down
21 changes: 11 additions & 10 deletions src/main/scala/ai/starlake/extract/ExtractBigQuerySchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class ExtractBigQuerySchema(config: BigQueryTablesConfig)(implicit settings: Set
tablesToExtract: Map[String, List[String]]
): List[Domain] = {
val datasetNames = tablesToExtract.keys.toList
val lowercaseDatasetNames = tablesToExtract.keys.map(_.toLowerCase()).toList
val filteredDatasets =
if (datasetNames.size == 1) {
// We optimize extraction for a single dataset
Expand All @@ -69,10 +68,10 @@ class ExtractBigQuerySchema(config: BigQueryTablesConfig)(implicit settings: Set
.iterateAll()
.asScala
.filter(ds =>
datasetNames.isEmpty || lowercaseDatasetNames.contains(
ds.getDatasetId.getDataset().toLowerCase()
)
datasetNames.isEmpty || tablesToExtract.keys
.exists(_.equalsIgnoreCase(ds.getDatasetId.getDataset))
)

}
filteredDatasets.map { dataset =>
extractDataset(schemaHandler, dataset)
Expand All @@ -98,21 +97,23 @@ class ExtractBigQuerySchema(config: BigQueryTablesConfig)(implicit settings: Set
val tables =
schemaHandler.domains().find(_.finalName.equalsIgnoreCase(datasetName)) match {
case Some(domain) =>
val tablesToExclude = domain.tables.map(_.finalName.toLowerCase())
allTables.filterNot(t => tablesToExclude.contains(t.getTableId.getTable().toLowerCase()))
val tablesToExclude = domain.tables.map(_.finalName)
allTables.filterNot(t =>
tablesToExclude.exists(_.equalsIgnoreCase(t.getTableId.getTable))
)
case None => allTables
}
val schemas = tables.flatMap { bqTable =>
logger.info(s"Extracting table $datasetName.${bqTable.getTableId.getTable()}")
logger.info(s"Extracting table $datasetName.${bqTable.getTableId.getTable}")
// We get the Table again below because Tables are returned with a null definition by listTables above.
Try(bigquery.getTable(bqTable.getTableId())) match {
Try(bigquery.getTable(bqTable.getTableId)) match {
case scala.util.Success(tableWithDefinition) =>
if (tableWithDefinition.getDefinition().isInstanceOf[StandardTableDefinition])
if (tableWithDefinition.getDefinition.isInstanceOf[StandardTableDefinition])
Some(extractTable(tableWithDefinition))
else
None
case scala.util.Failure(e) =>
logger.error(s"Failed to get table ${bqTable.getTableId()}", e)
logger.error(s"Failed to get table ${bqTable.getTableId}", e)
None
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/ai/starlake/extract/ExtractDataJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ class ExtractDataJob(schemaHandler: SchemaHandler) extends ExtractPathHelper wit
.filter { s =>
(config.includeSchemas, config.excludeSchemas) match {
case (Nil, Nil) => true
case (inc, Nil) => inc.map(_.toLowerCase).contains(s.schema.toLowerCase)
case (Nil, exc) => !exc.map(_.toLowerCase).contains(s.schema.toLowerCase)
case (inc, Nil) => inc.exists(_.equalsIgnoreCase(s.schema))
case (Nil, exc) => !exc.exists(_.equalsIgnoreCase(s.schema))
case (_, _) =>
throw new RuntimeException(
"You can't specify includeShemas and excludeSchemas at the same time"
Expand Down
151 changes: 92 additions & 59 deletions src/main/scala/ai/starlake/extract/JdbcDbUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import ai.starlake.extract.JdbcDbUtils.{lastExportTableName, Columns}
import ai.starlake.job.Main
import ai.starlake.schema.model._
import ai.starlake.sql.SQLUtils
import ai.starlake.sql.SQLUtils.sqlCased
import ai.starlake.tests.StarlakeTestData.DomainName
import ai.starlake.utils.{SparkUtils, Utils}
import com.manticore.jsqlformatter.JSQLFormatter
Expand Down Expand Up @@ -217,7 +218,7 @@ object JdbcDbUtils extends LazyLogging {
}

@throws[Exception]
def createSchema(conn: SQLConnection, domainName: String): Unit = {
def createSchema(conn: SQLConnection, domainName: String)(implicit settings: Settings): Unit = {
executeUpdate(schemaCreateSQL(domainName), conn) match {
case Success(_) =>
case Failure(e) =>
Expand All @@ -227,15 +228,15 @@ object JdbcDbUtils extends LazyLogging {
}

@throws[Exception]
def schemaCreateSQL(domainName: String): String = {
s"CREATE SCHEMA IF NOT EXISTS $domainName"
def schemaCreateSQL(domainName: String)(implicit settings: Settings): String = {
s"CREATE SCHEMA IF NOT EXISTS ${sqlCased(domainName)}"
}

def buildDropTableSQL(tableName: String): String = {
s"DROP TABLE IF EXISTS $tableName"
def buildDropTableSQL(tableName: String)(implicit settings: Settings): String = {
s"DROP TABLE IF EXISTS ${sqlCased(tableName)}"
}
@throws[Exception]
def dropTable(conn: SQLConnection, tableName: String): Unit = {
def dropTable(conn: SQLConnection, tableName: String)(implicit settings: Settings): Unit = {
executeUpdate(buildDropTableSQL(tableName), conn) match {
case Success(_) =>
case Failure(e) =>
Expand Down Expand Up @@ -290,6 +291,34 @@ object JdbcDbUtils extends LazyLogging {
stmt.close()
result
}
def executeQueryAsTable(
query: String,
connection: SQLConnection
): List[Map[String, String]] = {
val resultTable = ListBuffer[Map[String, String]]()
val statement = connection.createStatement()
try {
// Establish the connection
val resultSet = statement.executeQuery(query)

// Get column names
val metaData = resultSet.getMetaData
val columnCount = metaData.getColumnCount
val columnNames = (1 to columnCount).map(metaData.getColumnName)

// Process the result set
while (resultSet.next()) {
val row = columnNames
.map(name => name -> Option(resultSet.getObject(name)).map(_.toString).getOrElse("null"))
.toMap
resultTable += row
}
} finally {
statement.close()
}

resultTable.toList
}

def executeQueryAsTable(
query: String,
Expand Down Expand Up @@ -634,10 +663,12 @@ object JdbcDbUtils extends LazyLogging {
val jdbcServer = url.split(":")(1)
val jdbcEngine = settings.appConfig.jdbcEngines.get(jdbcServer)
val jdbcTableMap =
jdbcSchema.tables
.map(tblSchema => tblSchema.name.toUpperCase -> tblSchema)
.toMap
val uppercaseTableNames = jdbcTableMap.keys.toList
CaseInsensitiveMap(
jdbcSchema.tables
.map(tblSchema => tblSchema.name -> tblSchema)
.toMap
)
val tableNamesToExtract = jdbcTableMap.keys.toList
val schemaAndTableNames =
withJDBCConnection(readOnlyConnection(connectionSettings).options) { connection =>
val databaseMetaData = connection.getMetaData()
Expand All @@ -646,56 +677,53 @@ object JdbcDbUtils extends LazyLogging {
databaseMetaData,
jdbcSchema.schema
).map { schemaName =>
val lowerCasedExcludeTables = jdbcSchema.exclude.map(_.toLowerCase)

def tablesInScopePredicate(tablesToExtract: List[String] = Nil): TableName => Boolean =
(tableName: String) => {
!lowerCasedExcludeTables.contains(
tableName.toLowerCase
) && (tablesToExtract.isEmpty || tablesToExtract.contains(tableName.toUpperCase()))
!jdbcSchema.exclude.exists(
_.equalsIgnoreCase(tableName)
) && (tablesToExtract.isEmpty ||
tablesToExtract.exists(_.equalsIgnoreCase(tableName)))
}

val sqlDefinedTables = jdbcSchema.tables.filter(_.sql.isDefined).map(_.name)
val selectedTables = uppercaseTableNames match {
case list if list.isEmpty || list.contains("*") =>
extractTableNames(
schemaName,
jdbcSchema,
sqlDefinedTables,
tablesInScopePredicate(),
connectionSettings,
databaseMetaData,
skipRemarks,
jdbcEngine,
connection
)
case list =>
val extractedTableNames =
val selectedTables =
tableNamesToExtract match {
case list if list.isEmpty || list.contains("*") =>
extractTableNames(
schemaName,
jdbcSchema,
sqlDefinedTables,
tablesInScopePredicate(list),
tablesInScopePredicate(),
connectionSettings,
databaseMetaData,
skipRemarks,
jdbcEngine,
connection
)
val notExtractedTable = list.diff(
extractedTableNames
.map { case (tableName, _) => tableName }
.map(_.toUpperCase())
.toList
)
if (notExtractedTable.nonEmpty) {
val tablesNotExtractedStr = notExtractedTable.mkString(", ")
logger.warn(
s"The following tables where not extracted for $schemaName.${jdbcSchema.schema} : $tablesNotExtractedStr"
case list =>
val extractedTableNames =
extractTableNames(
schemaName,
jdbcSchema,
sqlDefinedTables,
tablesInScopePredicate(list),
connectionSettings,
databaseMetaData,
skipRemarks,
jdbcEngine,
connection
)
val notExtractedTable = list.diff(
extractedTableNames.map { case (tableName, _) => tableName }.toList
)
}
extractedTableNames
}
if (notExtractedTable.nonEmpty) {
val tablesNotExtractedStr = notExtractedTable.mkString(", ")
logger.warn(
s"The following tables where not extracted for $schemaName.${jdbcSchema.schema} : $tablesNotExtractedStr"
)
}
extractedTableNames
}
logger.whenDebugEnabled {
selectedTables.keys.foreach(table => logger.debug(s"Selected: $table"))
}
Expand Down Expand Up @@ -748,36 +776,41 @@ object JdbcDbUtils extends LazyLogging {
)
)
val primaryKeys = jdbcColumnMetadata.primaryKeys
val foreignKeys: Map[TableName, TableName] = jdbcColumnMetadata.foreignKeys
val foreignKeys: CaseInsensitiveMap[TableName] =
CaseInsensitiveMap(jdbcColumnMetadata.foreignKeys)
val columns: List[Attribute] = jdbcColumnMetadata.columns
logger.whenDebugEnabled {
columns
.foreach(column => logger.debug(s"column: $tableName.${column.name}"))
}
val jdbcCurrentTable = jdbcTableMap
.get(tableName.toUpperCase)
.get(tableName)
// Limit to the columns specified by the user if any
val currentTableRequestedColumns: Map[ColumnName, Option[ColumnName]] =
jdbcCurrentTable
.map(
_.columns.map(c =>
(if (keepOriginalName) c.name.toUpperCase.trim
else c.rename.getOrElse(c.name).toUpperCase.trim) -> c.rename
)
)
.getOrElse(Map.empty)
.toMap
val currentTableRequestedColumns: CaseInsensitiveMap[Option[ColumnName]] =
CaseInsensitiveMap(
jdbcCurrentTable
.map {
_.columns.map { c =>
val key =
if (keepOriginalName) c.name.trim
else c.rename.getOrElse(c.name).trim
key -> c.rename
}
}
.getOrElse(Map.empty)
.toMap
)
val currentFilter = jdbcCurrentTable.flatMap(_.filter)
val selectedColumns: List[Attribute] =
columns
.filter(col =>
currentTableRequestedColumns.isEmpty || currentTableRequestedColumns
.contains("*") || currentTableRequestedColumns
.contains(col.name.toUpperCase())
.contains("*") || currentTableRequestedColumns.keys
.exists(_.equalsIgnoreCase(col.name))
)
.map(c =>
c.copy(
foreignKey = foreignKeys.get(c.name.toUpperCase)
foreignKey = foreignKeys.get(c.name)
)
)
logger.whenDebugEnabled {
Expand Down
Loading
Loading