Skip to content

Commit b60eaa3

Browse files
Ajit Pratap SinghAjit Pratap Singh
authored andcommitted
fix: resolve merge conflicts with main, fix gofmt, handle OUTPUT parse error
- Resolved merge conflicts in ast.go, dml.go, expressions.go - Merged OnDuplicateKey (MySQL) and Output (SQL Server) fields in InsertStatement - Fixed silent error discard in OUTPUT clause parsing - Fixed gofmt formatting in tsql_test.go - Preserved MySQL IF/REPLACE function parsing from main - Preserved isNonReservedKeyword() for T-SQL identifier handling
2 parents b71e20a + e60d0bd commit b60eaa3

File tree

11 files changed

+794
-42
lines changed

11 files changed

+794
-42
lines changed

pkg/models/token_type.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,9 @@ const (
406406
TokenTypePolicy TokenType = 515 // POLICY keyword for CREATE/ALTER POLICY
407407
TokenTypeUntil TokenType = 516 // UNTIL keyword for VALID UNTIL
408408
TokenTypeReset TokenType = 517 // RESET keyword for ALTER ROLE RESET
409+
TokenTypeShow TokenType = 518 // SHOW keyword for MySQL SHOW commands
410+
TokenTypeDescribe TokenType = 519 // DESCRIBE keyword for MySQL DESCRIBE command
411+
TokenTypeExplain TokenType = 520 // EXPLAIN keyword
409412
)
410413

411414
// String returns a string representation of the token type.
@@ -1014,6 +1017,12 @@ func (t TokenType) String() string {
10141017
return "UNTIL"
10151018
case TokenTypeReset:
10161019
return "RESET"
1020+
case TokenTypeShow:
1021+
return "SHOW"
1022+
case TokenTypeDescribe:
1023+
return "DESCRIBE"
1024+
case TokenTypeExplain:
1025+
return "EXPLAIN"
10171026

10181027
default:
10191028
return "TOKEN"

pkg/sql/ast/ast.go

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,14 +1117,15 @@ func (a ArraySliceExpression) Children() []Node {
11171117

11181118
// InsertStatement represents an INSERT SQL statement
11191119
type InsertStatement struct {
1120-
With *WithClause
1121-
TableName string
1122-
Columns []Expression
1123-
Output []Expression // SQL Server OUTPUT clause columns
1124-
Values [][]Expression // Multi-row support: each inner slice is one row of values
1125-
Query QueryExpression // For INSERT ... SELECT (SelectStatement or SetOperation)
1126-
Returning []Expression
1127-
OnConflict *OnConflict
1120+
With *WithClause
1121+
TableName string
1122+
Columns []Expression
1123+
Output []Expression // SQL Server OUTPUT clause columns
1124+
Values [][]Expression // Multi-row support: each inner slice is one row of values
1125+
Query QueryExpression // For INSERT ... SELECT (SelectStatement or SetOperation)
1126+
Returning []Expression
1127+
OnConflict *OnConflict
1128+
OnDuplicateKey *UpsertClause // MySQL: ON DUPLICATE KEY UPDATE
11281129
}
11291130

11301131
func (i *InsertStatement) statementNode() {}
@@ -1147,6 +1148,9 @@ func (i InsertStatement) Children() []Node {
11471148
if i.OnConflict != nil {
11481149
children = append(children, i.OnConflict)
11491150
}
1151+
if i.OnDuplicateKey != nil {
1152+
children = append(children, i.OnDuplicateKey)
1153+
}
11501154
return children
11511155
}
11521156

@@ -1715,3 +1719,41 @@ func (a AST) Children() []Node {
17151719
}
17161720
return children
17171721
}
1722+
1723+
// ShowStatement represents MySQL SHOW commands (SHOW TABLES, SHOW DATABASES, SHOW CREATE TABLE x, etc.)
1724+
type ShowStatement struct {
1725+
ShowType string // TABLES, DATABASES, CREATE TABLE, COLUMNS, INDEX, etc.
1726+
ObjectName string // For SHOW CREATE TABLE x, SHOW COLUMNS FROM x, etc.
1727+
From string // For SHOW ... FROM database
1728+
}
1729+
1730+
func (s *ShowStatement) statementNode() {}
1731+
func (s ShowStatement) TokenLiteral() string { return "SHOW" }
1732+
func (s ShowStatement) Children() []Node { return nil }
1733+
1734+
// DescribeStatement represents MySQL DESCRIBE/DESC/EXPLAIN table commands
1735+
type DescribeStatement struct {
1736+
TableName string
1737+
}
1738+
1739+
func (d *DescribeStatement) statementNode() {}
1740+
func (d DescribeStatement) TokenLiteral() string { return "DESCRIBE" }
1741+
func (d DescribeStatement) Children() []Node { return nil }
1742+
1743+
// ReplaceStatement represents MySQL REPLACE INTO statement
1744+
type ReplaceStatement struct {
1745+
TableName string
1746+
Columns []Expression
1747+
Values [][]Expression
1748+
}
1749+
1750+
func (r *ReplaceStatement) statementNode() {}
1751+
func (r ReplaceStatement) TokenLiteral() string { return "REPLACE" }
1752+
func (r ReplaceStatement) Children() []Node {
1753+
children := make([]Node, 0)
1754+
children = append(children, nodifyExpressions(r.Columns)...)
1755+
for _, row := range r.Values {
1756+
children = append(children, nodifyExpressions(row)...)
1757+
}
1758+
return children
1759+
}

pkg/sql/parser/dml.go

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) {
5959
var outputCols []ast.Expression
6060
if p.dialect == string(keywords.DialectSQLServer) && strings.ToUpper(p.currentToken.Literal) == "OUTPUT" {
6161
p.advance() // Consume OUTPUT
62-
outputCols, _ = p.parseOutputColumns()
62+
var err error
63+
outputCols, err = p.parseOutputColumns()
64+
if err != nil {
65+
return nil, err
66+
}
6367
}
6468

6569
// Parse VALUES or SELECT
@@ -128,18 +132,37 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) {
128132
return nil, p.expectedError("VALUES or SELECT")
129133
}
130134

