@@ -7,29 +7,82 @@ import io.hasura.ndc.common.TableSchemaRow
77import io.hasura.ndc.common.TableType
88import org.jooq.impl.DSL
99
10+ fun debug (message : String ) {
11+ val logLevel = System .getenv(" HASURA_LOG_LEVEL" ) ? : " info"
12+ if (logLevel.lowercase() == " debug" ) {
13+ println (message)
14+ }
15+ }
16+
17+
1018object TrinoConfigGenerator : IConfigGenerator {
1119
1220 override fun generateConfig (
13- jdbcUrl : JdbcUrlConfig ,
21+ jdbcUrlConfig : JdbcUrlConfig ,
1422 schemas : List <String >,
1523 fullyQualifyNames : Boolean ,
1624 ): ConnectorConfiguration {
17- val jdbcUrlString = when (jdbcUrl ) {
18- is JdbcUrlConfig .Literal -> jdbcUrl .value
19- is JdbcUrlConfig .EnvVar -> System .getenv(jdbcUrl .variable)
20- ? : throw IllegalArgumentException (" Environment variable ${jdbcUrl .variable} not found" )
25+ val jdbcUrlString = when (jdbcUrlConfig ) {
26+ is JdbcUrlConfig .Literal -> jdbcUrlConfig .value
27+ is JdbcUrlConfig .EnvVar -> System .getenv(jdbcUrlConfig .variable)
28+ ? : throw IllegalArgumentException (" Environment variable ${jdbcUrlConfig .variable} not found" )
2129 }
2230
23- val ctx = DSL .using(jdbcUrlString)
31+ // Parse a Trino JDBC URL to extract the catalog and schema (if present)
32+ fun getCatalogAndSchemaFromJdbcUrl (jdbcUrl : String ): Pair <String ?, String ?> {
33+ val contents = jdbcUrl.substringAfter(" jdbc:trino://" )
34+ val parts = contents.split(" /" , limit = 3 )
35+ // Theoretical parts: ["localhost:8080", "catalog", "schema"]
36+ return when (parts.size) {
37+ 1 -> null to null // No catalog or schema
38+ 2 -> parts[1 ] to null // Only catalog, no schema
39+ 3 -> parts[1 ] to parts[2 ] // Both catalog and schema
40+ else -> error(" Unexpected JDBC URL format: $jdbcUrl . Expected format is 'jdbc:trino://host:port?user=name&password=pwd'" )
41+ }
42+ }
43+
44+ val (parsedCatalog, parsedSchema) = getCatalogAndSchemaFromJdbcUrl(jdbcUrlString)
45+ if (parsedCatalog != null || parsedSchema != null ) {
46+ error(
47+ " Trino JDBC URLs should not contain catalog or schema information. " +
48+ " Please provide an URL in the format 'jdbc:trino://host:port?user=name&password=pwd'"
49+ )
50+ }
2451
25- // "current_schema" is a special Trino value that resolves to the current schema
26- // See: https://trino.io/docs/current/functions/session.html#current_schema
27- val sql = """
28- SELECT table_name, column_name, data_type, is_nullable
29- FROM information_schema.columns
30- WHERE table_schema = current_schema
31- """ .trimIndent()
52+ // Parse the list of schemas to group the values by catalog
53+ // Results in: { catalogName: [schema1, schema2, ...] }
54+ val catalogToSchemas = schemas
55+ .groupBy { it.split(" ." ).firstOrNull() }
56+ .mapValues { (catalog, schemaEntries) ->
57+ // If there are entries that are just the catalog name with no schema part, we should include all schemas for that catalog
58+ val hasEmptyCatalogEntry = schemaEntries.any { it == catalog }
59+ if (hasEmptyCatalogEntry) {
60+ // Return empty list to signal we want all schemas
61+ emptyList()
62+ } else {
63+ // Extract schema parts
64+ schemaEntries
65+ .map { it.split(" ." ).drop(1 ).joinToString(" " ) }
66+ .filter { it.isNotEmpty() } // Filter out any empty schema names
67+ }
68+ }
3269
70+ // For each catalog given, we need to execute a SQL query to fetch the table and column information.
71+ val query = catalogToSchemas.entries.joinToString(" UNION ALL\n " ) { (catalog, catalogSchemas) ->
72+ """
73+ SELECT table_catalog, table_schema, table_name, column_name, data_type, is_nullable
74+ FROM $catalog .information_schema.columns
75+ ${
76+ if (catalogSchemas.isNotEmpty()) {
77+ debug(" Filtering schemas for catalog $catalog : $catalogSchemas " )
78+ " WHERE table_schema IN (${catalogSchemas.joinToString(" ," ) { " '$it '" }} )"
79+ } else {
80+ debug(" No specific schemas for catalog $catalog , including all schemas" )
81+ " "
82+ }
83+ }
84+ """ .trimIndent()
85+ }
3386
3487 // Take a string of the format "decimal(20)" or "decimal(20,2)" and extract the numeric precision and scale
3588 fun extractNumericPrecisionAndScale (
@@ -48,37 +101,50 @@ object TrinoConfigGenerator : IConfigGenerator {
48101 return precision to scale
49102 }
50103
51- // fetch every column, use jOOQ's fetchGroups to group them by table_name
52- val tables = ctx.resultQuery(sql).fetchGroups(" table_name" ).map { (tableName, rows) ->
53- TableSchemaRow (
54- tableName = tableName as String ,
55- tableType = TableType .TABLE ,
56- description = null ,
57- pks = emptyList(),
58- fks = emptyMap(),
59- columns = rows.map { row ->
60- val dataType = row.get(" data_type" , String ::class .java)
61- val (numericPrecision, numericScale) = when {
62- dataType.startsWith(" decimal" ) -> extractNumericPrecisionAndScale(dataType)
63- else -> null to null
64- }
65- ColumnSchemaRow (
66- name = row.get(" column_name" , String ::class .java),
67- type = row.get(" data_type" , String ::class .java),
68- nullable = row.get(" is_nullable" , String ::class .java) == " YES" ,
69- auto_increment = false ,
70- is_primarykey = false ,
71- description = null ,
72- numeric_precision = numericPrecision,
73- numeric_scale = numericScale
74- )
75- },
76- )
104+ // fetch every column, use jOOQ's fetchGroups to group them by (catalog, schema, table)
105+ debug(" Executing query to fetch table and column information..." )
106+ query.lines().forEach { line ->
107+ debug(" $line " )
77108 }
78109
110+ val tables = DSL .using(jdbcUrlString)
111+ .resultQuery(query)
112+ .fetchGroups { record ->
113+ val catalog = record.get(" table_catalog" , String ::class .java)
114+ val schema = record.get(" table_schema" , String ::class .java)
115+ val tableName = record.get(" table_name" , String ::class .java)
116+ " ${catalog} .${schema} .${tableName} "
117+ }
118+ .map { (tableName, rows) ->
119+ TableSchemaRow (
120+ tableName = tableName as String ,
121+ tableType = TableType .TABLE ,
122+ description = null ,
123+ pks = emptyList(),
124+ fks = emptyMap(),
125+ columns = rows.map { row ->
126+ val dataType = row.get(" data_type" , String ::class .java)
127+ val (numericPrecision, numericScale) = when {
128+ dataType.startsWith(" decimal" ) -> extractNumericPrecisionAndScale(dataType)
129+ else -> null to null
130+ }
131+ ColumnSchemaRow (
132+ name = row.get(" column_name" , String ::class .java),
133+ type = row.get(" data_type" , String ::class .java),
134+ nullable = row.get(" is_nullable" , String ::class .java) == " YES" ,
135+ auto_increment = false ,
136+ is_primarykey = false ,
137+ description = null ,
138+ numeric_precision = numericPrecision,
139+ numeric_scale = numericScale
140+ )
141+ },
142+ )
143+ }
144+
79145
80146 return ConnectorConfiguration (
81- jdbcUrl = jdbcUrl ,
147+ jdbcUrl = jdbcUrlConfig ,
82148 jdbcProperties = emptyMap(),
83149 tables = tables,
84150 functions = emptyList()
0 commit comments