Skip to content

Commit 17a823a

Browse files
authored
Merge branch 'main' into xrfxlp/device-count-labels
2 parents abbbaff + bf0b5e4 commit 17a823a

2 files changed

Lines changed: 73 additions & 16 deletions

File tree

store-client/pkg/client/postgresql_client.go

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ import (
2020
"encoding/json"
2121
"fmt"
2222
"log/slog"
23+
"regexp"
2324
"sort"
25+
"strconv"
2426
"strings"
2527

2628
"github.com/XSAM/otelsql"
@@ -392,26 +394,26 @@ func (c *PostgreSQLClient) UpdateDocumentStatusFields(
392394
func (c *PostgreSQLClient) UpdateDocument(
393395
ctx context.Context, filter interface{}, update interface{},
394396
) (*UpdateResult, error) {
395-
// Build WHERE clause from filter
396-
whereClause, args, err := c.buildWhereClause(filter)
397+
// Build SET clause first so its placeholders own the initial parameter range.
398+
setClause, updateArgs, err := c.buildUpdateClause(update)
397399
if err != nil {
398400
return nil, err
399401
}
400402

401-
// Build SET clause from update
402-
setClause, updateArgs, err := c.buildUpdateClause(update)
403+
whereClause, filterArgs, err := c.buildWhereClause(filter)
403404
if err != nil {
404405
return nil, err
405406
}
406407

407-
// Combine args (WHERE args + SET args)
408-
args = append(args, updateArgs...)
408+
adjustedWhereClause := c.adjustParameterNumbers(whereClause, len(updateArgs))
409+
args := updateArgs
410+
args = append(args, filterArgs...)
409411

410412
// Build final query
411413
//nolint:gosec // G201: table name from config, clauses built with parameterized queries
412414
query := fmt.Sprintf(
413415
"UPDATE %s SET %s, updated_at = NOW() WHERE %s",
414-
c.table, setClause, whereClause,
416+
c.table, setClause, adjustedWhereClause,
415417
)
416418

417419
result, err := c.db.ExecContext(ctx, query, args...)
@@ -2622,17 +2624,16 @@ func (c *PostgreSQLClient) adjustParameterNumbers(clause string, offset int) str
26222624
return clause
26232625
}
26242626

2625-
// Replace parameter placeholders: $1 → $N, $2 → $N+1, etc.
2626-
// This is a simple implementation; for production, use a more robust parser
2627-
result := clause
2627+
paramRe := regexp.MustCompile(`\$(\d+)`)
26282628

2629-
for i := 20; i >= 1; i-- { // Process in reverse to avoid double replacement
2630-
oldParam := fmt.Sprintf("$%d", i)
2631-
newParam := fmt.Sprintf("$%d", i+offset)
2632-
result = strings.ReplaceAll(result, oldParam, newParam)
2633-
}
2629+
return paramRe.ReplaceAllStringFunc(clause, func(match string) string {
2630+
n, err := strconv.Atoi(match[1:])
2631+
if err != nil {
2632+
return match
2633+
}
26342634

2635-
return result
2635+
return fmt.Sprintf("$%d", n+offset)
2636+
})
26362637
}
26372638

26382639
// buildUpdateClause converts MongoDB-style update operators to PostgreSQL SET clause

store-client/pkg/client/postgresql_client_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@ package client
1616

1717
import (
1818
"context"
19+
"regexp"
1920
"strings"
2021
"testing"
22+
23+
"github.com/DATA-DOG/go-sqlmock"
2124
)
2225

2326
// TestPostgreSQLClient_BasicOperations tests basic CRUD operations
@@ -207,6 +210,59 @@ func TestBuildUpdateClause(t *testing.T) {
207210
}
208211
}
209212

213+
func TestUpdateDocumentOffsetsWherePlaceholdersAfterUpdateArgs(t *testing.T) {
214+
db, mock, err := sqlmock.New()
215+
if err != nil {
216+
t.Fatalf("failed to create sql mock: %v", err)
217+
}
218+
defer db.Close()
219+
220+
client := &PostgreSQLClient{
221+
db: db,
222+
table: "health_events",
223+
}
224+
225+
filter := map[string]interface{}{
226+
"status": "old",
227+
}
228+
update := map[string]interface{}{
229+
"$set": map[string]interface{}{
230+
"status": "new",
231+
},
232+
}
233+
234+
expectedSQL := "UPDATE health_events SET document = jsonb_set(document, '{status}', $1), updated_at = NOW() " +
235+
"WHERE document->>'status' = $2"
236+
mock.ExpectExec(regexp.QuoteMeta(expectedSQL)).
237+
WithArgs(`"new"`, "old").
238+
WillReturnResult(sqlmock.NewResult(0, 1))
239+
240+
result, err := client.UpdateDocument(context.Background(), filter, update)
241+
if err != nil {
242+
t.Fatalf("UpdateDocument returned error: %v", err)
243+
}
244+
245+
if result.MatchedCount != 1 || result.ModifiedCount != 1 {
246+
t.Fatalf("expected one matched and modified document, got matched=%d modified=%d",
247+
result.MatchedCount, result.ModifiedCount)
248+
}
249+
250+
if err := mock.ExpectationsWereMet(); err != nil {
251+
t.Fatalf("unmet sql expectations: %v", err)
252+
}
253+
}
254+
255+
func TestAdjustParameterNumbersHandlesMultiDigitPlaceholders(t *testing.T) {
256+
client := &PostgreSQLClient{}
257+
258+
clause := "document->>'field1' = $1 AND document->>'field10' = $10"
259+
expected := "document->>'field1' = $3 AND document->>'field10' = $12"
260+
261+
if got := client.adjustParameterNumbers(clause, 2); got != expected {
262+
t.Fatalf("expected %q, got %q", expected, got)
263+
}
264+
}
265+
210266
// TestAggregationPipelineConversion tests aggregation pipeline parsing
211267
func TestAggregationPipelineConversion(t *testing.T) {
212268
client := &PostgreSQLClient{table: "health_events"}

0 commit comments

Comments
 (0)