131-
// Parse ON CONFLICT clause if present (PostgreSQL UPSERT)
135+
// Parse ON CONFLICT clause (PostgreSQL) or ON DUPLICATE KEY UPDATE (MySQL)
132136
var onConflict *ast.OnConflict
137+
var onDuplicateKey *ast.UpsertClause
133138
if p.isType(models.TokenTypeOn) {
134-
// Peek ahead to check for CONFLICT
135-
if p.peekToken().Literal == "CONFLICT" {
139+
nextLit := strings.ToUpper(p.peekToken().Literal)
140+
if nextLit == "CONFLICT" {
136141
p.advance() // Consume ON
137142
p.advance() // Consume CONFLICT
138143
var err error
139144
onConflict, err = p.parseOnConflictClause()
140145
if err != nil {
141146
return nil, err
142147
}
148+
} else if nextLit == "DUPLICATE" {
149+
p.advance() // Consume ON
150+
p.advance() // Consume DUPLICATE
151+
// Expect KEY
152+
if strings.ToUpper(p.currentToken.Literal) != "KEY" && !p.isType(models.TokenTypeKey) {
153+
return nil, p.expectedError("KEY")
154+
}
155+
p.advance() // Consume KEY
156+
// Expect UPDATE
157+
if !p.isType(models.TokenTypeUpdate) {
158+
return nil, p.expectedError("UPDATE")
159+
}
160+
p.advance() // Consume UPDATE
161+
var err error
162+
onDuplicateKey, err = p.parseOnDuplicateKeyUpdateClause()
163+
if err != nil {
164+
return nil, err
165+
}
143166
}
144167
}
145168

@@ -156,13 +179,14 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) {
156179

157180
// Create INSERT statement
158181
return &ast.InsertStatement{
159-
TableName: tableName,
160-
Columns: columns,
161-
Output: outputCols,
162-
Values: values,
163-
Query: query,
164-
OnConflict: onConflict,
165-
Returning: returning,
182+
TableName: tableName,
183+
Columns: columns,
184+
Output: outputCols,
185+
Values: values,
186+
Query: query,
187+
OnConflict: onConflict,
188+
OnDuplicateKey: onDuplicateKey,
189+
Returning: returning,
166190
}, nil
167191
}
168192

