diff --git a/driver/driver.go b/driver/driver.go index 19f5001c..f7dc6fe2 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -63,7 +63,7 @@ var driverName = "sqlite3" func init() { if driverName != "" { - sql.Register(driverName, sqlite{}) + sql.Register(driverName, &SQLite{}) } } @@ -73,29 +73,37 @@ func init() { // The conn can be used to execute queries, register functions, etc. // Any error return closes the conn and passes the error to [database/sql]. func Open(dataSourceName string, init func(*sqlite3.Conn) error) (*sql.DB, error) { - c, err := newConnector(dataSourceName, init) + c, err := (&SQLite{Init: init}).OpenConnector(dataSourceName) if err != nil { return nil, err } return sql.OpenDB(c), nil } -type sqlite struct{} +type SQLite struct { -func (sqlite) Open(name string) (driver.Conn, error) { - c, err := newConnector(name, nil) + // The init function is called by the driver on new connections. + // The conn can be used to execute queries, register functions, etc. + // Any error return closes the conn and passes the error to [database/sql]. + Init func(*sqlite3.Conn) error +} + +// Open: implements [database/sql.Driver]. +func (d *SQLite) Open(name string) (driver.Conn, error) { + c, err := d.newConnector(name) if err != nil { return nil, err } return c.Connect(context.Background()) } -func (sqlite) OpenConnector(name string) (driver.Connector, error) { - return newConnector(name, nil) +// OpenConnector: implements [database/sql.DriverContext]. +func (d *SQLite) OpenConnector(name string) (driver.Connector, error) { + return d.newConnector(name) } -func newConnector(name string, init func(*sqlite3.Conn) error) (*connector, error) { - c := connector{name: name, init: init} +func (d *SQLite) newConnector(name string) (*connector, error) { + c := connector{driver: d, name: name} var txlock, timefmt string if strings.HasPrefix(name, "file:") { @@ -137,7 +145,7 @@ func newConnector(name string, init func(*sqlite3.Conn) error) (*connector, erro } type connector struct { - init func(*sqlite3.Conn) error + driver *SQLite name string txBegin string tmRead sqlite3.TimeFormat @@ -146,7 +154,7 @@ type connector struct { } func (n *connector) Driver() driver.Driver { - return sqlite{} + return n.driver } func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { @@ -175,13 +183,13 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { return nil, err } } - if n.init != nil { - err = n.init(c.Conn) + if n.driver.Init != nil { + err = n.driver.Init(c.Conn) if err != nil { return nil, err } } - if n.pragmas || n.init != nil { + if n.pragmas || n.driver.Init != nil { s, _, err := c.Conn.Prepare(`PRAGMA query_only`) if err != nil { return nil, err