Skip to content

Commit 33b402c

Browse files
Merge pull request #87 from kaleido-io/like-escape
Fix escaping in LIKE expressions
2 parents f18f23b + 9aabb54 commit 33b402c

File tree

3 files changed

+111
-24
lines changed

3 files changed

+111
-24
lines changed

.golangci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,6 @@ linters:
5454
- unconvert
5555
- unparam
5656
- unused
57+
issues:
58+
exclude:
59+
- "method ToSql should be ToSQL"

pkg/dbsql/filter_sql.go

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,70 @@ import (
2626
"github.com/hyperledger/firefly-common/pkg/i18n"
2727
)
2828

29+
const escapeChar = "["
30+
31+
type LikeEscape sq.Like
32+
type NotLikeEscape sq.NotLike
33+
type ILikeEscape sq.ILike
34+
type NotILikeEscape sq.NotILike
35+
36+
// Split a map into a list of maps with a single entry each
37+
func splitMap[T ~map[string]interface{}](m T) (exprs []T) {
38+
for key, val := range m {
39+
exprs = append(exprs, T{key: val})
40+
}
41+
return exprs
42+
}
43+
44+
// Convert a list of Sqlizer operations to sq.And
45+
func toAnd[T sq.Sqlizer](ops []T) (and sq.And) {
46+
for _, op := range ops {
47+
and = append(and, op)
48+
}
49+
return and
50+
}
51+
52+
func (lk LikeEscape) ToSql() (sql string, args []interface{}, err error) {
53+
if len(lk) == 1 {
54+
sql, args, err = sq.Like(lk).ToSql()
55+
return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err
56+
}
57+
return toAnd(splitMap(lk)).ToSql()
58+
}
59+
60+
func (lk NotLikeEscape) ToSql() (sql string, args []interface{}, err error) {
61+
if len(lk) == 1 {
62+
sql, args, err = sq.NotLike(lk).ToSql()
63+
return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err
64+
}
65+
return toAnd(splitMap(lk)).ToSql()
66+
}
67+
68+
func (lk ILikeEscape) ToSql() (sql string, args []interface{}, err error) {
69+
if len(lk) == 1 {
70+
sql, args, err = sq.ILike(lk).ToSql()
71+
return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err
72+
}
73+
return toAnd(splitMap(lk)).ToSql()
74+
}
75+
76+
func (lk NotILikeEscape) ToSql() (sql string, args []interface{}, err error) {
77+
if len(lk) == 1 {
78+
sql, args, err = sq.NotILike(lk).ToSql()
79+
return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err
80+
}
81+
return toAnd(splitMap(lk)).ToSql()
82+
}
83+
84+
func (s *Database) escapeLike(value ffapi.FieldSerialization) string {
85+
v, _ := value.Value()
86+
vs, _ := v.(string)
87+
vs = strings.ReplaceAll(vs, escapeChar, escapeChar+escapeChar)
88+
vs = strings.ReplaceAll(vs, "%", escapeChar+"%")
89+
vs = strings.ReplaceAll(vs, "_", escapeChar+"_")
90+
return vs
91+
}
92+
2993
func (s *Database) FilterSelect(ctx context.Context, tableName string, sel sq.SelectBuilder, filter ffapi.Filter, typeMap map[string]string, defaultSort []interface{}, preconditions ...sq.Sqlizer) (sq.SelectBuilder, sq.Sqlizer, *ffapi.FilterInfo, error) {
3094
fi, err := filter.Finalize()
3195
if err != nil {
@@ -127,15 +191,6 @@ func (s *Database) FilterUpdate(ctx context.Context, update sq.UpdateBuilder, fi
127191
return update.Where(fop), nil
128192
}
129193

130-
func (s *Database) escapeLike(value ffapi.FieldSerialization) string {
131-
v, _ := value.Value()
132-
vs, _ := v.(string)
133-
vs = strings.ReplaceAll(vs, "[", "[[]")
134-
vs = strings.ReplaceAll(vs, "%", "[%]")
135-
vs = strings.ReplaceAll(vs, "_", "[_]")
136-
return vs
137-
}
138-
139194
func (s *Database) mapField(tableName, fieldName string, tm map[string]string) string {
140195
if fieldName == "sequence" {
141196
if tableName == "" {
@@ -158,17 +213,17 @@ func (s *Database) mapField(tableName, fieldName string, tm map[string]string) s
158213
// newILike uses ILIKE if supported by DB, otherwise the "lower" approach
159214
func (s *Database) newILike(field, value string) sq.Sqlizer {
160215
if s.features.UseILIKE {
161-
return sq.ILike{field: value}
216+
return ILikeEscape{field: value}
162217
}
163-
return sq.Like{fmt.Sprintf("lower(%s)", field): strings.ToLower(value)}
218+
return LikeEscape{fmt.Sprintf("lower(%s)", field): strings.ToLower(value)}
164219
}
165220

166221
// newNotILike uses ILIKE if supported by DB, otherwise the "lower" approach
167222
func (s *Database) newNotILike(field, value string) sq.Sqlizer {
168223
if s.features.UseILIKE {
169-
return sq.NotILike{field: value}
224+
return NotILikeEscape{field: value}
170225
}
171-
return sq.NotLike{fmt.Sprintf("lower(%s)", field): strings.ToLower(value)}
226+
return NotLikeEscape{fmt.Sprintf("lower(%s)", field): strings.ToLower(value)}
172227
}
173228

174229
func (s *Database) filterOp(ctx context.Context, tableName string, op *ffapi.FilterInfo, tm map[string]string) (sq.Sqlizer, error) {
@@ -190,25 +245,25 @@ func (s *Database) filterOp(ctx context.Context, tableName string, op *ffapi.Fil
190245
case ffapi.FilterOpNotIn:
191246
return sq.NotEq{s.mapField(tableName, op.Field, tm): op.Values}, nil
192247
case ffapi.FilterOpCont:
193-
return sq.Like{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
248+
return LikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
194249
case ffapi.FilterOpNotCont:
195-
return sq.NotLike{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
250+
return NotLikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
196251
case ffapi.FilterOpICont:
197252
return s.newILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))), nil
198253
case ffapi.FilterOpNotICont:
199254
return s.newNotILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
200255
case ffapi.FilterOpStartsWith:
201-
return sq.Like{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
256+
return LikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
202257
case ffapi.FilterOpNotStartsWith:
203-
return sq.NotLike{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
258+
return NotLikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
204259
case ffapi.FilterOpIStartsWith:
205260
return s.newILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
206261
case ffapi.FilterOpNotIStartsWith:
207262
return s.newNotILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
208263
case ffapi.FilterOpEndsWith:
209-
return sq.Like{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
264+
return LikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
210265
case ffapi.FilterOpNotEndsWith:
211-
return sq.NotLike{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
266+
return NotLikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
212267
case ffapi.FilterOpIEndsWith:
213268
return s.newILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%%%s", s.escapeLike(op.Value))), nil
214269
case ffapi.FilterOpNotIEndsWith:

pkg/dbsql/filter_sql_test.go

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package dbsql
1919
import (
2020
"context"
2121
"database/sql/driver"
22+
"sort"
2223
"testing"
2324

2425
"github.com/Masterminds/squirrel"
@@ -128,7 +129,7 @@ func TestSQLQueryFactoryExtraOps(t *testing.T) {
128129

129130
sqlFilter, _, err := sel.ToSql()
130131
assert.NoError(t, err)
131-
assert.Equal(t, "SELECT * FROM mytable AS mt WHERE (mt.created IN (?,?,?) AND mt.created NOT IN (?,?,?) AND mt.id = ? AND mt.id IN (?) AND mt.id IS NOT NULL AND mt.created < ? AND mt.created <= ? AND mt.created >= ? AND mt.created <> ? AND mt.seq > ? AND mt.topics LIKE ? AND mt.topics NOT LIKE ? AND mt.topics ILIKE ? AND mt.topics NOT ILIKE ?) ORDER BY mt.seq DESC", sqlFilter)
132+
assert.Equal(t, "SELECT * FROM mytable AS mt WHERE (mt.created IN (?,?,?) AND mt.created NOT IN (?,?,?) AND mt.id = ? AND mt.id IN (?) AND mt.id IS NOT NULL AND mt.created < ? AND mt.created <= ? AND mt.created >= ? AND mt.created <> ? AND mt.seq > ? AND mt.topics LIKE ? ESCAPE '[' AND mt.topics NOT LIKE ? ESCAPE '[' AND mt.topics ILIKE ? ESCAPE '[' AND mt.topics NOT ILIKE ? ESCAPE '[') ORDER BY mt.seq DESC", sqlFilter)
132133
}
133134

134135
func TestSQLQueryFactoryEvenMoreOps(t *testing.T) {
@@ -139,8 +140,8 @@ func TestSQLQueryFactoryEvenMoreOps(t *testing.T) {
139140
f := fb.And(
140141
fb.IEq("id", u),
141142
fb.NIeq("id", nil),
142-
fb.StartsWith("topics", "abc"),
143-
fb.NotStartsWith("topics", "def"),
143+
fb.StartsWith("topics", "abc_"),
144+
fb.NotStartsWith("topics", "def%"),
144145
fb.IStartsWith("topics", "ghi"),
145146
fb.NotIStartsWith("topics", "jkl"),
146147
fb.EndsWith("topics", "mno"),
@@ -154,9 +155,37 @@ func TestSQLQueryFactoryEvenMoreOps(t *testing.T) {
154155
sel, _, _, err := s.FilterSelect(context.Background(), "mt", sel, f, nil, []interface{}{"sequence"})
155156
assert.NoError(t, err)
156157

157-
sqlFilter, _, err := sel.ToSql()
158+
sqlFilter, args, err := sel.ToSql()
159+
assert.NoError(t, err)
160+
assert.Equal(t, "SELECT * FROM mytable AS mt WHERE (mt.id ILIKE ? ESCAPE '[' AND mt.id NOT ILIKE ? ESCAPE '[' AND mt.topics LIKE ? ESCAPE '[' AND mt.topics NOT LIKE ? ESCAPE '[' AND mt.topics ILIKE ? ESCAPE '[' AND mt.topics NOT ILIKE ? ESCAPE '[' AND mt.topics LIKE ? ESCAPE '[' AND mt.topics NOT LIKE ? ESCAPE '[' AND mt.topics ILIKE ? ESCAPE '[' AND mt.topics NOT ILIKE ? ESCAPE '[') ORDER BY mt.seq DESC", sqlFilter)
161+
assert.Equal(t, []interface{}{
162+
"4066abdc-8bbd-4472-9d29-1a55b467f9b9",
163+
"",
164+
"abc[_%",
165+
"def[%%",
166+
"ghi%",
167+
"jkl%",
168+
"%mno",
169+
"%pqr",
170+
"%sty",
171+
"%vwx",
172+
}, args)
173+
}
174+
175+
func TestSQLQueryFactoryEscapeLike(t *testing.T) {
176+
177+
sel := squirrel.Select("*").From("mytable AS mt").
178+
Where(LikeEscape{"a": 1, "b": 2}).
179+
Where(NotLikeEscape{"a": 1, "b": 2}).
180+
Where(ILikeEscape{"a": 1, "b": 2}).
181+
Where(NotILikeEscape{"a": 1, "b": 2})
182+
183+
sql, args, err := sel.ToSql()
158184
assert.NoError(t, err)
159-
assert.Equal(t, "SELECT * FROM mytable AS mt WHERE (mt.id ILIKE ? AND mt.id NOT ILIKE ? AND mt.topics LIKE ? AND mt.topics NOT LIKE ? AND mt.topics ILIKE ? AND mt.topics NOT ILIKE ? AND mt.topics LIKE ? AND mt.topics NOT LIKE ? AND mt.topics ILIKE ? AND mt.topics NOT ILIKE ?) ORDER BY mt.seq DESC", sqlFilter)
185+
assert.Regexp(t, `SELECT \* FROM mytable AS mt WHERE \([ab] LIKE \? ESCAPE '\[' AND [ab] LIKE \? ESCAPE '\['\) AND \([ab] NOT LIKE \? ESCAPE '\[' AND [ab] NOT LIKE \? ESCAPE '\['\) AND \([ab] ILIKE \? ESCAPE '\[' AND [ab] ILIKE \? ESCAPE '\['\) AND \([ab] NOT ILIKE \? ESCAPE '\[' AND [ab] NOT ILIKE \? ESCAPE '\['\)`, sql)
186+
assert.Len(t, args, 8)
187+
sort.Slice(args, func(i, j int) bool { return args[i].(int) < args[j].(int) })
188+
assert.Equal(t, []interface{}{1, 1, 1, 1, 2, 2, 2, 2}, args)
160189
}
161190

162191
func TestSQLQueryFactoryFinalizeFail(t *testing.T) {

0 commit comments

Comments
 (0)