Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pkg/models/token_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ const (
TokenTypePolicy TokenType = 515 // POLICY keyword for CREATE/ALTER POLICY
TokenTypeUntil TokenType = 516 // UNTIL keyword for VALID UNTIL
TokenTypeReset TokenType = 517 // RESET keyword for ALTER ROLE RESET
TokenTypeShow TokenType = 518 // SHOW keyword for MySQL SHOW commands
TokenTypeDescribe TokenType = 519 // DESCRIBE keyword for MySQL DESCRIBE command
TokenTypeExplain TokenType = 520 // EXPLAIN keyword
)

// String returns a string representation of the token type.
Expand Down Expand Up @@ -1014,6 +1017,12 @@ func (t TokenType) String() string {
return "UNTIL"
case TokenTypeReset:
return "RESET"
case TokenTypeShow:
return "SHOW"
case TokenTypeDescribe:
return "DESCRIBE"
case TokenTypeExplain:
return "EXPLAIN"

default:
return "TOKEN"
Expand Down
56 changes: 49 additions & 7 deletions pkg/sql/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -1105,13 +1105,14 @@ func (a ArraySliceExpression) Children() []Node {

// InsertStatement represents an INSERT SQL statement
type InsertStatement struct {
With *WithClause
TableName string
Columns []Expression
Values [][]Expression // Multi-row support: each inner slice is one row of values
Query QueryExpression // For INSERT ... SELECT (SelectStatement or SetOperation)
Returning []Expression
OnConflict *OnConflict
With *WithClause
TableName string
Columns []Expression
Values [][]Expression // Multi-row support: each inner slice is one row of values
Query QueryExpression // For INSERT ... SELECT (SelectStatement or SetOperation)
Returning []Expression
OnConflict *OnConflict
OnDuplicateKey *UpsertClause // MySQL: ON DUPLICATE KEY UPDATE
}

func (i *InsertStatement) statementNode() {}
Expand All @@ -1134,6 +1135,9 @@ func (i InsertStatement) Children() []Node {
if i.OnConflict != nil {
children = append(children, i.OnConflict)
}
if i.OnDuplicateKey != nil {
children = append(children, i.OnDuplicateKey)
}
return children
}

Expand Down Expand Up @@ -1701,3 +1705,41 @@ func (a AST) Children() []Node {
}
return children
}

// ShowStatement represents MySQL SHOW commands (SHOW TABLES, SHOW DATABASES, SHOW CREATE TABLE x, etc.)
type ShowStatement struct {
ShowType string // TABLES, DATABASES, CREATE TABLE, COLUMNS, INDEX, etc.
ObjectName string // For SHOW CREATE TABLE x, SHOW COLUMNS FROM x, etc.
From string // For SHOW ... FROM database
}

func (s *ShowStatement) statementNode() {}
func (s ShowStatement) TokenLiteral() string { return "SHOW" }
func (s ShowStatement) Children() []Node { return nil }

// DescribeStatement represents MySQL DESCRIBE/DESC/EXPLAIN table commands
type DescribeStatement struct {
TableName string
}

func (d *DescribeStatement) statementNode() {}
func (d DescribeStatement) TokenLiteral() string { return "DESCRIBE" }
func (d DescribeStatement) Children() []Node { return nil }

// ReplaceStatement represents MySQL REPLACE INTO statement
type ReplaceStatement struct {
TableName string
Columns []Expression
Values [][]Expression
}

func (r *ReplaceStatement) statementNode() {}
func (r ReplaceStatement) TokenLiteral() string { return "REPLACE" }
func (r ReplaceStatement) Children() []Node {
children := make([]Node, 0)
children = append(children, nodifyExpressions(r.Columns)...)
for _, row := range r.Values {
children = append(children, nodifyExpressions(row)...)
}
return children
}
87 changes: 78 additions & 9 deletions pkg/sql/parser/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,37 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) {
return nil, p.expectedError("VALUES or SELECT")
}