@@ -246,6 +270,14 @@ func (p *Parser) parseUpdateStatement() (ast.Statement, error) {
246270
}
247271
}
248272

273+
// Parse LIMIT clause if present (MySQL)
274+
if p.isType(models.TokenTypeLimit) {
275+
p.advance() // Consume LIMIT
276+
if p.isNumericLiteral() {
277+
p.advance() // Consume limit value (MySQL UPDATE LIMIT)
278+
}
279+
}
280+
249281
// Parse RETURNING clause if present (PostgreSQL)
250282
var returning []ast.Expression
251283
if p.isType(models.TokenTypeReturning) || p.currentToken.Literal == "RETURNING" {
@@ -293,6 +325,14 @@ func (p *Parser) parseDeleteStatement() (ast.Statement, error) {
293325
}
294326
}
295327

328+
// Parse LIMIT clause if present (MySQL)
329+
if p.isType(models.TokenTypeLimit) {
330+
p.advance() // Consume LIMIT
331+
if p.isNumericLiteral() {
332+
p.advance() // Consume limit value
333+
}
334+
}
335+
296336
// Parse RETURNING clause if present (PostgreSQL)
297337
var returning []ast.Expression
298338
if p.isType(models.TokenTypeReturning) || p.currentToken.Literal == "RETURNING" {
@@ -748,3 +788,39 @@ func (p *Parser) parseOutputColumns() ([]ast.Expression, error) {
748788
}
749789
return cols, nil
750790
}
791+
792+
// parseOnDuplicateKeyUpdateClause parses the assignments in ON DUPLICATE KEY UPDATE
793+
func (p *Parser) parseOnDuplicateKeyUpdateClause() (*ast.UpsertClause, error) {
794+
upsert := &ast.UpsertClause{}
795+
for {
796+
if !p.isIdentifier() {
797+
return nil, p.expectedError("column name in ON DUPLICATE KEY UPDATE")
798+
}
799+
columnName := p.currentToken.Literal
800+
p.advance()
801+
802+
if !p.isType(models.TokenTypeEq) {
803+
return nil, p.expectedError("=")
804+
}
805+
p.advance()
806+
807+
value, err := p.parseExpression()
808+
if err != nil {
809+
return nil, fmt.Errorf("failed to parse ON DUPLICATE KEY UPDATE value: %w", err)
810+
}
811+
812+
upsert.Updates = append(upsert.Updates, ast.UpdateExpression{
813+
Column: &ast.Identifier{Name: columnName},
814+
Value: value,
815+
})
816+
817+
if !p.isType(models.TokenTypeComma) {
818+
break
819+
}
820+
p.advance() // Consume comma
821+
}
822+
return upsert, nil
823+
}
824+
825+
// parseTableReference parses a simple table reference (table name)
826+
// Returns a TableReference with the Name field populated

pkg/sql/parser/expressions.go

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,26 @@ func (p *Parser) parseComparisonExpression() (ast.Expression, error) {
172172
}, nil
173173
}
174174

