Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion go/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func (c *mysqlConnectionImpl) GetTableSchema(ctx context.Context, catalog *strin
OrdinalPosition int32
ColumnName string
DataType string
ColumnType string
IsNullable string
CharacterMaximumLength sql.NullInt64
NumericPrecision sql.NullInt64
Expand All @@ -95,6 +96,7 @@ func (c *mysqlConnectionImpl) GetTableSchema(ctx context.Context, catalog *strin
ORDINAL_POSITION,
COLUMN_NAME,
DATA_TYPE,
COLUMN_TYPE,
IS_NULLABLE,
CHARACTER_MAXIMUM_LENGTH,
NUMERIC_PRECISION,
Expand Down Expand Up @@ -132,6 +134,7 @@ func (c *mysqlConnectionImpl) GetTableSchema(ctx context.Context, catalog *strin
&col.OrdinalPosition,
&col.ColumnName,
&col.DataType,
&col.ColumnType,
&col.IsNullable,
&col.CharacterMaximumLength,
&col.NumericPrecision,
Expand Down Expand Up @@ -166,9 +169,19 @@ func (c *mysqlConnectionImpl) GetTableSchema(ctx context.Context, catalog *strin
scale = &col.NumericScale.Int64
}

// Use DATA_TYPE but append UNSIGNED if COLUMN_TYPE indicates it
// Only check integer types to avoid false positives with enum/set value lists
dbTypeName := col.DataType
switch strings.ToUpper(col.DataType) {
case "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT":
if strings.Contains(strings.ToUpper(col.ColumnType), "UNSIGNED") {
dbTypeName = col.DataType + " UNSIGNED"
}
}

colType := sqlwrapper.ColumnType{
Name: col.ColumnName,
DatabaseTypeName: col.DataType,
DatabaseTypeName: dbTypeName,
Nullable: col.IsNullable == "YES",
Length: length,
Precision: precision,
Expand Down Expand Up @@ -328,6 +341,14 @@ func (c *mysqlConnectionImpl) arrowToMySQLType(arrowType arrow.DataType, nullabl
mysqlType = "INT"
case *arrow.Int64Type:
mysqlType = "BIGINT"
case *arrow.Uint8Type:
mysqlType = "TINYINT UNSIGNED"
case *arrow.Uint16Type:
mysqlType = "SMALLINT UNSIGNED"
case *arrow.Uint32Type:
mysqlType = "INT UNSIGNED"
case *arrow.Uint64Type:
mysqlType = "BIGINT UNSIGNED"
case *arrow.Float32Type:
mysqlType = "FLOAT"
case *arrow.Float64Type:
Expand Down
16 changes: 14 additions & 2 deletions go/connection_getobjects.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ func (c *mysqlConnectionImpl) getTablesWithColumns(ctx context.Context, catalog
ColumnName string
ColumnComment sql.NullString
DataType string
ColumnType string
IsNullable string
ColumnDefault sql.NullString
}
Expand All @@ -170,6 +171,7 @@ func (c *mysqlConnectionImpl) getTablesWithColumns(ctx context.Context, catalog
c.COLUMN_NAME,
c.COLUMN_COMMENT,
c.DATA_TYPE,
c.COLUMN_TYPE,
c.IS_NULLABLE,
c.COLUMN_DEFAULT
FROM INFORMATION_SCHEMA.TABLES t
Expand Down Expand Up @@ -208,7 +210,7 @@ func (c *mysqlConnectionImpl) getTablesWithColumns(ctx context.Context, catalog
if err := rows.Scan(
&tc.TableName, &tc.TableType,
&tc.OrdinalPosition, &tc.ColumnName, &tc.ColumnComment,
&tc.DataType, &tc.IsNullable, &tc.ColumnDefault,
&tc.DataType, &tc.ColumnType, &tc.IsNullable, &tc.ColumnDefault,
); err != nil {
return nil, c.ErrorHelper.WrapIO(err, "failed to scan table with columns")
}
Expand All @@ -226,6 +228,16 @@ func (c *mysqlConnectionImpl) getTablesWithColumns(ctx context.Context, catalog
var radix sql.NullInt16
var nullable sql.NullInt16

// Build the full type name including UNSIGNED if applicable
// Only check integer types to avoid false positives with enum/set value lists
xdbcTypeName := tc.DataType
switch strings.ToUpper(tc.DataType) {
case "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT":
if strings.Contains(strings.ToUpper(tc.ColumnType), "UNSIGNED") {
xdbcTypeName = tc.DataType + " UNSIGNED"
}
}

// Set numeric precision radix (MySQL doesn't store this directly)
dataType := strings.ToUpper(tc.DataType)
switch dataType {
Expand Down Expand Up @@ -266,7 +278,7 @@ func (c *mysqlConnectionImpl) getTablesWithColumns(ctx context.Context, catalog
ColumnName: tc.ColumnName,
OrdinalPosition: &tc.OrdinalPosition,
Remarks: driverbase.NullStringToPtr(tc.ColumnComment),
XdbcTypeName: &tc.DataType,
XdbcTypeName: &xdbcTypeName,
XdbcNumPrecRadix: driverbase.NullInt16ToPtr(radix),
XdbcNullable: driverbase.NullInt16ToPtr(nullable),
XdbcIsNullable: &tc.IsNullable,
Expand Down
16 changes: 16 additions & 0 deletions go/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,27 @@ type mySQLTypeConverter struct {
sqlwrapper.DefaultTypeConverter
}

// normalizeUnsignedTypeName converts "UNSIGNED INT" -> "INT UNSIGNED" format
// The go-sql-driver/mysql returns "UNSIGNED X" but the default type converter expects "X UNSIGNED"
func normalizeUnsignedTypeName(typeName string) string {
if strings.HasPrefix(typeName, "UNSIGNED ") {
return strings.TrimPrefix(typeName, "UNSIGNED ") + " UNSIGNED"
}
return typeName
}

// ConvertRawColumnType implements TypeConverter with MySQL-specific enhancements
func (m *mySQLTypeConverter) ConvertRawColumnType(colType sqlwrapper.ColumnType) (arrow.DataType, bool, arrow.Metadata, error) {
typeName := strings.ToUpper(colType.DatabaseTypeName)
nullable := colType.Nullable

// Normalize "UNSIGNED X" to "X UNSIGNED" for the default type converter
// Only update DatabaseTypeName when reordering is needed, to preserve original casing in metadata
typeName = normalizeUnsignedTypeName(typeName)
if typeName != strings.ToUpper(colType.DatabaseTypeName) {
colType.DatabaseTypeName = typeName
}

switch typeName {
case "BIT":
// Handle BIT type as binary data
Expand Down
73 changes: 71 additions & 2 deletions go/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,11 @@ func (s *MySQLTests) TestSelect() {
point_col POINT,
polygon_col POLYGON,
geometry_col GEOMETRY,
bit_col BIT(8)
bit_col BIT(8),
utinyint_col TINYINT UNSIGNED,
usmallint_col SMALLINT UNSIGNED,
uint_col INT UNSIGNED,
ubigint_col BIGINT UNSIGNED
)
`))
_, err := s.stmt.ExecuteUpdate(s.ctx)
Expand All @@ -372,7 +376,8 @@ func (s *MySQLTests) TestSelect() {
ST_GeomFromText('POINT(1 2)'),
ST_GeomFromText('POLYGON((0 0, 0 3, 3 3, 3 0, 0 0))'),
ST_GeomFromText('LINESTRING(0 0, 1 1, 2 2)'),
b'10101010'
b'10101010',
200, 60000, 3000000000, 10000000000000000000
)
`))
_, err = s.stmt.ExecuteUpdate(s.ctx)
Expand Down Expand Up @@ -595,6 +600,70 @@ func (s *MySQLTests) TestSelect() {
}, nil),
expected: `[{"bitvalue": "qg=="}]`,
},
{
name: "unsigned_tinyint",
query: "SELECT utinyint_col AS value FROM test_types",
schema: arrow.NewSchema([]arrow.Field{
{
Name: "value",
Type: arrow.PrimitiveTypes.Uint8,
Nullable: true,
Metadata: arrow.MetadataFrom(map[string]string{
"sql.column_name": "value",
"sql.database_type_name": "TINYINT UNSIGNED",
}),
},
}, nil),
expected: `[{"value": 200}]`,
},
{
name: "unsigned_smallint",
query: "SELECT usmallint_col AS value FROM test_types",
schema: arrow.NewSchema([]arrow.Field{
{
Name: "value",
Type: arrow.PrimitiveTypes.Uint16,
Nullable: true,
Metadata: arrow.MetadataFrom(map[string]string{
"sql.column_name": "value",
"sql.database_type_name": "SMALLINT UNSIGNED",
}),
},
}, nil),
expected: `[{"value": 60000}]`,
},
{
name: "unsigned_int",
query: "SELECT uint_col AS value FROM test_types",
schema: arrow.NewSchema([]arrow.Field{
{
Name: "value",
Type: arrow.PrimitiveTypes.Uint32,
Nullable: true,
Metadata: arrow.MetadataFrom(map[string]string{
"sql.column_name": "value",
"sql.database_type_name": "INT UNSIGNED",
}),
},
}, nil),
expected: `[{"value": 3000000000}]`,
},
{
name: "unsigned_bigint",
query: "SELECT ubigint_col AS value FROM test_types",
schema: arrow.NewSchema([]arrow.Field{
{
Name: "value",
Type: arrow.PrimitiveTypes.Uint64,
Nullable: true,
Metadata: arrow.MetadataFrom(map[string]string{
"sql.column_name": "value",
"sql.database_type_name": "BIGINT UNSIGNED",
}),
},
}, nil),
expected: `[{"value": 10000000000000000000}]`,
},
} {
s.Run(testCase.name, func() {
s.NoError(s.stmt.SetSqlQuery(testCase.query))
Expand Down
Loading