// Parse ON CONFLICT clause if present (PostgreSQL UPSERT)
// Parse ON CONFLICT clause (PostgreSQL) or ON DUPLICATE KEY UPDATE (MySQL)
var onConflict *ast.OnConflict
var onDuplicateKey *ast.UpsertClause
if p.isType(models.TokenTypeOn) {
// Peek ahead to check for CONFLICT
if p.peekToken().Literal == "CONFLICT" {
nextLit := strings.ToUpper(p.peekToken().Literal)
if nextLit == "CONFLICT" {
p.advance() // Consume ON
p.advance() // Consume CONFLICT
var err error
onConflict, err = p.parseOnConflictClause()
if err != nil {
return nil, err
}
} else if nextLit == "DUPLICATE" {
p.advance() // Consume ON
p.advance() // Consume DUPLICATE
// Expect KEY
if strings.ToUpper(p.currentToken.Literal) != "KEY" && !p.isType(models.TokenTypeKey) {
return nil, p.expectedError("KEY")
}
p.advance() // Consume KEY
// Expect UPDATE
if !p.isType(models.TokenTypeUpdate) {
return nil, p.expectedError("UPDATE")
}
p.advance() // Consume UPDATE
var err error
onDuplicateKey, err = p.parseOnDuplicateKeyUpdateClause()
if err != nil {
return nil, err
}
}
}

Expand All @@ -148,12 +167,13 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) {

// Create INSERT statement
return &ast.InsertStatement{
TableName: tableName,
Columns: columns,
Values: values,
Query: query,
OnConflict: onConflict,
Returning: returning,
TableName: tableName,
Columns: columns,
Values: values,
Query: query,
OnConflict: onConflict,
OnDuplicateKey: onDuplicateKey,
Returning: returning,
}, nil
}

Expand Down Expand Up @@ -237,6 +257,14 @@ func (p *Parser) parseUpdateStatement() (ast.Statement, error) {
}
}

// Parse LIMIT clause if present (MySQL)
if p.isType(models.TokenTypeLimit) {
p.advance() // Consume LIMIT
if p.isNumericLiteral() {
p.advance() // Consume limit value (MySQL UPDATE LIMIT)
}
}

// Parse RETURNING clause if present (PostgreSQL)
var returning []ast.Expression
if p.isType(models.TokenTypeReturning) || p.currentToken.Literal == "RETURNING" {
Expand Down Expand Up @@ -284,6 +312,14 @@ func (p *Parser) parseDeleteStatement() (ast.Statement, error) {
}
}

// Parse LIMIT clause if present (MySQL)
if p.isType(models.TokenTypeLimit) {
p.advance() // Consume LIMIT
if p.isNumericLiteral() {
p.advance() // Consume limit value
}
}

// Parse RETURNING clause if present (PostgreSQL)
var returning []ast.Expression
if p.isType(models.TokenTypeReturning) || p.currentToken.Literal == "RETURNING" {
Expand Down Expand Up @@ -712,5 +748,38 @@ func (p *Parser) parseOnConflictClause() (*ast.OnConflict, error) {
return onConflict, nil
}

// parseOnDuplicateKeyUpdateClause parses the assignments in ON DUPLICATE KEY UPDATE
func (p *Parser) parseOnDuplicateKeyUpdateClause() (*ast.UpsertClause, error) {
upsert := &ast.UpsertClause{}
for {
if !p.isIdentifier() {
return nil, p.expectedError("column name in ON DUPLICATE KEY UPDATE")
}
columnName := p.currentToken.Literal
p.advance()

if !p.isType(models.TokenTypeEq) {
return nil, p.expectedError("=")
}
p.advance()

value, err := p.parseExpression()
if err != nil {
return nil, fmt.Errorf("failed to parse ON DUPLICATE KEY UPDATE value: %w", err)
}

upsert.Updates = append(upsert.Updates, ast.UpdateExpression{
Column: &ast.Identifier{Name: columnName},
Value: value,
})

if !p.isType(models.TokenTypeComma) {
break
}
p.advance() // Consume comma
}
return upsert, nil
}

// parseTableReference parses a simple table reference (table name)
// Returns a TableReference with the Name field populated
69 changes: 57 additions & 12 deletions pkg/sql/parser/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,26 @@ func (p *Parser) parseComparisonExpression() (ast.Expression, error) {
}, nil
}