175+
// Check for REGEXP/RLIKE operator (MySQL)
176+
if strings.EqualFold(p.currentToken.Literal, "REGEXP") || strings.EqualFold(p.currentToken.Literal, "RLIKE") {
177+
operator := strings.ToUpper(p.currentToken.Literal)
178+
p.advance()
179+
pattern, err := p.parsePrimaryExpression()
180+
if err != nil {
181+
return nil, goerrors.InvalidSyntaxError(
182+
fmt.Sprintf("failed to parse REGEXP pattern: %v", err),
183+
p.currentLocation(),
184+
p.currentToken.Literal,
185+
)
186+
}
187+
return &ast.BinaryExpression{
188+
Left: left,
189+
Operator: operator,
190+
Right: pattern,
191+
Not: notPrefix,
192+
}, nil
193+
}
194+
175195
// Check for IN operator
176196
if p.isType(models.TokenTypeIn) {
177197
p.advance() // Consume IN
@@ -633,6 +653,17 @@ func (p *Parser) parsePrimaryExpression() (ast.Expression, error) {
633653
return p.parseArrayConstructor()
634654
}
635655

656+
// Handle keywords that can be used as function names in MySQL (IF, REPLACE, etc.)
657+
if (p.isType(models.TokenTypeIf) || p.isType(models.TokenTypeReplace)) && p.peekToken().Type == models.TokenTypeLParen {
658+
identName := p.currentToken.Literal
659+
p.advance()
660+
funcCall, err := p.parseFunctionCall(identName)
661+
if err != nil {
662+
return nil, err
663+
}
664+
return funcCall, nil
665+
}
666+
636667
if p.isType(models.TokenTypeIdentifier) || p.isType(models.TokenTypeDoubleQuotedString) || p.isNonReservedKeyword() {
637668
// Handle identifiers and function calls
638669
// Double-quoted strings are treated as identifiers in SQL (e.g., "column_name")
@@ -647,6 +678,12 @@ func (p *Parser) parsePrimaryExpression() (ast.Expression, error) {
647678
if err != nil {
648679
return nil, err
649680
}
681+
682+
// MySQL MATCH(...) AGAINST(...) full-text search
683+
if strings.EqualFold(identName, "MATCH") && strings.EqualFold(p.currentToken.Literal, "AGAINST") {
684+
return p.parseMatchAgainst(funcCall)
685+
}
686+
650687
return funcCall, nil
651688
}
652689

@@ -1083,21 +1120,29 @@ func (p *Parser) parseIntervalExpression() (*ast.IntervalExpression, error) {
10831120
// Consume INTERVAL keyword
10841121
p.advance()
10851122

1086-
// Expect a string literal for the interval value
1087-
if !p.isStringLiteral() {
1088-
return nil, goerrors.InvalidSyntaxError(
1089-
"expected string literal after INTERVAL keyword",
1090-
p.currentLocation(),
1091-
"Use INTERVAL 'value' syntax (e.g., INTERVAL '1 day')",
1092-
)
1123+
// Support both PostgreSQL style: INTERVAL '1 day'
1124+
// and MySQL style: INTERVAL 30 DAY, INTERVAL 1 HOUR
1125+
if p.isStringLiteral() {
1126+
value := p.currentToken.Literal
1127+
p.advance()
1128+
return &ast.IntervalExpression{Value: value}, nil
10931129
}
10941130

1095-
value := p.currentToken.Literal
1096-
p.advance() // Consume the string literal
1131+
// MySQL style: INTERVAL <number> <unit>
1132+
if p.isNumericLiteral() {
1133+
numStr := p.currentToken.Literal
1134+
p.advance()
1135+
// Expect a unit keyword (DAY, HOUR, MINUTE, SECOND, MONTH, YEAR, WEEK, etc.)
1136+
unit := strings.ToUpper(p.currentToken.Literal)
1137+
p.advance()
1138+
return &ast.IntervalExpression{Value: numStr + " " + unit}, nil
1139+
}
10971140

1098-
return &ast.IntervalExpression{
1099-
Value: value,
1100-
}, nil
1141+
return nil, goerrors.InvalidSyntaxError(
1142+
"expected string literal or number after INTERVAL keyword",
1143+
p.currentLocation(),
1144+
"Use INTERVAL '1 day' or INTERVAL 1 DAY syntax",
1145+
)
11011146
}
11021147

11031148
// parseArrayConstructor parses PostgreSQL ARRAY constructor syntax.

0 commit comments

Comments
 (0)