Skip to content

Commit 59fcbec

Browse files
refactor(sources): deduplicate identical RunSQL patterns across pgx abstraction drivers
1 parent b805a1f commit 59fcbec

File tree

6 files changed

+67
-124
lines changed

6 files changed

+67
-124
lines changed

internal/sources/alloydbpg/alloydb_pg.go

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ import (
2323
"cloud.google.com/go/alloydbconn"
2424
"github.com/goccy/go-yaml"
2525
"github.com/googleapis/genai-toolbox/internal/sources"
26-
"github.com/googleapis/genai-toolbox/internal/sqlcommenter"
2726
"github.com/googleapis/genai-toolbox/internal/util"
28-
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
2927
"github.com/jackc/pgx/v5/pgxpool"
3028
"go.opentelemetry.io/otel/trace"
3129
)
@@ -104,35 +102,7 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
104102
}
105103

106104
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
107-
// Inject the database driver into the context for SQLCommenter
108-
ctx = sqlcommenter.WithDBDriver(ctx, "pgx")
109-
// Decorate the statement with SQLCommenter metadata from the context
110-
statement = sqlcommenter.AppendComment(ctx, statement)
111-
112-
results, err := s.Pool.Query(ctx, statement, params...)
113-
if err != nil {
114-
return nil, fmt.Errorf("unable to execute query: %w", err)
115-
}
116-
defer results.Close()
117-
118-
fields := results.FieldDescriptions()
119-
var out []any
120-
for results.Next() {
121-
v, err := results.Values()
122-
if err != nil {
123-
return nil, fmt.Errorf("unable to parse row: %w", err)
124-
}
125-
row := orderedmap.Row{}
126-
for i, f := range fields {
127-
row.Add(f.Name, v[i])
128-
}
129-
out = append(out, row)
130-
}
131-
// this will catch actual query execution errors
132-
if err := results.Err(); err != nil {
133-
return nil, fmt.Errorf("unable to execute query: %w", err)
134-
}
135-
return out, nil
105+
return sources.RunSQLWithPgxQueryer(ctx, s.PostgresPool(), statement, params, "pgx")
136106
}
137107

138108
func getOpts(ipType, userAgent string, useIAM bool) ([]alloydbconn.Option, error) {

internal/sources/cloudsqlpg/cloud_sql_pg.go

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ import (
2222
"cloud.google.com/go/cloudsqlconn"
2323
"github.com/goccy/go-yaml"
2424
"github.com/googleapis/genai-toolbox/internal/sources"
25-
"github.com/googleapis/genai-toolbox/internal/sqlcommenter"
2625
"github.com/googleapis/genai-toolbox/internal/util"
27-
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
2826
"github.com/jackc/pgx/v5/pgxpool"
2927
"go.opentelemetry.io/otel/trace"
3028
)
@@ -102,35 +100,7 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
102100
}
103101

104102
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
105-
// Inject the database driver into the context for SQLCommenter
106-
ctx = sqlcommenter.WithDBDriver(ctx, "pgx")
107-
// Decorate the statement with SQLCommenter metadata from the context
108-
statement = sqlcommenter.AppendComment(ctx, statement)
109-
110-
results, err := s.PostgresPool().Query(ctx, statement, params...)
111-
if err != nil {
112-
return nil, fmt.Errorf("unable to execute query: %w", err)
113-
}
114-
defer results.Close()
115-
116-
fields := results.FieldDescriptions()
117-
var out []any
118-
for results.Next() {
119-
values, err := results.Values()
120-
if err != nil {
121-
return nil, fmt.Errorf("unable to parse row: %w", err)
122-
}
123-
row := orderedmap.Row{}
124-
for i, f := range fields {
125-
row.Add(f.Name, values[i])
126-
}
127-
out = append(out, row)
128-
}
129-
// this will catch actual query execution errors
130-
if err := results.Err(); err != nil {
131-
return nil, fmt.Errorf("unable to execute query: %w", err)
132-
}
133-
return out, nil
103+
return sources.RunSQLWithPgxQueryer(ctx, s.PostgresPool(), statement, params, "pgx")
134104
}
135105

136106
func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string, bool, error) {

internal/sources/cockroachdb/cockroachdb.go

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ import (
2828
"github.com/cockroachdb/cockroach-go/v2/crdb/crdbpgxv5"
2929
"github.com/goccy/go-yaml"
3030
"github.com/googleapis/genai-toolbox/internal/sources"
31-
"github.com/googleapis/genai-toolbox/internal/sqlcommenter"
3231
"github.com/googleapis/genai-toolbox/internal/util"
33-
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
3432
"github.com/jackc/pgx/v5"
3533
"github.com/jackc/pgx/v5/pgxpool"
3634
"go.opentelemetry.io/otel/trace"
@@ -147,34 +145,7 @@ func (s *Source) ExecuteTxWithRetry(ctx context.Context, fn func(pgx.Tx) error)
147145
}
148146

149147
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
150-
// Inject the database driver into the context for SQLCommenter
151-
ctx = sqlcommenter.WithDBDriver(ctx, "pgx")
152-
// Decorate the statement with SQLCommenter metadata from the context
153-
statement = sqlcommenter.AppendComment(ctx, statement)
154-
155-
rows, err := s.Query(ctx, statement, params...)
156-
if err != nil {
157-
return nil, err
158-
}
159-
defer rows.Close()
160-
161-
fields := rows.FieldDescriptions()
162-
var out []any
163-
for rows.Next() {
164-
v, err := rows.Values()
165-
if err != nil {
166-
return nil, fmt.Errorf("unable to parse row: %w", err)
167-
}
168-
row := orderedmap.Row{}
169-
for i, f := range fields {
170-
row.Add(f.Name, v[i])
171-
}
172-
out = append(out, row)
173-
}
174-
if err := rows.Err(); err != nil {
175-
return nil, fmt.Errorf("unable to execute query: %w", err)
176-
}
177-
return out, nil
148+
return sources.RunSQLWithPgxQueryer(ctx, s, statement, params, "pgx")
178149
}
179150

