Skip to content

Commit 2100efc

Browse files
authored
Merge pull request #67 from hasura/gavin/trino-cli-schema-support
Trino CLI: multi-catalog support
2 parents 0188ca2 + 87ead47 commit 2100efc

1 file changed

Lines changed: 106 additions & 40 deletions

File tree

ndc-cli/src/main/kotlin/io/hasura/cli/TrinoConfigGenerator.kt

Lines changed: 106 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,82 @@ import io.hasura.ndc.common.TableSchemaRow
77
import io.hasura.ndc.common.TableType
88
import org.jooq.impl.DSL
99

10+
fun debug(message: String) {
11+
val logLevel = System.getenv("HASURA_LOG_LEVEL") ?: "info"
12+
if (logLevel.lowercase() == "debug") {
13+
println(message)
14+
}
15+
}
16+
17+
1018
object TrinoConfigGenerator : IConfigGenerator {
1119

1220
override fun generateConfig(
13-
jdbcUrl: JdbcUrlConfig,
21+
jdbcUrlConfig: JdbcUrlConfig,
1422
schemas: List<String>,
1523
fullyQualifyNames: Boolean,
1624
): ConnectorConfiguration {
17-
val jdbcUrlString = when (jdbcUrl) {
18-
is JdbcUrlConfig.Literal -> jdbcUrl.value
19-
is JdbcUrlConfig.EnvVar -> System.getenv(jdbcUrl.variable)
20-
?: throw IllegalArgumentException("Environment variable ${jdbcUrl.variable} not found")
25+
val jdbcUrlString = when (jdbcUrlConfig) {
26+
is JdbcUrlConfig.Literal -> jdbcUrlConfig.value
27+
is JdbcUrlConfig.EnvVar -> System.getenv(jdbcUrlConfig.variable)
28+
?: throw IllegalArgumentException("Environment variable ${jdbcUrlConfig.variable} not found")
2129
}
2230

23-
val ctx = DSL.using(jdbcUrlString)
31+
// Parse a Trino JDBC URL to extract the catalog and schema (if present)
32+
fun getCatalogAndSchemaFromJdbcUrl(jdbcUrl: String): Pair<String?, String?> {
33+
val contents = jdbcUrl.substringAfter("jdbc:trino://")
34+
val parts = contents.split("/", limit = 3)
35+
// Theoretical parts: ["localhost:8080", "catalog", "schema"]
36+
return when (parts.size) {
37+
1 -> null to null // No catalog or schema
38+
2 -> parts[1] to null // Only catalog, no schema
39+
3 -> parts[1] to parts[2] // Both catalog and schema
40+
else -> error("Unexpected JDBC URL format: $jdbcUrl. Expected format is 'jdbc:trino://host:port?user=name&password=pwd'")
41+
}
42+
}
43+
44+
val (parsedCatalog, parsedSchema) = getCatalogAndSchemaFromJdbcUrl(jdbcUrlString)
45+
if (parsedCatalog != null || parsedSchema != null) {
46+
error(
47+
"Trino JDBC URLs should not contain catalog or schema information. " +
48+
"Please provide an URL in the format 'jdbc:trino://host:port?user=name&password=pwd'"
49+
)
50+
}
2451

25-
// "current_schema" is a special Trino value that resolves to the current schema
26-
// See: https://trino.io/docs/current/functions/session.html#current_schema
27-
val sql = """
28-
SELECT table_name, column_name, data_type, is_nullable
29-
FROM information_schema.columns
30-
WHERE table_schema = current_schema
31-
""".trimIndent()
52+
// Parse the list of schemas to group the values by catalog
53+
// Results in: { catalogName: [schema1, schema2, ...] }
54+
val catalogToSchemas = schemas
55+
.groupBy { it.split(".").firstOrNull() }
56+
.mapValues { (catalog, schemaEntries) ->
57+
// If there are entries that are just the catalog name with no schema part, we should include all schemas for that catalog
58+
val hasEmptyCatalogEntry = schemaEntries.any { it == catalog }
59+
if (hasEmptyCatalogEntry) {
60+
// Return empty list to signal we want all schemas
61+
emptyList()
62+
} else {
63+
// Extract schema parts
64+
schemaEntries
65+
.map { it.split(".").drop(1).joinToString("") }
66+
.filter { it.isNotEmpty() } // Filter out any empty schema names
67+
}
68+
}
3269

70+
// For each catalog given, we need to execute a SQL query to fetch the table and column information.
71+
val query = catalogToSchemas.entries.joinToString("UNION ALL\n") { (catalog, catalogSchemas) ->
72+
"""
73+
SELECT table_catalog, table_schema, table_name, column_name, data_type, is_nullable
74+
FROM $catalog.information_schema.columns
75+
${
76+
if (catalogSchemas.isNotEmpty()) {
77+
debug("Filtering schemas for catalog $catalog: $catalogSchemas")
78+
"WHERE table_schema IN (${catalogSchemas.joinToString(",") { "'$it'" }})"
79+
} else {
80+
debug("No specific schemas for catalog $catalog, including all schemas")
81+
""
82+
}
83+
}
84+
""".trimIndent()
85+
}
3386

3487
// Take a string of the format "decimal(20)" or "decimal(20,2)" and extract the numeric precision and scale
3588
fun extractNumericPrecisionAndScale(
@@ -48,37 +101,50 @@ object TrinoConfigGenerator : IConfigGenerator {
48101
return precision to scale
49102
}
50103

51-
// fetch every column, use jOOQ's fetchGroups to group them by table_name
52-
val tables = ctx.resultQuery(sql).fetchGroups("table_name").map { (tableName, rows) ->
53-
TableSchemaRow(
54-
tableName = tableName as String,
55-
tableType = TableType.TABLE,
56-
description = null,
57-
pks = emptyList(),
58-
fks = emptyMap(),
59-
columns = rows.map { row ->
60-
val dataType = row.get("data_type", String::class.java)
61-
val (numericPrecision, numericScale) = when {
62-
dataType.startsWith("decimal") -> extractNumericPrecisionAndScale(dataType)
63-
else -> null to null
64-
}
65-
ColumnSchemaRow(
66-
name = row.get("column_name", String::class.java),
67-
type = row.get("data_type", String::class.java),
68-
nullable = row.get("is_nullable", String::class.java) == "YES",
69-
auto_increment = false,
70-
is_primarykey = false,
71-
description = null,
72-
numeric_precision = numericPrecision,
73-
numeric_scale = numericScale
74-
)
75-
},
76-
)
104+
// fetch every column, use jOOQ's fetchGroups to group them by (catalog, schema, table)
105+
debug("Executing query to fetch table and column information...")
106+
query.lines().forEach { line ->
107+
debug(" $line")
77108
}
78109

110+
val tables = DSL.using(jdbcUrlString)
111+
.resultQuery(query)
112+
.fetchGroups { record ->
113+
val catalog = record.get("table_catalog", String::class.java)
114+
val schema = record.get("table_schema", String::class.java)
115+
val tableName = record.get("table_name", String::class.java)
116+
"${catalog}.${schema}.${tableName}"
117+
}
118+
.map { (tableName, rows) ->
119+
TableSchemaRow(
120+
tableName = tableName as String,
121+
tableType = TableType.TABLE,
122+
description = null,
123+
pks = emptyList(),
124+
fks = emptyMap(),
125+
columns = rows.map { row ->
126+
val dataType = row.get("data_type", String::class.java)
127+
val (numericPrecision, numericScale) = when {
128+
dataType.startsWith("decimal") -> extractNumericPrecisionAndScale(dataType)
129+
else -> null to null
130+
}
131+
ColumnSchemaRow(
132+
name = row.get("column_name", String::class.java),
133+
type = row.get("data_type", String::class.java),
134+
nullable = row.get("is_nullable", String::class.java) == "YES",
135+
auto_increment = false,
136+
is_primarykey = false,
137+
description = null,
138+
numeric_precision = numericPrecision,
139+
numeric_scale = numericScale
140+
)
141+
},
142+
)
143+
}
144+
79145

80146
return ConnectorConfiguration(
81-
jdbcUrl = jdbcUrl,
147+
jdbcUrl = jdbcUrlConfig,
82148
jdbcProperties = emptyMap(),
83149
tables = tables,
84150
functions = emptyList()

0 commit comments

Comments
 (0)