Skip to content

Commit 06308e4

Browse files
feat(go/adbc/driver/snowflake): implement WithTransporter driver option (#2558)
Relates to #2547 --------- Co-authored-by: Felipe Vianna <[email protected]>
1 parent 28a87ea commit 06308e4

File tree

2 files changed

+76
-2
lines changed

2 files changed

+76
-2
lines changed

go/adbc/driver/snowflake/driver.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package snowflake
2020
import (
2121
"errors"
2222
"maps"
23+
"net/http"
2324
"runtime/debug"
2425
"strings"
2526

@@ -170,22 +171,58 @@ func quoteTblName(name string) string {
170171
return "\"" + strings.ReplaceAll(name, "\"", "\"\"") + "\""
171172
}
172173

174+
type config struct {
175+
*gosnowflake.Config
176+
}
177+
178+
// Option is a function type to set custom driver configurations.
179+
//
180+
// It is intended for configurations that cannot be provided from the standard options map,
181+
// e.g. the underlying HTTP transporter.
182+
type Option func(*config) error
183+
184+
// WithTransporter sets the custom transporter to use for the Snowflake connection.
185+
// This allows to intercept HTTP requests and responses.
186+
func WithTransporter(transporter http.RoundTripper) Option {
187+
return func(cfg *config) error {
188+
cfg.Transporter = transporter
189+
return nil
190+
}
191+
}
192+
193+
// Driver is the Snowflake driver interface.
194+
//
195+
// It extends the base adbc.Driver to provide additional options
196+
// when creating the Snowflake database.
197+
type Driver interface {
198+
adbc.Driver
199+
200+
// NewDatabaseWithOptions creates a new Snowflake database with the provided options.
201+
NewDatabaseWithOptions(map[string]string, ...Option) (adbc.Database, error)
202+
}
203+
204+
var _ Driver = (*driverImpl)(nil)
205+
173206
type driverImpl struct {
174207
driverbase.DriverImplBase
175208
}
176209

177210
// NewDriver creates a new Snowflake driver using the given Arrow allocator.
178-
func NewDriver(alloc memory.Allocator) adbc.Driver {
211+
func NewDriver(alloc memory.Allocator) Driver {
179212
info := driverbase.DefaultDriverInfo("Snowflake")
180213
if infoVendorVersion != "" {
181214
if err := info.RegisterInfoCode(adbc.InfoVendorVersion, infoVendorVersion); err != nil {
182215
panic(err)
183216
}
184217
}
185-
return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc)})
218+
return &driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc)}
186219
}
187220

188221
func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) {
222+
return d.NewDatabaseWithOptions(opts)
223+
}
224+
225+
func (d *driverImpl) NewDatabaseWithOptions(opts map[string]string, optFuncs ...Option) (adbc.Database, error) {
189226
opts = maps.Clone(opts)
190227
db := &databaseImpl{
191228
DatabaseImplBase: driverbase.NewDatabaseImplBase(&d.DriverImplBase),
@@ -195,5 +232,12 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error)
195232
return nil, err
196233
}
197234

235+
cfg := &config{Config: db.cfg}
236+
for _, opt := range optFuncs {
237+
if err := opt(cfg); err != nil {
238+
return nil, err
239+
}
240+
}
241+
198242
return driverbase.NewDatabase(db), nil
199243
}

go/adbc/driver/snowflake/driver_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"encoding/pem"
3131
"fmt"
3232
"math"
33+
"net/http"
3334
"os"
3435
"runtime"
3536
"strconv"
@@ -350,6 +351,35 @@ func (suite *SnowflakeTests) TearDownTest() {
350351
suite.driver = nil
351352
}
352353

354+
type customTransport struct {
355+
base *http.Transport
356+
called bool
357+
}
358+
359+
func (t *customTransport) RoundTrip(r *http.Request) (*http.Response, error) {
360+
t.called = true
361+
return t.base.RoundTrip(r)
362+
}
363+
364+
func (suite *SnowflakeTests) TestNewDatabaseWithOptions() {
365+
t := suite.T()
366+
367+
drv := suite.Quirks.SetupDriver(t).(driver.Driver)
368+
369+
t.Run("WithTransporter", func(t *testing.T) {
370+
transport := &customTransport{base: gosnowflake.SnowflakeTransport}
371+
db, err := drv.NewDatabaseWithOptions(suite.Quirks.DatabaseOptions(),
372+
driver.WithTransporter(transport))
373+
suite.NoError(err)
374+
suite.NotNil(db)
375+
cnxn, err := db.Open(suite.ctx)
376+
suite.NoError(err)
377+
suite.NoError(db.Close())
378+
suite.NoError(cnxn.Close())
379+
suite.True(transport.called)
380+
})
381+
}
382+
353383
func (suite *SnowflakeTests) TestSqlIngestTimestamp() {
354384
suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, "bulk_ingest"))
355385

0 commit comments

Comments
 (0)