Skip to content

Commit 7bbb0c2

Browse files
authored
Helper fn to wrap db with standardized query logger (#573)
* Helper fn to wrap db with standardized query logger * Simplify/consolidate QueryLogger wrapping * Address PR comments * Address comments
1 parent 7281540 commit 7bbb0c2

7 files changed

Lines changed: 102 additions & 29 deletions

File tree

cmds/server_cmd.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"github.com/interline-io/transitland-lib/server/meters"
2323
localmeter "github.com/interline-io/transitland-lib/server/meters/local"
2424
"github.com/interline-io/transitland-lib/tldb"
25-
"github.com/interline-io/transitland-lib/tldb/querylogger"
2625

2726
"github.com/interline-io/transitland-lib/server/finders/actions"
2827
"github.com/interline-io/transitland-lib/server/finders/dbfinder"
@@ -116,15 +115,13 @@ func (cmd *ServerCommand) Parse(args []string) error {
116115

117116
func (cmd *ServerCommand) Run(ctx context.Context) error {
118117
// Open database
119-
var db tldb.Ext
120118
dbx, err := dbutil.OpenDB(cmd.DBURL)
121119
if err != nil {
122120
return err
123121
}
124-
db = dbx
125-
if log.Logger.GetLevel() == zerolog.TraceLevel {
126-
db = &querylogger.QueryLogger{Ext: dbx, Trace: true, LongQueryDuration: time.Duration(cmd.LongQueryDuration) * time.Millisecond}
127-
}
122+
trace := log.Logger.GetLevel() == zerolog.TraceLevel
123+
longQueryDuration := time.Duration(cmd.LongQueryDuration) * time.Millisecond
124+
var db tldb.Ext = dbutil.WithQueryLogger(dbx, trace, longQueryDuration)
128125

129126
// Open redis
130127
var redisClient *redis.Client

internal/testconfig/testconfig.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ import (
1717
"github.com/interline-io/transitland-lib/server/finders/rtfinder"
1818
"github.com/interline-io/transitland-lib/server/jobs"
1919
localjobs "github.com/interline-io/transitland-lib/server/jobs/local"
20+
"github.com/interline-io/transitland-lib/server/dbutil"
2021
"github.com/interline-io/transitland-lib/server/model"
2122
"github.com/interline-io/transitland-lib/server/testutil"
2223
"github.com/interline-io/transitland-lib/testdata"
2324
"github.com/interline-io/transitland-lib/tldb"
24-
"github.com/interline-io/transitland-lib/tldb/querylogger"
2525
"google.golang.org/protobuf/proto"
2626
)
2727

@@ -40,7 +40,7 @@ type Options struct {
4040
func Config(t testing.TB, opts Options) model.Config {
4141
ctx := context.Background()
4242
db := testutil.MustOpenTestDB(t)
43-
return newTestConfig(t, ctx, &querylogger.QueryLogger{Ext: db}, opts)
43+
return newTestConfig(t, ctx, db, opts)
4444
}
4545

4646
func ConfigTx(t testing.TB, opts Options, cb func(model.Config) error) {
@@ -51,7 +51,7 @@ func ConfigTx(t testing.TB, opts Options, cb func(model.Config) error) {
5151
defer tx.Rollback()
5252

5353
// Get finders
54-
testEnv := newTestConfig(t, ctx, &querylogger.QueryLogger{Ext: tx}, opts)
54+
testEnv := newTestConfig(t, ctx, tx, opts)
5555

5656
// Commit or rollback
5757
if err := cb(testEnv); err != nil {
@@ -83,6 +83,8 @@ func DefaultRTJson() []RTJsonFile {
8383
}
8484

8585
func newTestConfig(t testing.TB, ctx context.Context, db tldb.Ext, opts Options) model.Config {
86+
db = dbutil.WithQueryLogger(db, false, 0)
87+
8688
// Default time
8789
if opts.WhenUtc == "" {
8890
opts.WhenUtc = "2022-09-01T00:00:00Z"

server/dbutil/db.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,20 @@ package dbutil
33
import (
44
"context"
55
"database/sql"
6-
"regexp"
76
"strings"
87
"time"
98

109
sq "github.com/irees/squirrel"
1110

1211
"github.com/interline-io/log"
12+
"github.com/interline-io/transitland-lib/internal/tags"
13+
"github.com/interline-io/transitland-lib/tldb/querylogger"
1314
"github.com/jackc/pgx/v5/pgxpool"
1415
"github.com/jackc/pgx/v5/stdlib"
1516
"github.com/jmoiron/sqlx"
1617
"github.com/jmoiron/sqlx/reflectx"
1718
)
1819

19-
var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
20-
var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
21-
22-
func toSnakeCase(str string) string {
23-
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
24-
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
25-
return strings.ToLower(snake)
26-
}
27-
2820
// ConfigureDB sets up common database configuration
2921
func ConfigureDB(sqlDb *sql.DB) (*sqlx.DB, error) {
3022
db := sqlx.NewDb(sqlDb, "pgx")
@@ -35,7 +27,7 @@ func ConfigureDB(sqlDb *sql.DB) (*sqlx.DB, error) {
3527
log.Error().Err(err).Msgf("could not connect to database")
3628
return nil, err
3729
}
38-
db.Mapper = reflectx.NewMapperFunc("db", toSnakeCase)
30+
db.Mapper = reflectx.NewMapperFunc("db", tags.ToSnakeCase)
3931
return db.Unsafe(), nil
4032
}
4133

@@ -72,6 +64,18 @@ func OpenDB(url string) (*sqlx.DB, error) {
7264
return ConfigureDB(db.DB)
7365
}
7466

67+
// WithQueryLogger wraps a database connection with a QueryLogger.
68+
// If the connection is already a QueryLogger, its settings are updated in place
69+
// and it is returned as-is (no additional wrapping layer).
70+
func WithQueryLogger(db querylogger.Ext, trace bool, longQueryDuration time.Duration) *querylogger.QueryLogger {
71+
if ql, ok := db.(*querylogger.QueryLogger); ok {
72+
ql.Trace = trace
73+
ql.LongQueryDuration = longQueryDuration
74+
return ql
75+
}
76+
return &querylogger.QueryLogger{Ext: db, Trace: trace, LongQueryDuration: longQueryDuration}
77+
}
78+
7579
// Select runs a query and reads results into dest.
7680
func Select(ctx context.Context, db sqlx.Ext, q sq.SelectBuilder, dest interface{}) error {
7781
q = q.PlaceholderFormat(sq.Dollar)

server/finders/dbfinder/finder_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ import (
44
"context"
55
"testing"
66

7+
"github.com/interline-io/transitland-lib/server/dbutil"
78
"github.com/interline-io/transitland-lib/server/testutil"
8-
"github.com/interline-io/transitland-lib/tldb/querylogger"
99
"github.com/stretchr/testify/assert"
1010
)
1111

1212
func TestFinder_FindFeedVersionServiceWindow(t *testing.T) {
1313
ctx := context.Background()
1414
db := testutil.MustOpenTestDB(t)
15-
dbf := NewFinder(&querylogger.QueryLogger{Ext: db})
15+
dbf := NewFinder(dbutil.WithQueryLogger(db, false, 0))
1616
testFinder := dbf
1717

1818
fvm := map[string]int{}

tldb/postgres/postgres.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ func (adapter *PostgresAdapter) OpenDB() (*sqlx.DB, error) {
6767

6868
// Close the adapter.
6969
func (adapter *PostgresAdapter) Close() error {
70+
adapter.db = nil
7071
return nil
7172
}
7273

@@ -110,7 +111,12 @@ func (adapter *PostgresAdapter) Tx(cb func(Adapter) error) error {
110111
if err != nil {
111112
return err
112113
}
113-
if err := cb(&PostgresAdapter{DBURL: adapter.DBURL, db: &QueryLogger{Ext: tx}}); err != nil {
114+
// Re-wrap the transaction with QueryLogger only if the original connection was wrapped
115+
var txdb Ext = tx
116+
if ql, ok := adapter.db.(*QueryLogger); ok {
117+
txdb = &QueryLogger{Ext: tx, Trace: ql.Trace, LongQueryDuration: ql.LongQueryDuration}
118+
}
119+
if err := cb(&PostgresAdapter{DBURL: adapter.DBURL, db: txdb}); err != nil {
114120
if commit {
115121
if errTx := tx.Rollback(); errTx != nil {
116122
return errTx

tldb/sqlite/sqlite.go

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ type SQLiteAdapter struct {
4545

4646
// Open the database.
4747
func (adapter *SQLiteAdapter) Open() error {
48+
if adapter.db != nil {
49+
return nil
50+
}
4851
dbname := strings.Split(adapter.DBURL, "://")
4952
if len(dbname) != 2 {
5053
return causes.NewSourceUnreadableError("no database filename provided", nil)
@@ -61,7 +64,9 @@ func (adapter *SQLiteAdapter) Open() error {
6164
// Close the database.
6265
func (adapter *SQLiteAdapter) Close() error {
6366
if a, ok := adapter.db.(tldb.CanClose); ok {
64-
return a.Close()
67+
err := a.Close()
68+
adapter.db = nil
69+
return err
6570
}
6671
return nil
6772
}
@@ -100,19 +105,41 @@ func (adapter *SQLiteAdapter) Sqrl() sq.StatementBuilderType {
100105
func (adapter *SQLiteAdapter) Tx(cb func(Adapter) error) error {
101106
var err error
102107
var tx *sqlx.Tx
103-
if a, ok := adapter.db.(tldb.CanBeginx); ok {
108+
// Special check for wrapped connections
109+
commit := false
110+
switch a := adapter.db.(type) {
111+
case *sqlx.Tx:
112+
tx = a
113+
case *QueryLogger:
114+
if b, ok := a.Ext.(*sqlx.Tx); ok {
115+
tx = b
116+
}
117+
}
118+
// If we aren't already in a transaction, begin one, and commit at end
119+
if a, ok := adapter.db.(tldb.CanBeginx); tx == nil && ok {
104120
tx, err = a.Beginx()
121+
commit = true
105122
}
106123
if err != nil {
107124
return err
108125
}
109-
if errTx := cb(&SQLiteAdapter{DBURL: adapter.DBURL, db: &QueryLogger{Ext: tx}}); errTx != nil {
110-
if err3 := tx.Rollback(); err3 != nil {
111-
return err3
126+
// Re-wrap the transaction with QueryLogger only if the original connection was wrapped
127+
var txdb Ext = tx
128+
if ql, ok := adapter.db.(*QueryLogger); ok {
129+
txdb = &QueryLogger{Ext: tx, Trace: ql.Trace, LongQueryDuration: ql.LongQueryDuration}
130+
}
131+
if errTx := cb(&SQLiteAdapter{DBURL: adapter.DBURL, db: txdb}); errTx != nil {
132+
if commit {
133+
if err3 := tx.Rollback(); err3 != nil {
134+
return err3
135+
}
112136
}
113137
return errTx
114138
}
115-
return tx.Commit()
139+
if commit {
140+
return tx.Commit()
141+
}
142+
return nil
116143
}
117144

118145
// TableExists returns true if the requested table exists

tldb/sqlite/sqlite_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package sqlite
55

66
import (
77
"context"
8+
"errors"
89
"testing"
910

1011
"github.com/interline-io/transitland-lib/tldb/tldbtest"
@@ -14,3 +15,39 @@ func TestSQLiteAdapter(t *testing.T) {
1415
adapter := &SQLiteAdapter{DBURL: "sqlite3://:memory:"}
1516
tldbtest.AdapterTest(context.TODO(), t, adapter)
1617
}
18+
19+
func TestSQLiteAdapter_NestedTx(t *testing.T) {
20+
adapter := &SQLiteAdapter{DBURL: "sqlite3://:memory:"}
21+
if err := adapter.Open(); err != nil {
22+
t.Fatal(err)
23+
}
24+
if err := adapter.Create(); err != nil {
25+
t.Fatal(err)
26+
}
27+
// Outer Tx should commit; inner Tx should reuse the same transaction
28+
outerCalled := false
29+
innerCalled := false
30+
err := adapter.Tx(func(outer Adapter) error {
31+
outerCalled = true
32+
return outer.Tx(func(inner Adapter) error {
33+
innerCalled = true
34+
return nil
35+
})
36+
})
37+
if err != nil {
38+
t.Fatal(err)
39+
}
40+
if !outerCalled || !innerCalled {
41+
t.Fatal("expected both outer and inner callbacks to be called")
42+
}
43+
44+
// Inner error should propagate without double-rollback
45+
err = adapter.Tx(func(outer Adapter) error {
46+
return outer.Tx(func(inner Adapter) error {
47+
return errors.New("inner error")
48+
})
49+
})
50+
if err == nil || err.Error() != "inner error" {
51+
t.Fatalf("expected inner error, got: %v", err)
52+
}
53+
}

0 commit comments

Comments
 (0)