From bab2eccade932c5d811e16c3aa0fea3b34023667 Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 24 Mar 2026 02:19:48 +0400 Subject: [PATCH 1/4] feat: get_objects with single query --- go/connection.go | 4 +- go/connection_getobjects.go | 375 ++++++++++++++++-------------------- go/mysql.go | 5 +- 3 files changed, 169 insertions(+), 215 deletions(-) diff --git a/go/connection.go b/go/connection.go index 9ed89df..4dbc733 100644 --- a/go/connection.go +++ b/go/connection.go @@ -38,7 +38,7 @@ const ( // GetCurrentCatalog implements driverbase.CurrentNamespacer. func (c *mysqlConnectionImpl) GetCurrentCatalog() (string, error) { var database string - err := c.Db.QueryRowContext(context.Background(), "SELECT DATABASE()").Scan(&database) + err := c.Conn.QueryRowContext(context.Background(), "SELECT DATABASE()").Scan(&database) if err != nil { return "", c.ErrorHelper.WrapIO(err, "failed to get current database") } @@ -55,7 +55,7 @@ func (c *mysqlConnectionImpl) GetCurrentDbSchema() (string, error) { // SetCurrentCatalog implements driverbase.CurrentNamespacer. func (c *mysqlConnectionImpl) SetCurrentCatalog(catalog string) error { - _, err := c.Db.ExecContext(context.Background(), "USE "+quoteIdentifier(catalog)) + _, err := c.Conn.ExecContext(context.Background(), "USE "+quoteIdentifier(catalog)) return err } diff --git a/go/connection_getobjects.go b/go/connection_getobjects.go index cacfd65..46d61d7 100644 --- a/go/connection_getobjects.go +++ b/go/connection_getobjects.go @@ -22,149 +22,41 @@ import ( "strings" "github.com/adbc-drivers/driverbase-go/driverbase" + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow-go/v18/arrow/array" ) -func (c *mysqlConnectionImpl) GetCatalogs(ctx context.Context, catalogFilter *string) (catalogs []string, err error) { - // In MySQL JDBC, getCatalogs() returns database names (catalogs are databases) - // Build query using strings.Builder - var queryBuilder strings.Builder - queryBuilder.WriteString("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA") - args := []any{} - - if catalogFilter != nil { - queryBuilder.WriteString(" WHERE SCHEMA_NAME LIKE ?") - args = append(args, *catalogFilter) - } - - queryBuilder.WriteString(" ORDER BY SCHEMA_NAME") - - rows, err := c.Db.QueryContext(ctx, queryBuilder.String(), args...) - if err != nil { - return nil, c.ErrorHelper.WrapIO(err, "failed to query catalogs") - } - defer func() { - err = errors.Join(err, rows.Close()) - }() - - catalogs = make([]string, 0) - for rows.Next() { - var catalog string - if err := rows.Scan(&catalog); err != nil { - return nil, c.ErrorHelper.WrapIO(err, "failed to scan catalog") - } - catalogs = append(catalogs, catalog) - } - - if err := rows.Err(); err != nil { - return nil, c.ErrorHelper.WrapIO(err, "error during catalog iteration") - } - - return catalogs, err -} - -func (c *mysqlConnectionImpl) GetDBSchemasForCatalog(ctx context.Context, catalog string, schemaFilter *string) (schemas []string, err error) { - // In MySQL JDBC, getSchemas() returns empty - schemas are not supported - // For ADBC GetObjects, we return empty string as schema to maintain the hierarchy - // This allows: catalog (db name) -> schema ("") -> tables - - // Apply schema filter - only empty string matches our single schema - if schemaFilter != nil { - matches, err := filepath.Match(*schemaFilter, "") +func (c *mysqlConnectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { + // MySQL has no real schema concept; we model it as a single empty-string schema. + // If the caller filters on a schema that doesn't match "", return catalogs only. + includeSchemas := true + if dbSchema != nil { + matches, err := filepath.Match(*dbSchema, "") if err != nil { return nil, c.ErrorHelper.WrapInvalidArgument(err, "invalid schema filter pattern") } if !matches { - return []string{}, nil // Schema filter doesn't match empty string + includeSchemas = false } } - // Return empty string as the single schema for this catalog - return []string{""}, nil -} - -func (c *mysqlConnectionImpl) GetTablesForDBSchema(ctx context.Context, catalog string, schema string, tableFilter *string, columnFilter *string, includeColumns bool) (tables []driverbase.TableInfo, err error) { - if includeColumns { - return c.getTablesWithColumns(ctx, catalog, schema, tableFilter, columnFilter) + // Determine effective depth: if schema filter doesn't match, cap at catalogs. + effectiveDepth := depth + if !includeSchemas && effectiveDepth != adbc.ObjectDepthCatalogs { + effectiveDepth = adbc.ObjectDepthCatalogs } - return c.getTablesOnly(ctx, catalog, schema, tableFilter) -} -// getTablesOnly retrieves table information without columns -func (c *mysqlConnectionImpl) getTablesOnly(ctx context.Context, catalog string, schema string, tableFilter *string) (tables []driverbase.TableInfo, err error) { - // In MySQL JDBC, catalog is the database name and schema should be empty - if schema != "" { - return []driverbase.TableInfo{}, nil // No tables for non-empty schemas - } + // Build a single query: SCHEMATA LEFT JOIN TABLES LEFT JOIN COLUMNS. + // Deeper levels are disabled with AND 1=0 in the join condition. + includeTables := effectiveDepth == adbc.ObjectDepthTables || effectiveDepth == adbc.ObjectDepthColumns + includeColumns := effectiveDepth == adbc.ObjectDepthColumns - // Build query using strings.Builder var queryBuilder strings.Builder - queryBuilder.WriteString(` - SELECT - TABLE_NAME, - TABLE_TYPE - FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_SCHEMA = ?`) - - args := []any{catalog} - - if tableFilter != nil { - queryBuilder.WriteString(` AND TABLE_NAME LIKE ?`) - args = append(args, *tableFilter) - } - - queryBuilder.WriteString(` ORDER BY TABLE_NAME`) - - rows, err := c.Db.QueryContext(ctx, queryBuilder.String(), args...) - if err != nil { - return nil, c.ErrorHelper.WrapIO(err, "failed to query tables for catalog %s", catalog) - } - defer func() { - err = errors.Join(err, rows.Close()) - }() - - tables = make([]driverbase.TableInfo, 0) - for rows.Next() { - var tableName, tableType string - if err := rows.Scan(&tableName, &tableType); err != nil { - return nil, c.ErrorHelper.WrapIO(err, "failed to scan table info") - } - - tables = append(tables, driverbase.TableInfo{ - TableName: tableName, - TableType: tableType, - }) - } - - if err := rows.Err(); err != nil { - return nil, c.ErrorHelper.WrapIO(err, "error during table iteration") - } - - return tables, err -} - -// getTablesWithColumns retrieves complete table and column information -func (c *mysqlConnectionImpl) getTablesWithColumns(ctx context.Context, catalog string, schema string, tableFilter *string, columnFilter *string) (tables []driverbase.TableInfo, err error) { - // In MySQL JDBC, catalog is the database name and schema should be empty - if schema != "" { - return []driverbase.TableInfo{}, nil // No tables for non-empty schemas - } - - type tableColumn struct { - TableName string - TableType string - OrdinalPosition int32 - ColumnName string - ColumnComment sql.NullString - DataType string - ColumnType string - IsNullable string - ColumnDefault sql.NullString - } + args := []any{} - // Build query using strings.Builder - var queryBuilder strings.Builder queryBuilder.WriteString(` SELECT + s.SCHEMA_NAME, t.TABLE_NAME, t.TABLE_TYPE, c.ORDINAL_POSITION, @@ -174,123 +66,188 @@ func (c *mysqlConnectionImpl) getTablesWithColumns(ctx context.Context, catalog c.COLUMN_TYPE, c.IS_NULLABLE, c.COLUMN_DEFAULT - FROM INFORMATION_SCHEMA.TABLES t - INNER JOIN INFORMATION_SCHEMA.COLUMNS c - ON t.TABLE_SCHEMA = c.TABLE_SCHEMA - AND t.TABLE_NAME = c.TABLE_NAME - WHERE t.TABLE_SCHEMA = ?`) + FROM INFORMATION_SCHEMA.SCHEMATA s + LEFT JOIN INFORMATION_SCHEMA.TABLES t + ON s.SCHEMA_NAME = t.TABLE_SCHEMA`) + + if !includeTables { + queryBuilder.WriteString(` AND 1=0`) + } else { + if tableName != nil { + queryBuilder.WriteString(` AND t.TABLE_NAME LIKE ?`) + args = append(args, *tableName) + } + if len(tableType) > 0 { + queryBuilder.WriteString(` AND t.TABLE_TYPE IN (` + placeholders(len(tableType)) + `)`) + for _, tt := range tableType { + args = append(args, tt) + } + } + } - args := []any{catalog} + queryBuilder.WriteString(` + LEFT JOIN INFORMATION_SCHEMA.COLUMNS c + ON t.TABLE_SCHEMA = c.TABLE_SCHEMA + AND t.TABLE_NAME = c.TABLE_NAME`) - if tableFilter != nil { - queryBuilder.WriteString(` AND t.TABLE_NAME LIKE ?`) - args = append(args, *tableFilter) - } - if columnFilter != nil { + if !includeColumns { + queryBuilder.WriteString(` AND 1=0`) + } else if columnName != nil { queryBuilder.WriteString(` AND c.COLUMN_NAME LIKE ?`) - args = append(args, *columnFilter) + args = append(args, *columnName) + } + + if catalog != nil { + queryBuilder.WriteString(` WHERE s.SCHEMA_NAME LIKE ?`) + args = append(args, *catalog) } - queryBuilder.WriteString(` ORDER BY t.TABLE_NAME, c.ORDINAL_POSITION`) + queryBuilder.WriteString(` ORDER BY s.SCHEMA_NAME, t.TABLE_NAME, c.ORDINAL_POSITION`) - rows, err := c.Db.QueryContext(ctx, queryBuilder.String(), args...) + rows, err := c.Conn.QueryContext(ctx, queryBuilder.String(), args...) if err != nil { - return nil, c.ErrorHelper.WrapIO(err, "failed to query tables with columns for catalog %s", catalog) + return nil, c.ErrorHelper.WrapIO(err, "failed to query objects") } defer func() { err = errors.Join(err, rows.Close()) }() - tables = make([]driverbase.TableInfo, 0) + // Group rows into the GetObjectsInfo hierarchy. + var infos []driverbase.GetObjectsInfo + var currentInfo *driverbase.GetObjectsInfo var currentTable *driverbase.TableInfo + includeDbSchemas := effectiveDepth != adbc.ObjectDepthCatalogs + for rows.Next() { - var tc tableColumn + var ( + schema string + tblName sql.NullString + tblType sql.NullString + ordinalPosition sql.NullInt32 + colName sql.NullString + colComment sql.NullString + dataType sql.NullString + colType sql.NullString + isNullable sql.NullString + colDefault sql.NullString + ) if err := rows.Scan( - &tc.TableName, &tc.TableType, - &tc.OrdinalPosition, &tc.ColumnName, &tc.ColumnComment, - &tc.DataType, &tc.ColumnType, &tc.IsNullable, &tc.ColumnDefault, + &schema, &tblName, &tblType, + &ordinalPosition, &colName, &colComment, + &dataType, &colType, &isNullable, &colDefault, ); err != nil { - return nil, c.ErrorHelper.WrapIO(err, "failed to scan table with columns") + return nil, c.ErrorHelper.WrapIO(err, "failed to scan objects row") } - // Check if we need to create a new table entry - if currentTable == nil || currentTable.TableName != tc.TableName { - tables = append(tables, driverbase.TableInfo{ - TableName: tc.TableName, - TableType: tc.TableType, - }) - currentTable = &tables[len(tables)-1] + // New catalog? + if currentInfo == nil || *currentInfo.CatalogName != schema { + info := driverbase.GetObjectsInfo{CatalogName: driverbase.Nullable(schema)} + if includeDbSchemas { + info.CatalogDbSchemas = []driverbase.DBSchemaInfo{{DbSchemaName: driverbase.Nullable("")}} + } + infos = append(infos, info) + currentInfo = &infos[len(infos)-1] + currentTable = nil } - // Process column data - var radix sql.NullInt16 - var nullable sql.NullInt16 - - // Build the full type name including UNSIGNED if applicable - // Only check integer types to avoid false positives with enum/set value lists - xdbcTypeName := tc.DataType - switch strings.ToUpper(tc.DataType) { - case "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT": - if strings.Contains(strings.ToUpper(tc.ColumnType), "UNSIGNED") { - xdbcTypeName = tc.DataType + " UNSIGNED" - } + if !tblName.Valid { + continue } - // Set numeric precision radix (MySQL doesn't store this directly) - dataType := strings.ToUpper(tc.DataType) - switch dataType { - // Binary radix (base 2) - case "BIT": - radix = sql.NullInt16{Int16: 2, Valid: true} + // New table? + tables := ¤tInfo.CatalogDbSchemas[0].DbSchemaTables + if currentTable == nil || currentTable.TableName != tblName.String { + *tables = append(*tables, driverbase.TableInfo{ + TableName: tblName.String, + TableType: tblType.String, + }) + currentTable = &(*tables)[len(*tables)-1] + } - // Decimal radix (base 10) - integer types - case "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "INTEGER", "BIGINT": - radix = sql.NullInt16{Int16: 10, Valid: true} + if !colName.Valid { + continue + } - // Decimal radix (base 10) - decimal/numeric types - case "DECIMAL", "DEC", "NUMERIC", "FIXED": - radix = sql.NullInt16{Int16: 10, Valid: true} + currentTable.TableColumns = append(currentTable.TableColumns, + buildColumnInfo(dataType.String, colType.String, colName.String, isNullable.String, + ordinalPosition.Int32, colComment, colDefault)) + } - // Decimal radix (base 10) - floating point types - case "FLOAT", "DOUBLE", "DOUBLE PRECISION", "REAL": - radix = sql.NullInt16{Int16: 10, Valid: true} + if err := rows.Err(); err != nil { + return nil, c.ErrorHelper.WrapIO(err, "error during objects iteration") + } - // Decimal radix (base 10) - year type - case "YEAR": - radix = sql.NullInt16{Int16: 10, Valid: true} + return buildResult(c, infos) +} - // No radix for non-numeric types - default: - radix = sql.NullInt16{Valid: false} +// buildColumnInfo constructs a ColumnInfo from raw MySQL column metadata. +func buildColumnInfo(dataType, columnType, columnName, isNullable string, ordinalPosition int32, columnComment, columnDefault sql.NullString) driverbase.ColumnInfo { + var radix sql.NullInt16 + var nullable sql.NullInt16 + + // Build the full type name including UNSIGNED if applicable + // Only check integer types to avoid false positives with enum/set value lists + xdbcTypeName := dataType + switch strings.ToUpper(dataType) { + case "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT": + if strings.Contains(strings.ToUpper(columnType), "UNSIGNED") { + xdbcTypeName = dataType + " UNSIGNED" } + } - // Set nullable information - switch tc.IsNullable { - case "YES": - nullable = sql.NullInt16{Int16: int16(driverbase.XdbcColumnNullable), Valid: true} - case "NO": - nullable = sql.NullInt16{Int16: int16(driverbase.XdbcColumnNoNulls), Valid: true} - } + // Set numeric precision radix (MySQL doesn't store this directly) + switch strings.ToUpper(dataType) { + case "BIT": + radix = sql.NullInt16{Int16: 2, Valid: true} + case "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "INTEGER", "BIGINT", + "DECIMAL", "DEC", "NUMERIC", "FIXED", + "FLOAT", "DOUBLE", "DOUBLE PRECISION", "REAL", + "YEAR": + radix = sql.NullInt16{Int16: 10, Valid: true} + default: + radix = sql.NullInt16{Valid: false} + } + + // Set nullable information + switch isNullable { + case "YES": + nullable = sql.NullInt16{Int16: int16(driverbase.XdbcColumnNullable), Valid: true} + case "NO": + nullable = sql.NullInt16{Int16: int16(driverbase.XdbcColumnNoNulls), Valid: true} + } + + return driverbase.ColumnInfo{ + ColumnName: columnName, + OrdinalPosition: &ordinalPosition, + Remarks: driverbase.NullStringToPtr(columnComment), + XdbcTypeName: &xdbcTypeName, + XdbcNumPrecRadix: driverbase.NullInt16ToPtr(radix), + XdbcNullable: driverbase.NullInt16ToPtr(nullable), + XdbcIsNullable: &isNullable, + XdbcColumnDef: driverbase.NullStringToPtr(columnDefault), + } +} - currentTable.TableColumns = append(currentTable.TableColumns, driverbase.ColumnInfo{ - ColumnName: tc.ColumnName, - OrdinalPosition: &tc.OrdinalPosition, - Remarks: driverbase.NullStringToPtr(tc.ColumnComment), - XdbcTypeName: &xdbcTypeName, - XdbcNumPrecRadix: driverbase.NullInt16ToPtr(radix), - XdbcNullable: driverbase.NullInt16ToPtr(nullable), - XdbcIsNullable: &tc.IsNullable, - XdbcColumnDef: driverbase.NullStringToPtr(tc.ColumnDefault), - }) +// placeholders returns a comma-separated string of n question marks. +func placeholders(n int) string { + if n <= 0 { + return "" } + return strings.Repeat("?,", n-1) + "?" +} - if err := rows.Err(); err != nil { - return nil, c.ErrorHelper.WrapIO(err, "error during table with columns iteration") +// buildResult feeds GetObjectsInfo entries into BuildGetObjectsRecordReader. +func buildResult(c *mysqlConnectionImpl, infos []driverbase.GetObjectsInfo) (array.RecordReader, error) { + ch := make(chan driverbase.GetObjectsInfo, len(infos)) + for _, info := range infos { + ch <- info } + close(ch) - // TODO: Add constraint and foreign key metadata support + errCh := make(chan error, 1) + close(errCh) - return tables, err + return driverbase.BuildGetObjectsRecordReader(c.Alloc, ch, errCh) } diff --git a/go/mysql.go b/go/mysql.go index cde57d2..22f0185 100644 --- a/go/mysql.go +++ b/go/mysql.go @@ -260,7 +260,7 @@ func (m *mySQLTypeConverter) ConvertArrowToGo(arrowArray arrow.Array, index int, } } -// mysqlConnectionImpl extends sqlwrapper connection with DbObjectsEnumerator +// mysqlConnectionImpl extends sqlwrapper connection with MySQL-specific functionality type mysqlConnectionImpl struct { *sqlwrapper.ConnectionImplBase // Embed sqlwrapper connection for all standard functionality @@ -270,9 +270,6 @@ type mysqlConnectionImpl struct { // implements BulkIngester interface var _ sqlwrapper.BulkIngester = (*mysqlConnectionImpl)(nil) -// implements DbObjectsEnumerator interface -var _ driverbase.DbObjectsEnumerator = (*mysqlConnectionImpl)(nil) - // implements CurrentNameSpacer interface var _ driverbase.CurrentNamespacer = (*mysqlConnectionImpl)(nil) From 8d05be5ef2a7dc5d3bd9266d3130dfb10e53049b Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 24 Mar 2026 03:11:36 +0400 Subject: [PATCH 2/4] fix: handle empty results --- go/connection_getobjects.go | 89 ++++++++++++++++++++++--------------- 1 file changed, 52 insertions(+), 37 deletions(-) diff --git a/go/connection_getobjects.go b/go/connection_getobjects.go index 46d61d7..725f1d9 100644 --- a/go/connection_getobjects.go +++ b/go/connection_getobjects.go @@ -18,7 +18,6 @@ import ( "context" "database/sql" "errors" - "path/filepath" "strings" "github.com/adbc-drivers/driverbase-go/driverbase" @@ -26,37 +25,26 @@ import ( "github.com/apache/arrow-go/v18/arrow/array" ) +// GetObjects implements adbc.Connection by running a single query on the +// session connection (c.Conn) so that session-scoped objects like temporary +// tables are visible in the results. +// +// The query joins SCHEMATA, a synthetic schema subquery, TABLES, and COLUMNS. +// Levels beyond the requested depth are disabled with AND 1=0 in the join +// condition. All filters (catalog, schema, table, column, table type) are +// applied in SQL. func (c *mysqlConnectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { - // MySQL has no real schema concept; we model it as a single empty-string schema. - // If the caller filters on a schema that doesn't match "", return catalogs only. - includeSchemas := true - if dbSchema != nil { - matches, err := filepath.Match(*dbSchema, "") - if err != nil { - return nil, c.ErrorHelper.WrapInvalidArgument(err, "invalid schema filter pattern") - } - if !matches { - includeSchemas = false - } - } - - // Determine effective depth: if schema filter doesn't match, cap at catalogs. - effectiveDepth := depth - if !includeSchemas && effectiveDepth != adbc.ObjectDepthCatalogs { - effectiveDepth = adbc.ObjectDepthCatalogs - } - - // Build a single query: SCHEMATA LEFT JOIN TABLES LEFT JOIN COLUMNS. - // Deeper levels are disabled with AND 1=0 in the join condition. - includeTables := effectiveDepth == adbc.ObjectDepthTables || effectiveDepth == adbc.ObjectDepthColumns - includeColumns := effectiveDepth == adbc.ObjectDepthColumns + includeSchemas := depth != adbc.ObjectDepthCatalogs + includeTables := depth == adbc.ObjectDepthTables || depth == adbc.ObjectDepthColumns + includeColumns := depth == adbc.ObjectDepthColumns var queryBuilder strings.Builder args := []any{} queryBuilder.WriteString(` SELECT - s.SCHEMA_NAME, + s.SCHEMA_NAME AS CATALOG_NAME, + sch.DB_SCHEMA_NAME, t.TABLE_NAME, t.TABLE_TYPE, c.ORDINAL_POSITION, @@ -66,9 +54,27 @@ func (c *mysqlConnectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectD c.COLUMN_TYPE, c.IS_NULLABLE, c.COLUMN_DEFAULT - FROM INFORMATION_SCHEMA.SCHEMATA s + FROM INFORMATION_SCHEMA.SCHEMATA s`) + + // MySQL has no real schema concept. We model it as a single empty-string + // schema via a LEFT JOIN with a synthetic row. The schema filter is + // applied via LIKE on this column. AND 1=0 disables the join when + // depth is catalogs-only, producing NULL for DB_SCHEMA_NAME. + queryBuilder.WriteString(` + LEFT JOIN (SELECT '' AS DB_SCHEMA_NAME) sch + ON 1=1`) + + if !includeSchemas { + queryBuilder.WriteString(` AND 1=0`) + } else if dbSchema != nil { + queryBuilder.WriteString(` AND sch.DB_SCHEMA_NAME LIKE ?`) + args = append(args, *dbSchema) + } + + queryBuilder.WriteString(` LEFT JOIN INFORMATION_SCHEMA.TABLES t - ON s.SCHEMA_NAME = t.TABLE_SCHEMA`) + ON s.SCHEMA_NAME = t.TABLE_SCHEMA + AND sch.DB_SCHEMA_NAME IS NOT NULL`) if !includeTables { queryBuilder.WriteString(` AND 1=0`) @@ -117,11 +123,10 @@ func (c *mysqlConnectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectD var currentInfo *driverbase.GetObjectsInfo var currentTable *driverbase.TableInfo - includeDbSchemas := effectiveDepth != adbc.ObjectDepthCatalogs - for rows.Next() { var ( - schema string + catalogName string + schemaName sql.NullString tblName sql.NullString tblType sql.NullString ordinalPosition sql.NullInt32 @@ -134,7 +139,7 @@ func (c *mysqlConnectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectD ) if err := rows.Scan( - &schema, &tblName, &tblType, + &catalogName, &schemaName, &tblName, &tblType, &ordinalPosition, &colName, &colComment, &dataType, &colType, &isNullable, &colDefault, ); err != nil { @@ -142,10 +147,16 @@ func (c *mysqlConnectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectD } // New catalog? - if currentInfo == nil || *currentInfo.CatalogName != schema { - info := driverbase.GetObjectsInfo{CatalogName: driverbase.Nullable(schema)} - if includeDbSchemas { - info.CatalogDbSchemas = []driverbase.DBSchemaInfo{{DbSchemaName: driverbase.Nullable("")}} + if currentInfo == nil || *currentInfo.CatalogName != catalogName { + info := driverbase.GetObjectsInfo{CatalogName: driverbase.Nullable(catalogName)} + if schemaName.Valid { + schemaInfo := driverbase.DBSchemaInfo{DbSchemaName: driverbase.Nullable(schemaName.String)} + if includeTables { + schemaInfo.DbSchemaTables = []driverbase.TableInfo{} + } + info.CatalogDbSchemas = []driverbase.DBSchemaInfo{schemaInfo} + } else if includeSchemas { + info.CatalogDbSchemas = []driverbase.DBSchemaInfo{} } infos = append(infos, info) currentInfo = &infos[len(infos)-1] @@ -159,10 +170,14 @@ func (c *mysqlConnectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectD // New table? tables := ¤tInfo.CatalogDbSchemas[0].DbSchemaTables if currentTable == nil || currentTable.TableName != tblName.String { - *tables = append(*tables, driverbase.TableInfo{ + tableInfo := driverbase.TableInfo{ TableName: tblName.String, TableType: tblType.String, - }) + } + if includeColumns { + tableInfo.TableColumns = []driverbase.ColumnInfo{} + } + *tables = append(*tables, tableInfo) currentTable = &(*tables)[len(*tables)-1] } From e979bf98ba2243144e16234055ee5541af05015e Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 24 Mar 2026 09:41:48 +0400 Subject: [PATCH 3/4] test: add temp table test --- go/mysql_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/go/mysql_test.go b/go/mysql_test.go index 00f4533..f2247e4 100644 --- a/go/mysql_test.go +++ b/go/mysql_test.go @@ -805,6 +805,49 @@ func TestMySQLTypeTests(t *testing.T) { suite.Run(t, &MySQLTests{Quirks: quirks}) } +func (s *MySQLTestSuite) TestGetObjectsTempTable() { + tempTableName := "getobjects_temp_test" + + // Create a temporary table via bulk ingest + stmt, err := s.cnxn.NewStatement() + s.Require().NoError(err) + defer func() { s.NoError(stmt.Close()) }() + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + }, nil) + + batchbldr := array.NewRecordBuilder(s.mem, schema) + defer batchbldr.Release() + batchbldr.Field(0).(*array.Int64Builder).Append(1) + batch := batchbldr.NewRecordBatch() + defer batch.Release() + + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, tempTableName)) + s.Require().NoError(stmt.SetOption(adbc.OptionValueIngestTemporary, adbc.OptionValueEnabled)) + s.Require().NoError(stmt.Bind(s.ctx, batch)) + _, err = stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) + + // Verify the temp table is queryable on this connection + s.Require().NoError(stmt.SetSqlQuery("SELECT COUNT(*) FROM `" + tempTableName + "`")) + rdr, _, err := stmt.ExecuteQuery(s.ctx) + s.Require().NoError(err) + s.Require().True(rdr.Next()) + count := rdr.RecordBatch().Column(0).(*array.Int64).Value(0) + s.EqualValues(1, count) + rdr.Release() + + // GetObjects should not error even with a temp table on the session. + // On MariaDB, the temp table would appear in the results. + // On MySQL, temp tables are not in INFORMATION_SCHEMA.TABLES. + objRdr, err := s.cnxn.GetObjects(s.ctx, adbc.ObjectDepthTables, nil, nil, &tempTableName, nil, nil) + s.Require().NoError(err) + defer objRdr.Release() + + s.Require().True(objRdr.Next()) +} + func TestMySQLIntegrationSuite(t *testing.T) { suite.Run(t, new(MySQLTestSuite)) } From 61c4ca73a8a3746245c917cfc64696bb3bb75a09 Mon Sep 17 00:00:00 2001 From: tokoko Date: Thu, 26 Mar 2026 18:05:04 +0400 Subject: [PATCH 4/4] fix: add ClearPending calls, remove temp table test --- go/connection.go | 12 +++++++++++ go/connection_getobjects.go | 4 ++++ go/mysql_test.go | 43 ------------------------------------- 3 files changed, 16 insertions(+), 43 deletions(-) diff --git a/go/connection.go b/go/connection.go index 4dbc733..0218650 100644 --- a/go/connection.go +++ b/go/connection.go @@ -37,6 +37,9 @@ const ( // GetCurrentCatalog implements driverbase.CurrentNamespacer. func (c *mysqlConnectionImpl) GetCurrentCatalog() (string, error) { + if err := c.ClearPending(); err != nil { + return "", err + } var database string err := c.Conn.QueryRowContext(context.Background(), "SELECT DATABASE()").Scan(&database) if err != nil { @@ -55,6 +58,9 @@ func (c *mysqlConnectionImpl) GetCurrentDbSchema() (string, error) { // SetCurrentCatalog implements driverbase.CurrentNamespacer. func (c *mysqlConnectionImpl) SetCurrentCatalog(catalog string) error { + if err := c.ClearPending(); err != nil { + return err + } _, err := c.Conn.ExecContext(context.Background(), "USE "+quoteIdentifier(catalog)) return err } @@ -68,6 +74,9 @@ func (c *mysqlConnectionImpl) SetCurrentDbSchema(schema string) error { } func (c *mysqlConnectionImpl) PrepareDriverInfo(ctx context.Context, infoCodes []adbc.InfoCode) error { + if err := c.ClearPending(); err != nil { + return err + } if c.version == "" { var version, comment string if err := c.Conn.QueryRowContext(ctx, "SELECT @@version, @@version_comment").Scan(&version, &comment); err != nil { @@ -80,6 +89,9 @@ func (c *mysqlConnectionImpl) PrepareDriverInfo(ctx context.Context, infoCodes [ // GetTableSchema returns the Arrow schema for a MySQL table func (c *mysqlConnectionImpl) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (schema *arrow.Schema, err error) { + if err := c.ClearPending(); err != nil { + return nil, err + } // Struct to capture MySQL column information type tableColumn struct { OrdinalPosition int32 diff --git a/go/connection_getobjects.go b/go/connection_getobjects.go index 725f1d9..9109344 100644 --- a/go/connection_getobjects.go +++ b/go/connection_getobjects.go @@ -34,6 +34,10 @@ import ( // condition. All filters (catalog, schema, table, column, table type) are // applied in SQL. func (c *mysqlConnectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { + if err := c.ClearPending(); err != nil { + return nil, err + } + includeSchemas := depth != adbc.ObjectDepthCatalogs includeTables := depth == adbc.ObjectDepthTables || depth == adbc.ObjectDepthColumns includeColumns := depth == adbc.ObjectDepthColumns diff --git a/go/mysql_test.go b/go/mysql_test.go index f2247e4..00f4533 100644 --- a/go/mysql_test.go +++ b/go/mysql_test.go @@ -805,49 +805,6 @@ func TestMySQLTypeTests(t *testing.T) { suite.Run(t, &MySQLTests{Quirks: quirks}) } -func (s *MySQLTestSuite) TestGetObjectsTempTable() { - tempTableName := "getobjects_temp_test" - - // Create a temporary table via bulk ingest - stmt, err := s.cnxn.NewStatement() - s.Require().NoError(err) - defer func() { s.NoError(stmt.Close()) }() - - schema := arrow.NewSchema([]arrow.Field{ - {Name: "id", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, - }, nil) - - batchbldr := array.NewRecordBuilder(s.mem, schema) - defer batchbldr.Release() - batchbldr.Field(0).(*array.Int64Builder).Append(1) - batch := batchbldr.NewRecordBatch() - defer batch.Release() - - s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, tempTableName)) - s.Require().NoError(stmt.SetOption(adbc.OptionValueIngestTemporary, adbc.OptionValueEnabled)) - s.Require().NoError(stmt.Bind(s.ctx, batch)) - _, err = stmt.ExecuteUpdate(s.ctx) - s.Require().NoError(err) - - // Verify the temp table is queryable on this connection - s.Require().NoError(stmt.SetSqlQuery("SELECT COUNT(*) FROM `" + tempTableName + "`")) - rdr, _, err := stmt.ExecuteQuery(s.ctx) - s.Require().NoError(err) - s.Require().True(rdr.Next()) - count := rdr.RecordBatch().Column(0).(*array.Int64).Value(0) - s.EqualValues(1, count) - rdr.Release() - - // GetObjects should not error even with a temp table on the session. - // On MariaDB, the temp table would appear in the results. - // On MySQL, temp tables are not in INFORMATION_SCHEMA.TABLES. - objRdr, err := s.cnxn.GetObjects(s.ctx, adbc.ObjectDepthTables, nil, nil, &tempTableName, nil, nil) - s.Require().NoError(err) - defer objRdr.Release() - - s.Require().True(objRdr.Next()) -} - func TestMySQLIntegrationSuite(t *testing.T) { suite.Run(t, new(MySQLTestSuite)) }