Skip to content

Commit ebab2bf

Browse files
authored
fix naïve replacement for stmt args (#30)
1 parent c564157 commit ebab2bf

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

statement.go

+6-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql/driver"
66
"fmt"
7+
"regexp"
78
"strings"
89
"time"
910

@@ -92,19 +93,17 @@ func template(query string) string {
9293
}
9394

9495
func statement(tmpl string, args []driver.NamedValue) string {
95-
var replacements []string
96+
stmt := tmpl
9697
for _, arg := range args {
97-
var placeholder string
98+
var re *regexp.Regexp
9899
if arg.Name != "" {
99-
placeholder = fmt.Sprintf("@%s", arg.Name)
100+
re = regexp.MustCompile(fmt.Sprintf("@%s%s", arg.Name, `\b`))
100101
} else {
101-
placeholder = fmt.Sprintf("@p%d", arg.Ordinal)
102+
re = regexp.MustCompile(fmt.Sprintf("@p%d%s", arg.Ordinal, `\b`))
102103
}
103104
val := fmt.Sprintf("%v", arg.Value)
104-
replacements = append(replacements, placeholder, val)
105+
stmt = re.ReplaceAllString(stmt, val)
105106
}
106-
r := strings.NewReplacer(replacements...)
107-
stmt := r.Replace(tmpl)
108107
return stmt
109108
}
110109

statement_test.go

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package impala
2+
3+
import (
4+
"database/sql/driver"
5+
"testing"
6+
)
7+
8+
func TestStatement(t *testing.T) {
9+
tests := []struct {
10+
stmt string
11+
args []driver.NamedValue
12+
target string
13+
}{
14+
{
15+
stmt: "@p1 p1",
16+
args: []driver.NamedValue{
17+
driver.NamedValue{Ordinal: 1, Value: "val_1"},
18+
},
19+
target: "val_1 p1",
20+
},
21+
{
22+
stmt: "@p1 @p10 @p11 @named @named1 @p1",
23+
args: []driver.NamedValue{
24+
driver.NamedValue{Ordinal: 1, Value: "val_1"},
25+
driver.NamedValue{Ordinal: 10, Name: "named", Value: "val_named"},
26+
driver.NamedValue{Ordinal: 11, Value: "val_11"},
27+
},
28+
target: "val_1 @p10 val_11 val_named @named1 val_1",
29+
},
30+
}
31+
32+
for _, tt := range tests {
33+
result := statement(tt.stmt, tt.args)
34+
35+
if result != tt.target {
36+
t.Fatalf("mismatch for statement: %q\n\ttarget: %q\n\tresult: %q", tt.stmt, tt.target, result)
37+
}
38+
}
39+
}

0 commit comments

Comments
 (0)