Skip to content

Commit

Permalink
fix(go/adbc/driver/snowflake): fix GetObjects for VECTOR cols (#2564)
Browse files Browse the repository at this point in the history
Fixes #2544 by eliminating the null pointer dereference for unknown
column types
  • Loading branch information
zeroshade authored Feb 26, 2025
1 parent 470b209 commit 0cb90c1
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 1 deletion.
4 changes: 4 additions & 0 deletions go/adbc/driver/internal/shared_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,10 @@ const (
)

func ToXdbcDataType(dt arrow.DataType) (xdbcType XdbcDataType) {
if dt == nil {
return XdbcDataType_XDBC_UNKNOWN_TYPE
}

switch dt.ID() {
case arrow.EXTENSION:
return ToXdbcDataType(dt.(arrow.ExtensionType).StorageType())
Expand Down
9 changes: 8 additions & 1 deletion go/adbc/driver/snowflake/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,9 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth,
field := c.toArrowField(col)
xdbcDataType := internal.ToXdbcDataType(field.Type)

getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcDataType = driverbase.Nullable(int16(field.Type.ID()))
if field.Type != nil {
getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcDataType = driverbase.Nullable(int16(field.Type.ID()))
}
getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcSqlDataType = driverbase.Nullable(int16(xdbcDataType))
}
}
Expand Down Expand Up @@ -475,6 +477,11 @@ func (c *connectionImpl) toArrowField(columnInfo driverbase.ColumnInfo) arrow.Fi
fallthrough
case "GEOMETRY":
field.Type = arrow.BinaryTypes.String
case "VECTOR":
// despite the fact that Snowflake *does* support returning data
// for VECTOR typed columns as Arrow FixedSizeLists, there's no way
// currently to retrieve enough metadata to construct the proper type
// for it
}

return field
Expand Down
67 changes: 67 additions & 0 deletions go/adbc/driver/snowflake/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2294,3 +2294,70 @@ ORDER BY start_time;
suite.False(rdr.Next())
suite.Require().NoError(rdr.Err())
}

func (suite *SnowflakeTests) TestGetObjectsVector() {
suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "MYVECTORTABLE"))
suite.Require().NoError(suite.stmt.SetSqlQuery(`CREATE OR REPLACE TABLE myvectortable (
a VECTOR(float, 3), b VECTOR(float, 3))`))
_, err := suite.stmt.ExecuteUpdate(suite.ctx)
suite.Require().NoError(err)
suite.Require().NoError(suite.stmt.SetSqlQuery(`INSERT INTO myvectortable
SELECT [1.1,2.2,3]::VECTOR(FLOAT,3), [1,1,1]::VECTOR(FLOAT,3)`))
_, err = suite.stmt.ExecuteUpdate(suite.ctx)
suite.Require().NoError(err)

tableName := "MYVECTORTABLE"
rdr, err := suite.cnxn.GetObjects(suite.ctx, adbc.ObjectDepthColumns, nil, nil, &tableName, nil, nil)
suite.Require().NoError(err)
defer rdr.Release()

suite.Require().True(rdr.Next())
rec := rdr.Record()

for i := 0; int64(i) < rec.NumRows(); i++ {
// list<db_schema_schema>
dbSchemasList := rec.Column(1).(*array.List)
// db_schema_schema (struct)
dbSchemas := dbSchemasList.ListValues().(*array.Struct)
// list<table_schema>
dbSchemaTablesList := dbSchemas.Field(1).(*array.List)
// table_schema (struct)
dbSchemaTables := dbSchemaTablesList.ListValues().(*array.Struct)
// list<column_schema>
tableColumnsList := dbSchemaTables.Field(2).(*array.List)
// column_schema (struct)
tableColumns := tableColumnsList.ListValues().(*array.Struct)

start, end := dbSchemasList.ValueOffsets(i)
for j := start; j < end; j++ {
schemaName := dbSchemas.Field(0).(*array.String).Value(int(j))
if !strings.EqualFold(schemaName, suite.Quirks.DBSchema()) {
continue
}
tblStart, tblEnd := dbSchemaTablesList.ValueOffsets(int(j))
for k := tblStart; k < tblEnd; k++ {
tblName := dbSchemaTables.Field(0).(*array.String).Value(int(k))
if !strings.EqualFold(tblName, tableName) {
continue
}

colStart, colEnd := tableColumnsList.ValueOffsets(int(k))
suite.EqualValues(2, colEnd-colStart)

for l := colStart; l < colEnd; l++ {
colName := tableColumns.Field(0).(*array.String).Value(int(l))
ordinalPos := tableColumns.Field(1).(*array.Int32).Value(int(l))
typeName := tableColumns.Field(4).(*array.String).Value(int(l))
switch ordinalPos {
case 1:
suite.Equal("A", colName)
case 2:
suite.Equal("B", colName)
}

suite.Equal("VECTOR", typeName)
}
}
}
}
}

0 comments on commit 0cb90c1

Please sign in to comment.