180151
// Query executes a query using the connection pool with MCP security enforcement.

internal/sources/pgx_util.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package sources
16+
17+
import (
18+
"context"
19+
"fmt"
20+
21+
"github.com/googleapis/genai-toolbox/internal/sqlcommenter"
22+
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
23+
"github.com/jackc/pgx/v5"
24+
)
25+
26+
// PgxQueryer abstracts connection pools and wrapper classes that can execute native Pgx queries.
27+
type PgxQueryer interface {
28+
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
29+
}
30+
31+
// RunSQLWithPgxQueryer executes a standard SQL statement with SQLCommenter telemetry
32+
// across any driver natively supporting the pgx execution interface.
33+
func RunSQLWithPgxQueryer(ctx context.Context, queryer PgxQueryer, statement string, params []any, driver string) (any, error) {
34+
// Inject the database driver into the context for SQLCommenter
35+
ctx = sqlcommenter.WithDBDriver(ctx, driver)
36+
// Decorate the statement with SQLCommenter metadata from the context
37+
statement = sqlcommenter.AppendComment(ctx, statement)
38+
39+
results, err := queryer.Query(ctx, statement, params...)
40+
if err != nil {
41+
return nil, fmt.Errorf("unable to execute query: %w", err)
42+
}
43+
defer results.Close()
44+
45+
fields := results.FieldDescriptions()
46+
var out []any
47+
for results.Next() {
48+
values, err := results.Values()
49+
if err != nil {
50+
return nil, fmt.Errorf("unable to parse row: %w", err)
51+
}
52+
row := orderedmap.Row{}
53+
for i, f := range fields {
54+
row.Add(f.Name, values[i])
55+
}
56+
out = append(out, row)
57+
}
58+
// this will catch actual query execution errors
59+
if err := results.Err(); err != nil {
60+
return nil, fmt.Errorf("unable to execute query: %w", err)
61+
}
62+
return out, nil
63+
}

internal/sources/postgres/postgres.go

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ import (
2222

2323
"github.com/goccy/go-yaml"
2424
"github.com/googleapis/genai-toolbox/internal/sources"
25-
"github.com/googleapis/genai-toolbox/internal/sqlcommenter"
2625
"github.com/googleapis/genai-toolbox/internal/util"
27-
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
2826
"github.com/jackc/pgx/v5"
2927
"github.com/jackc/pgx/v5/pgxpool"
3028
"go.opentelemetry.io/otel/trace"
@@ -103,35 +101,7 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
103101
}
104102

105103
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
106-
// Inject the database driver into the context for SQLCommenter
107-
ctx = sqlcommenter.WithDBDriver(ctx, "pgx")
108-
// Decorate the statement with SQLCommenter metadata from the context
109-
statement = sqlcommenter.AppendComment(ctx, statement)
110-
111-
results, err := s.PostgresPool().Query(ctx, statement, params...)
112-
if err != nil {
113-
return nil, fmt.Errorf("unable to execute query: %w", err)
114-
}
115-
defer results.Close()
116-
117-
fields := results.FieldDescriptions()
118-
var out []any
119-
for results.Next() {
120-
values, err := results.Values()
121-
if err != nil {
122-
return nil, fmt.Errorf("unable to parse row: %w", err)
123-
}
124-
row := orderedmap.Row{}
125-
for i, f := range fields {
126-
row.Add(f.Name, values[i])
127-
}
128-
out = append(out, row)
129-
}
130-
// this will catch actual query execution errors
131-
if err := results.Err(); err != nil {
132-
return nil, fmt.Errorf("unable to execute query: %w", err)
133-
}
134-
return out, nil
104+
return sources.RunSQLWithPgxQueryer(ctx, s.PostgresPool(), statement, params, "pgx")
135105
}
136106

137107
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string, queryExecMode string) (*pgxpool.Pool, error) {

internal/sources/yugabytedb/yugabytedb.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
126126
out = append(out, vMap)
127127
}
128128

129-
// this will catch actual query execution errors
130129
if err := results.Err(); err != nil {
131130
return nil, fmt.Errorf("unable to execute query: %w", err)
132131
}

0 commit comments

Comments
 (0)