Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cmd/gosqlx/cmd/sql_formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,15 @@ func (f *SQLFormatter) formatInsert(stmt *ast.InsertStatement) error {

if stmt.Query != nil {
f.writeNewline()
return f.formatSelect(stmt.Query)
if sel, ok := stmt.Query.(*ast.SelectStatement); ok {
return f.formatSelect(sel)
}
// For SetOperation or other statement types, use Format if available
if fmtable, ok := stmt.Query.(interface {
Format(ast.FormatOptions) string
}); ok {
f.builder.WriteString(fmtable.Format(ast.FormatOptions{}))
}
}

return nil
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -1095,8 +1095,8 @@ type InsertStatement struct {
With *WithClause
TableName string
Columns []Expression
Values [][]Expression // Multi-row support: each inner slice is one row of values
Query *SelectStatement // For INSERT ... SELECT
Values [][]Expression // Multi-row support: each inner slice is one row of values
Query Statement // For INSERT ... SELECT (SelectStatement or SetOperation)
Returning []Expression
OnConflict *OnConflict
}
Expand Down
6 changes: 5 additions & 1 deletion pkg/sql/ast/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,11 @@ func (i *InsertStatement) Format(opts FormatOptions) string {

if i.Query != nil {
sb.WriteString(f.clauseSep())
sb.WriteString(i.Query.Format(opts))
if fq, ok := i.Query.(interface{ Format(FormatOptions) string }); ok {
sb.WriteString(fq.Format(opts))
} else {
sb.WriteString(stmtSQL(i.Query))
}
} else if len(i.Values) > 0 {
sb.WriteString(f.clauseSep())
sb.WriteString(f.kw("VALUES"))
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/ast/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ func (i *InsertStatement) SQL() string {

if i.Query != nil {
sb.WriteString(" ")
sb.WriteString(i.Query.SQL())
sb.WriteString(stmtSQL(i.Query))
} else if len(i.Values) > 0 {
sb.WriteString(" VALUES ")
rows := make([]string, len(i.Values))
Expand Down
87 changes: 50 additions & 37 deletions pkg/sql/parser/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,53 +54,65 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) {
p.advance() // Consume )
}

// Parse VALUES
if !p.isType(models.TokenTypeValues) {
return nil, p.expectedError("VALUES")
}
p.advance() // Consume VALUES

// Parse value rows - supports multi-row INSERT: VALUES (a, b), (c, d), (e, f)
values := make([][]ast.Expression, 0)
for {
if !p.isType(models.TokenTypeLParen) {
if len(values) == 0 {
return nil, p.expectedError("(")
}
break
// Parse VALUES or SELECT
var values [][]ast.Expression
var query ast.Statement

if p.isType(models.TokenTypeSelect) {
// INSERT ... SELECT syntax
p.advance() // Consume SELECT
stmt, err := p.parseSelectWithSetOperations()
if err != nil {
return nil, err
}
p.advance() // Consume (
query = stmt
} else if p.isType(models.TokenTypeValues) {
p.advance() // Consume VALUES

// Parse one row of values
row := make([]ast.Expression, 0)
// Parse value rows - supports multi-row INSERT: VALUES (a, b), (c, d), (e, f)
values = make([][]ast.Expression, 0)
for {
// Parse value using parseExpression to support all expression types
// including function calls like NOW(), UUID(), etc.
expr, err := p.parseExpression()
if err != nil {
return nil, fmt.Errorf("failed to parse value at position %d in VALUES row %d: %w", len(row)+1, len(values)+1, err)
if !p.isType(models.TokenTypeLParen) {
if len(values) == 0 {
return nil, p.expectedError("(")
}
break
}
row = append(row, expr)
p.advance() // Consume (

// Check if there are more values in this row
if !p.isType(models.TokenTypeComma) {
break
// Parse one row of values
row := make([]ast.Expression, 0)
for {
// Parse value using parseExpression to support all expression types
// including function calls like NOW(), UUID(), etc.
expr, err := p.parseExpression()
if err != nil {
return nil, fmt.Errorf("failed to parse value at position %d in VALUES row %d: %w", len(row)+1, len(values)+1, err)
}
row = append(row, expr)

// Check if there are more values in this row
if !p.isType(models.TokenTypeComma) {
break
}
p.advance() // Consume comma
}
p.advance() // Consume comma
}

if !p.isType(models.TokenTypeRParen) {
return nil, p.expectedError(")")
}
p.advance() // Consume )
if !p.isType(models.TokenTypeRParen) {
return nil, p.expectedError(")")
}
p.advance() // Consume )

values = append(values, row)
values = append(values, row)

// Check if there are more rows (comma after closing paren)
if !p.isType(models.TokenTypeComma) {
break
// Check if there are more rows (comma after closing paren)
if !p.isType(models.TokenTypeComma) {
break
}
p.advance() // Consume comma between rows
}
p.advance() // Consume comma between rows
} else {
return nil, p.expectedError("VALUES or SELECT")
}

// Parse ON CONFLICT clause if present (PostgreSQL UPSERT)
Expand Down Expand Up @@ -134,6 +146,7 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) {
TableName: tableName,
Columns: columns,
Values: values,
Query: query,
OnConflict: onConflict,
Returning: returning,
}, nil
Expand Down
58 changes: 58 additions & 0 deletions pkg/sql/parser/insert_select_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package parser

import (
"testing"

"github.com/ajitpratap0/GoSQLX/pkg/sql/ast"
)

func TestInsertSelect(t *testing.T) {
tests := []struct {
name string
input string
wantQuery bool
wantCols int
}{
{"with columns", "INSERT INTO t1 (a) SELECT a FROM t2", true, 1},
{"without columns", "INSERT INTO t1 SELECT * FROM t2", true, 0},
{"multiple columns and WHERE", "INSERT INTO t1 (a, b) SELECT a, b FROM t2 WHERE x > 1", true, 2},
{"with UNION", "INSERT INTO t1 SELECT a FROM t2 UNION SELECT a FROM t3", true, 0},
{"VALUES still works", "INSERT INTO t1 VALUES (1)", false, 0},
{"VALUES with columns", "INSERT INTO t1 (a, b) VALUES (1, 2)", false, 2},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokens := tokenizeSQL(t, tt.input)
p := NewParser()
result, err := p.Parse(tokens)
if err != nil {
t.Fatalf("Parse(%q) error: %v", tt.input, err)
}
if len(result.Statements) < 1 {
t.Fatalf("expected at least 1 statement, got %d", len(result.Statements))
}

// For UNION case, the top-level might be SetOperation
if tt.name == "with UNION" {
// Just verify it parsed without error
return
}

insert, ok := result.Statements[0].(*ast.InsertStatement)
if !ok {
t.Fatalf("expected InsertStatement, got %T", result.Statements[0])
}

if (insert.Query != nil) != tt.wantQuery {
t.Errorf("Query present = %v, want %v", insert.Query != nil, tt.wantQuery)
}
if len(insert.Columns) != tt.wantCols {
t.Errorf("columns = %d, want %d", len(insert.Columns), tt.wantCols)
}

// Verify SQL() roundtrip doesn't panic
_ = insert.SQL()
})
}
}
2 changes: 1 addition & 1 deletion pkg/sql/parser/parser_coverage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ func TestParser_CTEEdgeCases(t *testing.T) {
{Type: models.TokenTypeFrom, Literal: "FROM"},
{Type: models.TokenTypeIdentifier, Literal: "new_users"},
},
wantErr: true, // INSERT SELECT with CTE not yet fully supported
wantErr: false, // INSERT ... SELECT is now supported
},
{
name: "CTE with UPDATE statement",
Expand Down
Loading