Skip to content

Commit

Permalink
feat(go/adbc/driver/snowflake): implement WithTransporter driver opti…
Browse files Browse the repository at this point in the history
…on (#2558)

Relates to #2547

---------

Co-authored-by: Felipe Vianna <[email protected]>
  • Loading branch information
frbvianna and frbvianna-sap authored Feb 27, 2025
1 parent 28a87ea commit 06308e4
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 2 deletions.
48 changes: 46 additions & 2 deletions go/adbc/driver/snowflake/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package snowflake
import (
"errors"
"maps"
"net/http"
"runtime/debug"
"strings"

Expand Down Expand Up @@ -170,22 +171,58 @@ func quoteTblName(name string) string {
return "\"" + strings.ReplaceAll(name, "\"", "\"\"") + "\""
}

type config struct {
*gosnowflake.Config
}

// Option is a function type to set custom driver configurations.
//
// It is intended for configurations that cannot be provided from the standard options map,
// e.g. the underlying HTTP transporter.
type Option func(*config) error

// WithTransporter sets the custom transporter to use for the Snowflake connection.
// This allows to intercept HTTP requests and responses.
func WithTransporter(transporter http.RoundTripper) Option {
return func(cfg *config) error {
cfg.Transporter = transporter
return nil
}
}

// Driver is the Snowflake driver interface.
//
// It extends the base adbc.Driver to provide additional options
// when creating the Snowflake database.
type Driver interface {
adbc.Driver

// NewDatabaseWithOptions creates a new Snowflake database with the provided options.
NewDatabaseWithOptions(map[string]string, ...Option) (adbc.Database, error)
}

var _ Driver = (*driverImpl)(nil)

type driverImpl struct {
driverbase.DriverImplBase
}

// NewDriver creates a new Snowflake driver using the given Arrow allocator.
func NewDriver(alloc memory.Allocator) adbc.Driver {
func NewDriver(alloc memory.Allocator) Driver {
info := driverbase.DefaultDriverInfo("Snowflake")
if infoVendorVersion != "" {
if err := info.RegisterInfoCode(adbc.InfoVendorVersion, infoVendorVersion); err != nil {
panic(err)
}
}
return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc)})
return &driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc)}
}

func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) {
return d.NewDatabaseWithOptions(opts)
}

func (d *driverImpl) NewDatabaseWithOptions(opts map[string]string, optFuncs ...Option) (adbc.Database, error) {
opts = maps.Clone(opts)
db := &databaseImpl{
DatabaseImplBase: driverbase.NewDatabaseImplBase(&d.DriverImplBase),
Expand All @@ -195,5 +232,12 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error)
return nil, err
}

cfg := &config{Config: db.cfg}
for _, opt := range optFuncs {
if err := opt(cfg); err != nil {
return nil, err
}
}

return driverbase.NewDatabase(db), nil
}
30 changes: 30 additions & 0 deletions go/adbc/driver/snowflake/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"encoding/pem"
"fmt"
"math"
"net/http"
"os"
"runtime"
"strconv"
Expand Down Expand Up @@ -350,6 +351,35 @@ func (suite *SnowflakeTests) TearDownTest() {
suite.driver = nil
}

type customTransport struct {
base *http.Transport
called bool
}

func (t *customTransport) RoundTrip(r *http.Request) (*http.Response, error) {
t.called = true
return t.base.RoundTrip(r)
}

func (suite *SnowflakeTests) TestNewDatabaseWithOptions() {
t := suite.T()

drv := suite.Quirks.SetupDriver(t).(driver.Driver)

t.Run("WithTransporter", func(t *testing.T) {
transport := &customTransport{base: gosnowflake.SnowflakeTransport}
db, err := drv.NewDatabaseWithOptions(suite.Quirks.DatabaseOptions(),
driver.WithTransporter(transport))
suite.NoError(err)
suite.NotNil(db)
cnxn, err := db.Open(suite.ctx)
suite.NoError(err)
suite.NoError(db.Close())
suite.NoError(cnxn.Close())
suite.True(transport.called)
})
}

func (suite *SnowflakeTests) TestSqlIngestTimestamp() {
suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "bulk_ingest"))

Expand Down

0 comments on commit 06308e4

Please sign in to comment.