diff --git a/pkg/models/token_type.go b/pkg/models/token_type.go index de1768b8..d6bfd804 100644 --- a/pkg/models/token_type.go +++ b/pkg/models/token_type.go @@ -406,6 +406,9 @@ const ( TokenTypePolicy TokenType = 515 // POLICY keyword for CREATE/ALTER POLICY TokenTypeUntil TokenType = 516 // UNTIL keyword for VALID UNTIL TokenTypeReset TokenType = 517 // RESET keyword for ALTER ROLE RESET + TokenTypeShow TokenType = 518 // SHOW keyword for MySQL SHOW commands + TokenTypeDescribe TokenType = 519 // DESCRIBE keyword for MySQL DESCRIBE command + TokenTypeExplain TokenType = 520 // EXPLAIN keyword ) // String returns a string representation of the token type. @@ -1014,6 +1017,12 @@ func (t TokenType) String() string { return "UNTIL" case TokenTypeReset: return "RESET" + case TokenTypeShow: + return "SHOW" + case TokenTypeDescribe: + return "DESCRIBE" + case TokenTypeExplain: + return "EXPLAIN" default: return "TOKEN" diff --git a/pkg/sql/ast/ast.go b/pkg/sql/ast/ast.go index 399ff2e9..d07e1d66 100644 --- a/pkg/sql/ast/ast.go +++ b/pkg/sql/ast/ast.go @@ -1105,13 +1105,14 @@ func (a ArraySliceExpression) Children() []Node { // InsertStatement represents an INSERT SQL statement type InsertStatement struct { - With *WithClause - TableName string - Columns []Expression - Values [][]Expression // Multi-row support: each inner slice is one row of values - Query QueryExpression // For INSERT ... SELECT (SelectStatement or SetOperation) - Returning []Expression - OnConflict *OnConflict + With *WithClause + TableName string + Columns []Expression + Values [][]Expression // Multi-row support: each inner slice is one row of values + Query QueryExpression // For INSERT ... SELECT (SelectStatement or SetOperation) + Returning []Expression + OnConflict *OnConflict + OnDuplicateKey *UpsertClause // MySQL: ON DUPLICATE KEY UPDATE } func (i *InsertStatement) statementNode() {} @@ -1134,6 +1135,9 @@ func (i InsertStatement) Children() []Node { if i.OnConflict != nil { children = append(children, i.OnConflict) } + if i.OnDuplicateKey != nil { + children = append(children, i.OnDuplicateKey) + } return children } @@ -1701,3 +1705,41 @@ func (a AST) Children() []Node { } return children } + +// ShowStatement represents MySQL SHOW commands (SHOW TABLES, SHOW DATABASES, SHOW CREATE TABLE x, etc.) +type ShowStatement struct { + ShowType string // TABLES, DATABASES, CREATE TABLE, COLUMNS, INDEX, etc. + ObjectName string // For SHOW CREATE TABLE x, SHOW COLUMNS FROM x, etc. + From string // For SHOW ... FROM database +} + +func (s *ShowStatement) statementNode() {} +func (s ShowStatement) TokenLiteral() string { return "SHOW" } +func (s ShowStatement) Children() []Node { return nil } + +// DescribeStatement represents MySQL DESCRIBE/DESC/EXPLAIN table commands +type DescribeStatement struct { + TableName string +} + +func (d *DescribeStatement) statementNode() {} +func (d DescribeStatement) TokenLiteral() string { return "DESCRIBE" } +func (d DescribeStatement) Children() []Node { return nil } + +// ReplaceStatement represents MySQL REPLACE INTO statement +type ReplaceStatement struct { + TableName string + Columns []Expression + Values [][]Expression +} + +func (r *ReplaceStatement) statementNode() {} +func (r ReplaceStatement) TokenLiteral() string { return "REPLACE" } +func (r ReplaceStatement) Children() []Node { + children := make([]Node, 0) + children = append(children, nodifyExpressions(r.Columns)...) + for _, row := range r.Values { + children = append(children, nodifyExpressions(row)...) + } + return children +} diff --git a/pkg/sql/parser/dml.go b/pkg/sql/parser/dml.go index 2daa046e..7186eb75 100644 --- a/pkg/sql/parser/dml.go +++ b/pkg/sql/parser/dml.go @@ -120,11 +120,12 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) { return nil, p.expectedError("VALUES or SELECT") } - // Parse ON CONFLICT clause if present (PostgreSQL UPSERT) + // Parse ON CONFLICT clause (PostgreSQL) or ON DUPLICATE KEY UPDATE (MySQL) var onConflict *ast.OnConflict + var onDuplicateKey *ast.UpsertClause if p.isType(models.TokenTypeOn) { - // Peek ahead to check for CONFLICT - if p.peekToken().Literal == "CONFLICT" { + nextLit := strings.ToUpper(p.peekToken().Literal) + if nextLit == "CONFLICT" { p.advance() // Consume ON p.advance() // Consume CONFLICT var err error @@ -132,6 +133,24 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) { if err != nil { return nil, err } + } else if nextLit == "DUPLICATE" { + p.advance() // Consume ON + p.advance() // Consume DUPLICATE + // Expect KEY + if strings.ToUpper(p.currentToken.Literal) != "KEY" && !p.isType(models.TokenTypeKey) { + return nil, p.expectedError("KEY") + } + p.advance() // Consume KEY + // Expect UPDATE + if !p.isType(models.TokenTypeUpdate) { + return nil, p.expectedError("UPDATE") + } + p.advance() // Consume UPDATE + var err error + onDuplicateKey, err = p.parseOnDuplicateKeyUpdateClause() + if err != nil { + return nil, err + } } } @@ -148,12 +167,13 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) { // Create INSERT statement return &ast.InsertStatement{ - TableName: tableName, - Columns: columns, - Values: values, - Query: query, - OnConflict: onConflict, - Returning: returning, + TableName: tableName, + Columns: columns, + Values: values, + Query: query, + OnConflict: onConflict, + OnDuplicateKey: onDuplicateKey, + Returning: returning, }, nil } @@ -237,6 +257,14 @@ func (p *Parser) parseUpdateStatement() (ast.Statement, error) { } } + // Parse LIMIT clause if present (MySQL) + if p.isType(models.TokenTypeLimit) { + p.advance() // Consume LIMIT + if p.isNumericLiteral() { + p.advance() // Consume limit value (MySQL UPDATE LIMIT) + } + } + // Parse RETURNING clause if present (PostgreSQL) var returning []ast.Expression if p.isType(models.TokenTypeReturning) || p.currentToken.Literal == "RETURNING" { @@ -284,6 +312,14 @@ func (p *Parser) parseDeleteStatement() (ast.Statement, error) { } } + // Parse LIMIT clause if present (MySQL) + if p.isType(models.TokenTypeLimit) { + p.advance() // Consume LIMIT + if p.isNumericLiteral() { + p.advance() // Consume limit value + } + } + // Parse RETURNING clause if present (PostgreSQL) var returning []ast.Expression if p.isType(models.TokenTypeReturning) || p.currentToken.Literal == "RETURNING" { @@ -712,5 +748,38 @@ func (p *Parser) parseOnConflictClause() (*ast.OnConflict, error) { return onConflict, nil } +// parseOnDuplicateKeyUpdateClause parses the assignments in ON DUPLICATE KEY UPDATE +func (p *Parser) parseOnDuplicateKeyUpdateClause() (*ast.UpsertClause, error) { + upsert := &ast.UpsertClause{} + for { + if !p.isIdentifier() { + return nil, p.expectedError("column name in ON DUPLICATE KEY UPDATE") + } + columnName := p.currentToken.Literal + p.advance() + + if !p.isType(models.TokenTypeEq) { + return nil, p.expectedError("=") + } + p.advance() + + value, err := p.parseExpression() + if err != nil { + return nil, fmt.Errorf("failed to parse ON DUPLICATE KEY UPDATE value: %w", err) + } + + upsert.Updates = append(upsert.Updates, ast.UpdateExpression{ + Column: &ast.Identifier{Name: columnName}, + Value: value, + }) + + if !p.isType(models.TokenTypeComma) { + break + } + p.advance() // Consume comma + } + return upsert, nil +} + // parseTableReference parses a simple table reference (table name) // Returns a TableReference with the Name field populated diff --git a/pkg/sql/parser/expressions.go b/pkg/sql/parser/expressions.go index 16287a57..30dc96da 100644 --- a/pkg/sql/parser/expressions.go +++ b/pkg/sql/parser/expressions.go @@ -172,6 +172,26 @@ func (p *Parser) parseComparisonExpression() (ast.Expression, error) { }, nil } + // Check for REGEXP/RLIKE operator (MySQL) + if strings.EqualFold(p.currentToken.Literal, "REGEXP") || strings.EqualFold(p.currentToken.Literal, "RLIKE") { + operator := strings.ToUpper(p.currentToken.Literal) + p.advance() + pattern, err := p.parsePrimaryExpression() + if err != nil { + return nil, goerrors.InvalidSyntaxError( + fmt.Sprintf("failed to parse REGEXP pattern: %v", err), + p.currentLocation(), + p.currentToken.Literal, + ) + } + return &ast.BinaryExpression{ + Left: left, + Operator: operator, + Right: pattern, + Not: notPrefix, + }, nil + } + // Check for IN operator if p.isType(models.TokenTypeIn) { p.advance() // Consume IN @@ -619,6 +639,17 @@ func (p *Parser) parsePrimaryExpression() (ast.Expression, error) { return p.parseArrayConstructor() } + // Handle keywords that can be used as function names in MySQL (IF, REPLACE, etc.) + if (p.isType(models.TokenTypeIf) || p.isType(models.TokenTypeReplace)) && p.peekToken().Type == models.TokenTypeLParen { + identName := p.currentToken.Literal + p.advance() + funcCall, err := p.parseFunctionCall(identName) + if err != nil { + return nil, err + } + return funcCall, nil + } + if p.isType(models.TokenTypeIdentifier) || p.isType(models.TokenTypeDoubleQuotedString) { // Handle identifiers and function calls // Double-quoted strings are treated as identifiers in SQL (e.g., "column_name") @@ -632,6 +663,12 @@ func (p *Parser) parsePrimaryExpression() (ast.Expression, error) { if err != nil { return nil, err } + + // MySQL MATCH(...) AGAINST(...) full-text search + if strings.EqualFold(identName, "MATCH") && strings.EqualFold(p.currentToken.Literal, "AGAINST") { + return p.parseMatchAgainst(funcCall) + } + return funcCall, nil } @@ -1068,21 +1105,29 @@ func (p *Parser) parseIntervalExpression() (*ast.IntervalExpression, error) { // Consume INTERVAL keyword p.advance() - // Expect a string literal for the interval value - if !p.isStringLiteral() { - return nil, goerrors.InvalidSyntaxError( - "expected string literal after INTERVAL keyword", - p.currentLocation(), - "Use INTERVAL 'value' syntax (e.g., INTERVAL '1 day')", - ) + // Support both PostgreSQL style: INTERVAL '1 day' + // and MySQL style: INTERVAL 30 DAY, INTERVAL 1 HOUR + if p.isStringLiteral() { + value := p.currentToken.Literal + p.advance() + return &ast.IntervalExpression{Value: value}, nil } - value := p.currentToken.Literal - p.advance() // Consume the string literal + // MySQL style: INTERVAL + if p.isNumericLiteral() { + numStr := p.currentToken.Literal + p.advance() + // Expect a unit keyword (DAY, HOUR, MINUTE, SECOND, MONTH, YEAR, WEEK, etc.) + unit := strings.ToUpper(p.currentToken.Literal) + p.advance() + return &ast.IntervalExpression{Value: numStr + " " + unit}, nil + } - return &ast.IntervalExpression{ - Value: value, - }, nil + return nil, goerrors.InvalidSyntaxError( + "expected string literal or number after INTERVAL keyword", + p.currentLocation(), + "Use INTERVAL '1 day' or INTERVAL 1 DAY syntax", + ) } // parseArrayConstructor parses PostgreSQL ARRAY constructor syntax. diff --git a/pkg/sql/parser/mysql.go b/pkg/sql/parser/mysql.go new file mode 100644 index 00000000..7cdd8c38 --- /dev/null +++ b/pkg/sql/parser/mysql.go @@ -0,0 +1,238 @@ +package parser + +import ( + "fmt" + "strings" + + "github.com/ajitpratap0/GoSQLX/pkg/models" + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" +) + +// parseMatchAgainst parses MySQL MATCH(...) AGAINST('text' [IN NATURAL LANGUAGE MODE | IN BOOLEAN MODE | WITH QUERY EXPANSION]) +func (p *Parser) parseMatchAgainst(matchFunc *ast.FunctionCall) (ast.Expression, error) { + p.advance() // Consume AGAINST + if !p.isType(models.TokenTypeLParen) { + return nil, p.expectedError("(") + } + p.advance() // Consume ( + + // Parse search expression (just the primary — not full expression, to avoid IN being eaten) + searchExpr, err := p.parsePrimaryExpression() + if err != nil { + return nil, fmt.Errorf("failed to parse AGAINST expression: %w", err) + } + + // Consume optional mode keywords until we hit ) + mode := "" + for !p.isType(models.TokenTypeRParen) && !p.isType(models.TokenTypeEOF) { + mode += " " + p.currentToken.Literal + p.advance() + } + + if !p.isType(models.TokenTypeRParen) { + return nil, p.expectedError(")") + } + p.advance() // Consume ) + + // Represent as a binary expression: MATCH(cols) AGAINST(expr) + // Store the search expr and mode as a function call named "AGAINST" + againstFunc := &ast.FunctionCall{ + Name: "AGAINST", + Arguments: []ast.Expression{searchExpr}, + } + if mode != "" { + againstFunc.Arguments = append(againstFunc.Arguments, &ast.LiteralValue{ + Value: strings.TrimSpace(mode), + Type: "STRING", + }) + } + + return &ast.BinaryExpression{ + Left: matchFunc, + Operator: "AGAINST", + Right: againstFunc, + }, nil +} + +// parseShowStatement parses MySQL SHOW commands: +// - SHOW TABLES +// - SHOW DATABASES +// - SHOW CREATE TABLE name +// - SHOW COLUMNS FROM name +// - SHOW INDEX FROM name +func (p *Parser) parseShowStatement() (ast.Statement, error) { + show := &ast.ShowStatement{} + + upper := strings.ToUpper(p.currentToken.Literal) + + switch upper { + case "TABLES": + show.ShowType = "TABLES" + p.advance() + // Optional FROM database + if p.isType(models.TokenTypeFrom) { + p.advance() + show.From = p.currentToken.Literal + p.advance() + } + case "DATABASES": + show.ShowType = "DATABASES" + p.advance() + case "CREATE": + p.advance() // Consume CREATE + if p.isType(models.TokenTypeTable) { + show.ShowType = "CREATE TABLE" + p.advance() // Consume TABLE + name, err := p.parseQualifiedName() + if err != nil { + return nil, p.expectedError("table name") + } + show.ObjectName = name + } else { + show.ShowType = "CREATE " + strings.ToUpper(p.currentToken.Literal) + p.advance() + name, err := p.parseQualifiedName() + if err != nil { + return nil, p.expectedError("object name") + } + show.ObjectName = name + } + case "COLUMNS": + show.ShowType = "COLUMNS" + p.advance() + if p.isType(models.TokenTypeFrom) { + p.advance() + name, err := p.parseQualifiedName() + if err != nil { + return nil, p.expectedError("table name") + } + show.ObjectName = name + } + case "INDEX", "INDEXES", "KEYS": + show.ShowType = upper + p.advance() + if p.isType(models.TokenTypeFrom) { + p.advance() + name, err := p.parseQualifiedName() + if err != nil { + return nil, p.expectedError("table name") + } + show.ObjectName = name + } + case "STATUS", "VARIABLES": + show.ShowType = upper + p.advance() + default: + // Generic: SHOW + show.ShowType = upper + p.advance() + } + + return show, nil +} + +// parseDescribeStatement parses DESCRIBE/DESC/EXPLAIN table_name +func (p *Parser) parseDescribeStatement() (ast.Statement, error) { + // For EXPLAIN SELECT ..., defer to parseStatement for the SELECT + // For DESCRIBE table_name, just parse the table name + if p.isType(models.TokenTypeSelect) { + // EXPLAIN SELECT ... — treat as describe with the query text + // For now, just skip to parse the select + p.advance() + stmt, err := p.parseSelectWithSetOperations() + if err != nil { + return nil, err + } + // Wrap in a describe + _ = stmt + return &ast.DescribeStatement{TableName: "SELECT"}, nil + } + + name, err := p.parseQualifiedName() + if err != nil { + return nil, p.expectedError("table name") + } + return &ast.DescribeStatement{TableName: name}, nil +} + +// parseReplaceStatement parses MySQL REPLACE INTO statement +func (p *Parser) parseReplaceStatement() (ast.Statement, error) { + // Expect INTO + if !p.isType(models.TokenTypeInto) { + return nil, p.expectedError("INTO") + } + p.advance() + + // Parse table name + tableName, err := p.parseQualifiedName() + if err != nil { + return nil, p.expectedError("table name") + } + + // Parse column list if present + columns := make([]ast.Expression, 0) + if p.isType(models.TokenTypeLParen) { + p.advance() + for { + if !p.isIdentifier() { + return nil, p.expectedError("column name") + } + columns = append(columns, &ast.Identifier{Name: p.currentToken.Literal}) + p.advance() + if !p.isType(models.TokenTypeComma) { + break + } + p.advance() + } + if !p.isType(models.TokenTypeRParen) { + return nil, p.expectedError(")") + } + p.advance() + } + + // Parse VALUES + if !p.isType(models.TokenTypeValues) { + return nil, p.expectedError("VALUES") + } + p.advance() + + values := make([][]ast.Expression, 0) + for { + if !p.isType(models.TokenTypeLParen) { + if len(values) == 0 { + return nil, p.expectedError("(") + } + break + } + p.advance() + + row := make([]ast.Expression, 0) + for { + expr, err := p.parseExpression() + if err != nil { + return nil, fmt.Errorf("failed to parse value in REPLACE: %w", err) + } + row = append(row, expr) + if !p.isType(models.TokenTypeComma) { + break + } + p.advance() + } + if !p.isType(models.TokenTypeRParen) { + return nil, p.expectedError(")") + } + p.advance() + values = append(values, row) + + if !p.isType(models.TokenTypeComma) { + break + } + p.advance() + } + + return &ast.ReplaceStatement{ + TableName: tableName, + Columns: columns, + Values: values, + }, nil +} diff --git a/pkg/sql/parser/mysql_test.go b/pkg/sql/parser/mysql_test.go new file mode 100644 index 00000000..8e9f4c55 --- /dev/null +++ b/pkg/sql/parser/mysql_test.go @@ -0,0 +1,291 @@ +package parser + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + "github.com/ajitpratap0/GoSQLX/pkg/sql/keywords" +) + +// TestMySQLLimitOffsetSyntax tests MySQL-style LIMIT offset, count +func TestMySQLLimitOffsetSyntax(t *testing.T) { + tests := []struct { + name string + sql string + wantLimit int + wantOffset int + }{ + { + name: "LIMIT offset, count", + sql: "SELECT * FROM posts LIMIT 10, 20", + wantLimit: 20, + wantOffset: 10, + }, + { + name: "LIMIT count only", + sql: "SELECT * FROM posts LIMIT 5", + wantLimit: 5, + wantOffset: 0, + }, + { + name: "LIMIT with ORDER BY", + sql: "SELECT * FROM posts ORDER BY id DESC LIMIT 0, 50", + wantLimit: 50, + wantOffset: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ParseWithDialect(tt.sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } + if len(result.Statements) == 0 { + t.Fatal("expected at least one statement") + } + sel, ok := result.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", result.Statements[0]) + } + if sel.Limit == nil { + t.Fatal("expected non-nil Limit") + } + if *sel.Limit != tt.wantLimit { + t.Errorf("Limit = %d, want %d", *sel.Limit, tt.wantLimit) + } + if tt.wantOffset > 0 { + if sel.Offset == nil { + t.Fatal("expected non-nil Offset") + } + if *sel.Offset != tt.wantOffset { + t.Errorf("Offset = %d, want %d", *sel.Offset, tt.wantOffset) + } + } + }) + } +} + +// TestMySQLOnDuplicateKeyUpdate tests ON DUPLICATE KEY UPDATE parsing +func TestMySQLOnDuplicateKeyUpdate(t *testing.T) { + sql := `INSERT INTO user_stats (user_id, login_count) VALUES (1, 1) + ON DUPLICATE KEY UPDATE login_count = login_count + 1` + + result, err := ParseWithDialect(sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } + + stmt, ok := result.Statements[0].(*ast.InsertStatement) + if !ok { + t.Fatalf("expected InsertStatement, got %T", result.Statements[0]) + } + if stmt.OnDuplicateKey == nil { + t.Fatal("expected non-nil OnDuplicateKey") + } + if len(stmt.OnDuplicateKey.Updates) != 1 { + t.Fatalf("expected 1 update, got %d", len(stmt.OnDuplicateKey.Updates)) + } + col, ok := stmt.OnDuplicateKey.Updates[0].Column.(*ast.Identifier) + if !ok || col.Name != "login_count" { + t.Errorf("expected column login_count, got %v", stmt.OnDuplicateKey.Updates[0].Column) + } +} + +// TestMySQLBacktickIdentifiers tests backtick-quoted identifiers +func TestMySQLBacktickIdentifiers(t *testing.T) { + tests := []string{ + "SELECT `id`, `name` FROM `users`", + "SELECT `tbl`.`col` FROM `mydb`.`tbl`", + "SELECT `select` FROM `from`", + } + + for _, sql := range tests { + t.Run(sql, func(t *testing.T) { + _, err := ParseWithDialect(sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } + }) + } +} + +// TestMySQLShowStatements tests SHOW command parsing +func TestMySQLShowStatements(t *testing.T) { + tests := []struct { + sql string + showType string + objName string + }{ + {"SHOW TABLES", "TABLES", ""}, + {"SHOW DATABASES", "DATABASES", ""}, + {"SHOW CREATE TABLE users", "CREATE TABLE", "users"}, + } + + for _, tt := range tests { + t.Run(tt.sql, func(t *testing.T) { + result, err := ParseWithDialect(tt.sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } + show, ok := result.Statements[0].(*ast.ShowStatement) + if !ok { + t.Fatalf("expected ShowStatement, got %T", result.Statements[0]) + } + if show.ShowType != tt.showType { + t.Errorf("ShowType = %q, want %q", show.ShowType, tt.showType) + } + if tt.objName != "" && show.ObjectName != tt.objName { + t.Errorf("ObjectName = %q, want %q", show.ObjectName, tt.objName) + } + }) + } +} + +// TestMySQLDescribeStatement tests DESCRIBE command parsing +func TestMySQLDescribeStatement(t *testing.T) { + tests := []string{ + "DESCRIBE users", + "DESCRIBE schema1.users", + } + + for _, sql := range tests { + t.Run(sql, func(t *testing.T) { + result, err := ParseWithDialect(sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } + desc, ok := result.Statements[0].(*ast.DescribeStatement) + if !ok { + t.Fatalf("expected DescribeStatement, got %T", result.Statements[0]) + } + if desc.TableName == "" { + t.Error("expected non-empty TableName") + } + }) + } +} + +// TestMySQLReplaceInto tests REPLACE INTO parsing +func TestMySQLReplaceInto(t *testing.T) { + sql := "REPLACE INTO cache (key_name, value) VALUES ('k1', 'v1')" + + result, err := ParseWithDialect(sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } + + stmt, ok := result.Statements[0].(*ast.ReplaceStatement) + if !ok { + t.Fatalf("expected ReplaceStatement, got %T", result.Statements[0]) + } + if stmt.TableName != "cache" { + t.Errorf("TableName = %q, want cache", stmt.TableName) + } + if len(stmt.Columns) != 2 { + t.Errorf("expected 2 columns, got %d", len(stmt.Columns)) + } + if len(stmt.Values) != 1 { + t.Errorf("expected 1 value row, got %d", len(stmt.Values)) + } +} + +// TestMySQLUpdateWithLimit tests UPDATE ... LIMIT +func TestMySQLUpdateWithLimit(t *testing.T) { + sql := "UPDATE users SET active = 0 WHERE last_login < '2024-01-01' LIMIT 100" + _, err := ParseWithDialect(sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } +} + +// TestMySQLDeleteWithLimit tests DELETE ... LIMIT +func TestMySQLDeleteWithLimit(t *testing.T) { + sql := "DELETE FROM logs WHERE created_at < '2024-01-01' LIMIT 1000" + _, err := ParseWithDialect(sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } +} + +// TestMySQLIntervalNumericSyntax tests INTERVAL 1 DAY style +func TestMySQLIntervalNumericSyntax(t *testing.T) { + sql := "SELECT DATE_ADD(NOW(), INTERVAL 30 DAY) FROM dual" + _, err := ParseWithDialect(sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } +} + +// TestMySQLIFFunction tests IF() function +func TestMySQLIFFunction(t *testing.T) { + sql := "SELECT IF(salary > 50000, 'High', 'Low') FROM employees" + _, err := ParseWithDialect(sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } +} + +// TestMySQLGroupConcat tests GROUP_CONCAT with SEPARATOR +func TestMySQLGroupConcat(t *testing.T) { + sql := "SELECT GROUP_CONCAT(name ORDER BY name SEPARATOR ', ') FROM users GROUP BY dept" + _, err := ParseWithDialect(sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } +} + +// TestMySQLMatchAgainst tests MATCH AGAINST full-text search +func TestMySQLMatchAgainst(t *testing.T) { + sql := "SELECT * FROM articles WHERE MATCH(title, content) AGAINST('search term' IN NATURAL LANGUAGE MODE)" + _, err := ParseWithDialect(sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } +} + +// TestMySQLRegexp tests REGEXP operator +func TestMySQLRegexp(t *testing.T) { + sql := "SELECT * FROM users WHERE email REGEXP '^[a-z]+@[a-z]+$'" + _, err := ParseWithDialect(sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } +} + +// TestMySQLTestdataIntegration runs all 30 MySQL test files +func TestMySQLTestdataIntegration(t *testing.T) { + files, err := filepath.Glob("../../../testdata/mysql/*.sql") + if err != nil { + t.Fatalf("glob failed: %v", err) + } + if len(files) == 0 { + t.Skip("no MySQL test files found") + } + + for _, f := range files { + t.Run(filepath.Base(f), func(t *testing.T) { + data, err := os.ReadFile(f) + if err != nil { + t.Fatalf("read file: %v", err) + } + lines := strings.Split(string(data), "\n") + var sqlLines []string + for _, l := range lines { + trimmed := strings.TrimSpace(l) + if trimmed == "" || strings.HasPrefix(trimmed, "--") { + continue + } + sqlLines = append(sqlLines, l) + } + sql := strings.Join(sqlLines, "\n") + _, err = ParseWithDialect(sql, keywords.DialectMySQL) + if err != nil { + t.Fatalf("ParseWithDialect failed: %v", err) + } + }) + } +} diff --git a/pkg/sql/parser/parser.go b/pkg/sql/parser/parser.go index 8fc1636f..5f58ae0c 100644 --- a/pkg/sql/parser/parser.go +++ b/pkg/sql/parser/parser.go @@ -701,6 +701,15 @@ func (p *Parser) parseStatement() (ast.Statement, error) { case models.TokenTypeTruncate: p.advance() return p.parseTruncateStatement() + case models.TokenTypeShow: + p.advance() + return p.parseShowStatement() + case models.TokenTypeDescribe, models.TokenTypeExplain: + p.advance() + return p.parseDescribeStatement() + case models.TokenTypeReplace: + p.advance() + return p.parseReplaceStatement() } return nil, p.expectedError("statement") } diff --git a/pkg/sql/parser/select.go b/pkg/sql/parser/select.go index 7185c473..7b117992 100644 --- a/pkg/sql/parser/select.go +++ b/pkg/sql/parser/select.go @@ -974,12 +974,25 @@ func (p *Parser) parseSelectStatement() (ast.Statement, error) { } // Convert string to int - limitVal := 0 - _, _ = fmt.Sscanf(p.currentToken.Literal, "%d", &limitVal) - - // Add LIMIT to SELECT statement - selectStmt.Limit = &limitVal + firstVal := 0 + _, _ = fmt.Sscanf(p.currentToken.Literal, "%d", &firstVal) p.advance() + + // MySQL-style LIMIT offset, count: LIMIT 10, 20 + if p.dialect == "mysql" && p.isType(models.TokenTypeComma) { + p.advance() // Consume comma + if !p.isNumericLiteral() { + return nil, p.expectedError("integer for LIMIT count") + } + secondVal := 0 + _, _ = fmt.Sscanf(p.currentToken.Literal, "%d", &secondVal) + p.advance() + // In MySQL LIMIT offset, count: first is offset, second is count + selectStmt.Offset = &firstVal + selectStmt.Limit = &secondVal + } else { + selectStmt.Limit = &firstVal + } } // Parse OFFSET clause if present (MySQL-style OFFSET or SQL-99 OFFSET ... ROWS) diff --git a/pkg/sql/parser/window.go b/pkg/sql/parser/window.go index 75ca1517..9b1566fd 100644 --- a/pkg/sql/parser/window.go +++ b/pkg/sql/parser/window.go @@ -5,6 +5,8 @@ package parser import ( + "strings" + "github.com/ajitpratap0/GoSQLX/pkg/models" "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" ) @@ -41,6 +43,15 @@ func (p *Parser) parseFunctionCall(funcName string) (*ast.FunctionCall, error) { p.advance() // Consume comma } else if p.isType(models.TokenTypeRParen) || p.isType(models.TokenTypeOrder) { break + } else if strings.ToUpper(p.currentToken.Literal) == "SEPARATOR" { + // MySQL GROUP_CONCAT SEPARATOR clause + p.advance() // Consume SEPARATOR + sepArg, err := p.parseExpression() + if err != nil { + return nil, err + } + arguments = append(arguments, sepArg) + break } else { return nil, p.expectedError(", or )") } @@ -96,12 +107,24 @@ func (p *Parser) parseFunctionCall(funcName string) (*ast.FunctionCall, error) { p.advance() // Consume comma } else if p.isType(models.TokenTypeRParen) { break + } else if strings.EqualFold(p.currentToken.Literal, "SEPARATOR") { + break // Let SEPARATOR be handled below } else { return nil, p.expectedError(", or )") } } } + // Handle MySQL SEPARATOR clause (GROUP_CONCAT) + if strings.EqualFold(p.currentToken.Literal, "SEPARATOR") { + p.advance() // Consume SEPARATOR + sepExpr, err := p.parseExpression() + if err != nil { + return nil, err + } + arguments = append(arguments, sepExpr) + } + // Expect closing parenthesis if !p.isType(models.TokenTypeRParen) { return nil, p.expectedError(")") diff --git a/pkg/sql/tokenizer/tokenizer.go b/pkg/sql/tokenizer/tokenizer.go index f8d721c6..49eca9d3 100644 --- a/pkg/sql/tokenizer/tokenizer.go +++ b/pkg/sql/tokenizer/tokenizer.go @@ -205,6 +205,12 @@ var keywordTokenTypes = map[string]models.TokenType{ "SKIP": models.TokenTypeSkip, "LOCKED": models.TokenTypeLocked, "OF": models.TokenTypeOf, + // MySQL admin/utility keywords + "SHOW": models.TokenTypeShow, + "DESCRIBE": models.TokenTypeDescribe, + "EXPLAIN": models.TokenTypeExplain, + "DATABASES": models.TokenTypeKeyword, + "TABLES": models.TokenTypeKeyword, } // Tokenizer provides high-performance SQL tokenization with zero-copy operations.