// Check for REGEXP/RLIKE operator (MySQL)
if strings.EqualFold(p.currentToken.Literal, "REGEXP") || strings.EqualFold(p.currentToken.Literal, "RLIKE") {
operator := strings.ToUpper(p.currentToken.Literal)
p.advance()
pattern, err := p.parsePrimaryExpression()
if err != nil {
return nil, goerrors.InvalidSyntaxError(
fmt.Sprintf("failed to parse REGEXP pattern: %v", err),
p.currentLocation(),
p.currentToken.Literal,
)
}
return &ast.BinaryExpression{
Left: left,
Operator: operator,
Right: pattern,
Not: notPrefix,
}, nil
}

// Check for IN operator
if p.isType(models.TokenTypeIn) {
p.advance() // Consume IN
Expand Down Expand Up @@ -619,6 +639,17 @@ func (p *Parser) parsePrimaryExpression() (ast.Expression, error) {
return p.parseArrayConstructor()
}

// Handle keywords that can be used as function names in MySQL (IF, REPLACE, etc.)
if (p.isType(models.TokenTypeIf) || p.isType(models.TokenTypeReplace)) && p.peekToken().Type == models.TokenTypeLParen {
identName := p.currentToken.Literal
p.advance()
funcCall, err := p.parseFunctionCall(identName)
if err != nil {
return nil, err
}
return funcCall, nil
}

if p.isType(models.TokenTypeIdentifier) || p.isType(models.TokenTypeDoubleQuotedString) {
// Handle identifiers and function calls
// Double-quoted strings are treated as identifiers in SQL (e.g., "column_name")
Expand All @@ -632,6 +663,12 @@ func (p *Parser) parsePrimaryExpression() (ast.Expression, error) {
if err != nil {
return nil, err
}

// MySQL MATCH(...) AGAINST(...) full-text search
if strings.EqualFold(identName, "MATCH") && strings.EqualFold(p.currentToken.Literal, "AGAINST") {
return p.parseMatchAgainst(funcCall)
}

return funcCall, nil
}

Expand Down Expand Up @@ -1068,21 +1105,29 @@ func (p *Parser) parseIntervalExpression() (*ast.IntervalExpression, error) {
// Consume INTERVAL keyword
p.advance()

// Expect a string literal for the interval value
if !p.isStringLiteral() {
return nil, goerrors.InvalidSyntaxError(
"expected string literal after INTERVAL keyword",
p.currentLocation(),
"Use INTERVAL 'value' syntax (e.g., INTERVAL '1 day')",
)
// Support both PostgreSQL style: INTERVAL '1 day'
// and MySQL style: INTERVAL 30 DAY, INTERVAL 1 HOUR
if p.isStringLiteral() {
value := p.currentToken.Literal
p.advance()
return &ast.IntervalExpression{Value: value}, nil
}

value := p.currentToken.Literal
p.advance() // Consume the string literal
// MySQL style: INTERVAL <number> <unit>
if p.isNumericLiteral() {
numStr := p.currentToken.Literal
p.advance()
// Expect a unit keyword (DAY, HOUR, MINUTE, SECOND, MONTH, YEAR, WEEK, etc.)
unit := strings.ToUpper(p.currentToken.Literal)
p.advance()
return &ast.IntervalExpression{Value: numStr + " " + unit}, nil
}

return &ast.IntervalExpression{
Value: value,
}, nil
return nil, goerrors.InvalidSyntaxError(
"expected string literal or number after INTERVAL keyword",
p.currentLocation(),
"Use INTERVAL '1 day' or INTERVAL 1 DAY syntax",
)
}

// parseArrayConstructor parses PostgreSQL ARRAY constructor syntax.
Expand Down
Loading
Loading