Skip to content

Commit 068cd59

Browse files
authored
fix: solve some compatibility issues with pgcli (#355)
* fix: replace 'xxx=ANY(yyy)' with 'my_list_contains(yyy, xxx) (#354) * fix: support pg_catalog.pg_get_expr with 3 params (#354) * fix: wrap columns 'proallargtypes' and 'proargtypes' to split string into string array (#354)
1 parent 34b0213 commit 068cd59

File tree

3 files changed

+106
-16
lines changed

3 files changed

+106
-16
lines changed

catalog/internal_macro.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ type MacroDefinition struct {
77
DDL string
88
}
99

10+
var (
11+
SchemaNameSYS string = "__sys__"
12+
MacroNameMyListContains string = "my_list_contains"
13+
14+
MacroNameMySplitListStr string = "my_split_list_str"
15+
)
16+
1017
type InternalMacro struct {
1118
Schema string
1219
Name string
@@ -55,4 +62,48 @@ var InternalMacros = []InternalMacro{
5562
},
5663
},
5764
},
65+
{
66+
Schema: "pg_catalog",
67+
Name: "pg_get_expr",
68+
IsTableMacro: false,
69+
Definitions: []MacroDefinition{
70+
{
71+
Params: []string{"pg_node_tree", "relation_oid"},
72+
// Do nothing currently
73+
DDL: `pg_catalog.pg_get_expr(pg_node_tree, relation_oid)`,
74+
},
75+
{
76+
Params: []string{"pg_node_tree", "relation_oid", "pretty_bool"},
77+
// Do nothing currently
78+
DDL: `pg_catalog.pg_get_expr(pg_node_tree, relation_oid)`,
79+
},
80+
},
81+
},
82+
{
83+
Schema: SchemaNameSYS,
84+
Name: MacroNameMyListContains,
85+
IsTableMacro: false,
86+
Definitions: []MacroDefinition{
87+
{
88+
Params: []string{"l", "v"},
89+
DDL: `CASE
90+
WHEN typeof(l) = 'VARCHAR' THEN
91+
list_contains(regexp_split_to_array(l::VARCHAR, '[{},\s]+'), v)
92+
ELSE
93+
list_contains(l::text[], v)
94+
END`,
95+
},
96+
},
97+
},
98+
{
99+
Schema: SchemaNameSYS,
100+
Name: MacroNameMySplitListStr,
101+
IsTableMacro: false,
102+
Definitions: []MacroDefinition{
103+
{
104+
Params: []string{"l"},
105+
DDL: `regexp_split_to_array(l::VARCHAR, '[{},\s]+')`,
106+
},
107+
},
108+
},
58109
}

pgserver/in_place_handler.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,24 @@ var selectionConversions = []SelectionConversion{
237237
needConvert: func(query *ConvertedStatement) bool {
238238
sqlStr := RemoveComments(query.String)
239239
// TODO(sean): Evaluate the conditions by iterating over the AST.
240-
return getTypeCastRegex().MatchString(sqlStr)
240+
return getSimpleStringMatchingRegex().MatchString(sqlStr)
241241
},
242242
doConvert: func(h *ConnectionHandler, query *ConvertedStatement) error {
243243
sqlStr := RemoveComments(query.String)
244-
sqlStr = ConvertTypeCast(sqlStr)
244+
sqlStr = SimpleStrReplacement(sqlStr)
245+
query.String = sqlStr
246+
return nil
247+
},
248+
},
249+
{
250+
needConvert: func(query *ConvertedStatement) bool {
251+
sqlStr := RemoveComments(query.String)
252+
// TODO: Evaluate the conditions by iterating over the AST.
253+
return getPgAnyOpRegex().MatchString(sqlStr)
254+
},
255+
doConvert: func(h *ConnectionHandler, query *ConvertedStatement) error {
256+
sqlStr := RemoveComments(query.String)
257+
sqlStr = ConvertAnyOp(sqlStr)
245258
query.String = sqlStr
246259
return nil
247260
},

pgserver/stmt.go

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -295,31 +295,57 @@ func ConvertToSys(sql string) string {
295295
}
296296

297297
var (
298-
typeCastRegex *regexp.Regexp
299-
initTypeCastRegex sync.Once
298+
pgAnyOpRegex *regexp.Regexp
299+
initPgAnyOpRegex sync.Once
300300
)
301301

302+
// get the regex to match the operator 'ANY'
303+
func getPgAnyOpRegex() *regexp.Regexp {
304+
initPgAnyOpRegex.Do(func() {
305+
pgAnyOpRegex = regexp.MustCompile(`(?i)([^\s(]+)\s*=\s*any\s*\(\s*([^)]*)\s*\)`)
306+
})
307+
return pgAnyOpRegex
308+
}
309+
310+
// Replace the operator 'ANY' with a function call.
311+
func ConvertAnyOp(sql string) string {
312+
re := getPgAnyOpRegex()
313+
return re.ReplaceAllString(sql, catalog.SchemaNameSYS+"."+catalog.MacroNameMyListContains+"($2, $1)")
314+
}
315+
316+
var (
317+
simpleStrMatchingRegex *regexp.Regexp
318+
initSimpleStrMatchingRegex sync.Once
319+
)
320+
321+
// TODO(sean): This is a temporary solution. We need to find a better way to handle type cast conversion and column conversion. e.g. Iterating the AST with a visitor pattern.
302322
// The Key must be in lowercase. Because the key used for value retrieval is in lowercase.
303-
var typeCastConversion = map[string]string{
323+
var simpleStringsConversion = map[string]string{
324+
// type cast conversion
304325
"::regclass": "::varchar",
326+
"::regtype": "::varchar",
327+
328+
// column conversion
329+
"proallargtypes": catalog.SchemaNameSYS + "." + catalog.MacroNameMySplitListStr + "(proallargtypes)",
330+
"proargtypes": catalog.SchemaNameSYS + "." + catalog.MacroNameMySplitListStr + "(proargtypes)",
305331
}
306332

307333
// This function will return a regex that matches all type casts in the query.
308-
func getTypeCastRegex() *regexp.Regexp {
309-
initTypeCastRegex.Do(func() {
310-
var typeCasts []string
311-
for typeCast := range typeCastConversion {
312-
typeCasts = append(typeCasts, regexp.QuoteMeta(typeCast))
334+
func getSimpleStringMatchingRegex() *regexp.Regexp {
335+
initSimpleStrMatchingRegex.Do(func() {
336+
var simpleStrings []string
337+
for simpleString := range simpleStringsConversion {
338+
simpleStrings = append(simpleStrings, regexp.QuoteMeta(simpleString))
313339
}
314-
typeCastRegex = regexp.MustCompile(`(?i)(` + strings.Join(typeCasts, "|") + `)`)
340+
simpleStrMatchingRegex = regexp.MustCompile(`(?i)(` + strings.Join(simpleStrings, "|") + `)`)
315341
})
316-
return typeCastRegex
342+
return simpleStrMatchingRegex
317343
}
318344

319-
// This function will replace all type casts in the query with the corresponding type cast in the typeCastConversion map.
320-
func ConvertTypeCast(sql string) string {
321-
return getTypeCastRegex().ReplaceAllStringFunc(sql, func(m string) string {
322-
return typeCastConversion[strings.ToLower(m)]
345+
// This function will replace all type casts in the query with the corresponding type cast in the simpleStringsConversion map.
346+
func SimpleStrReplacement(sql string) string {
347+
return getSimpleStringMatchingRegex().ReplaceAllStringFunc(sql, func(m string) string {
348+
return simpleStringsConversion[strings.ToLower(m)]
323349
})
324350
}
325351

0 commit comments

Comments
 (0)