Skip to content

Commit 1a4d95e

Browse files
ajitpratap0Ajit Pratap Singh
andauthored
feat(#275): comment preservation in AST and formatter (#311)
Co-authored-by: Ajit Pratap Singh <[email protected]>
1 parent 22dbbcd commit 1a4d95e

File tree

8 files changed

+223
-2
lines changed

8 files changed

+223
-2
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package formatter
2+
3+
import (
4+
"strings"
5+
"testing"
6+
)
7+
8+
func TestFormat_LineCommentPreservation(t *testing.T) {
9+
f := New(Options{Compact: true})
10+
11+
t.Run("leading line comment", func(t *testing.T) {
12+
input := "-- header comment\nSELECT col FROM t"
13+
result, err := f.Format(input)
14+
if err != nil {
15+
t.Fatalf("Format() error = %v", err)
16+
}
17+
if !strings.Contains(result, "-- header comment") {
18+
t.Errorf("Expected leading comment preserved, got: %s", result)
19+
}
20+
if !strings.Contains(result, "SELECT") || !strings.Contains(result, "col") {
21+
t.Errorf("Expected SQL preserved, got: %s", result)
22+
}
23+
})
24+
25+
t.Run("trailing line comment", func(t *testing.T) {
26+
input := "SELECT col FROM t -- trailing"
27+
result, err := f.Format(input)
28+
if err != nil {
29+
t.Fatalf("Format() error = %v", err)
30+
}
31+
if !strings.Contains(result, "-- trailing") {
32+
t.Errorf("Expected trailing comment preserved, got: %s", result)
33+
}
34+
})
35+
}
36+
37+
func TestFormat_BlockCommentPreservation(t *testing.T) {
38+
f := New(Options{Compact: true})
39+
40+
t.Run("leading block comment", func(t *testing.T) {
41+
input := "/* header */\nSELECT col FROM t"
42+
result, err := f.Format(input)
43+
if err != nil {
44+
t.Fatalf("Format() error = %v", err)
45+
}
46+
if !strings.Contains(result, "/* header */") {
47+
t.Errorf("Expected block comment preserved, got: %s", result)
48+
}
49+
})
50+
51+
t.Run("inline block comment", func(t *testing.T) {
52+
input := "SELECT /* inline */ col FROM t"
53+
result, err := f.Format(input)
54+
if err != nil {
55+
t.Fatalf("Format() error = %v", err)
56+
}
57+
if !strings.Contains(result, "/* inline */") {
58+
t.Errorf("Expected inline block comment preserved, got: %s", result)
59+
}
60+
})
61+
}
62+
63+
func TestFormat_CommentRoundTrip(t *testing.T) {
64+
f := New(Options{Compact: true})
65+
66+
inputs := []string{
67+
"-- header\nSELECT col FROM t",
68+
"SELECT col FROM t -- trailing",
69+
"/* block */ SELECT col FROM t",
70+
}
71+
72+
for _, input := range inputs {
73+
t.Run(input, func(t *testing.T) {
74+
first, err := f.Format(input)
75+
if err != nil {
76+
t.Fatalf("First format error: %v", err)
77+
}
78+
second, err := f.Format(first)
79+
if err != nil {
80+
t.Fatalf("Second format error: %v", err)
81+
}
82+
if first != second {
83+
t.Errorf("Round-trip mismatch:\n first: %q\n second: %q", first, second)
84+
}
85+
})
86+
}
87+
}
88+
89+
func TestFormat_MultipleComments(t *testing.T) {
90+
f := New(Options{Compact: true})
91+
92+
input := "-- first comment\n-- second comment\nSELECT col FROM t"
93+
result, err := f.Format(input)
94+
if err != nil {
95+
t.Fatalf("Format() error = %v", err)
96+
}
97+
if !strings.Contains(result, "-- first comment") {
98+
t.Errorf("Expected first comment, got: %s", result)
99+
}
100+
if !strings.Contains(result, "-- second comment") {
101+
t.Errorf("Expected second comment, got: %s", result)
102+
}
103+
}

pkg/formatter/formatter.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"fmt"
1111
"strings"
1212

13+
"github.com/ajitpratap0/GoSQLX/pkg/models"
1314
"github.com/ajitpratap0/GoSQLX/pkg/sql/ast"
1415
"github.com/ajitpratap0/GoSQLX/pkg/sql/parser"
1516
"github.com/ajitpratap0/GoSQLX/pkg/sql/tokenizer"
@@ -53,13 +54,22 @@ func (f *Formatter) Format(sql string) (string, error) {
5354
return "", nil
5455
}
5556

57+
// Capture comments from tokenizer before parsing
58+
comments := tkz.Comments
59+
5660
p := parser.NewParser()
5761
parsedAST, err := p.ParseFromModelTokens(tokens)
5862
if err != nil {
5963
return "", fmt.Errorf("parsing failed: %w", err)
6064
}
6165
defer ast.ReleaseAST(parsedAST)
6266

67+
// Attach captured comments to AST
68+
if len(comments) > 0 {
69+
parsedAST.Comments = make([]models.Comment, len(comments))
70+
copy(parsedAST.Comments, comments)
71+
}
72+
6373
// Use AST's built-in Format method
6474
style := ast.ReadableStyle()
6575
if f.opts.Compact {

pkg/models/comment.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Package models provides the Comment type for SQL comment preservation.
2+
package models
3+
4+
// CommentStyle indicates the type of SQL comment.
5+
type CommentStyle int
6+
7+
const (
8+
// LineComment represents a -- single-line comment.
9+
LineComment CommentStyle = iota
10+
// BlockComment represents a /* multi-line */ comment.
11+
BlockComment
12+
)
13+
14+
// Comment represents a SQL comment captured during tokenization.
15+
type Comment struct {
16+
Text string // The comment text including delimiters (e.g., "-- foo" or "/* bar */")
17+
Style CommentStyle // Line or block comment
18+
Start Location // Start position in source
19+
End Location // End position in source
20+
Inline bool // True if the comment is on the same line as code (trailing)
21+
}

pkg/sql/ast/ast.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@
3232
// FILTER clause, RETURNING clause, JSON/JSONB operators, and FETCH FIRST/NEXT.
3333
package ast
3434

35-
import "fmt"
35+
import (
36+
"fmt"
37+
38+
"github.com/ajitpratap0/GoSQLX/pkg/models"
39+
)
3640

3741
// Node represents any node in the Abstract Syntax Tree.
3842
//
@@ -1684,6 +1688,7 @@ func (p PartitionDefinition) Children() []Node {
16841688
// AST represents the root of the Abstract Syntax Tree
16851689
type AST struct {
16861690
Statements []Statement
1691+
Comments []models.Comment // Comments captured during tokenization, preserved during formatting
16871692
}
16881693

16891694
func (a AST) TokenLiteral() string { return "" }

pkg/sql/ast/format.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,34 @@ func (a AST) Format(opts FormatOptions) string {
138138
// Each statement already gets semicolons from its own Format
139139
sep = "\n"
140140
}
141-
return strings.Join(parts, sep)
141+
result := strings.Join(parts, sep)
142+
143+
// Emit preserved comments around the formatted SQL
144+
if len(a.Comments) > 0 {
145+
var leading, trailing []string
146+
for _, c := range a.Comments {
147+
if c.Inline {
148+
// Inline comments (on same line as code) → trailing
149+
trailing = append(trailing, c.Text)
150+
} else {
151+
// Comments on their own line → leading
152+
leading = append(leading, c.Text)
153+
}
154+
}
155+
var sb strings.Builder
156+
for _, lc := range leading {
157+
sb.WriteString(lc)
158+
sb.WriteString("\n")
159+
}
160+
sb.WriteString(result)
161+
for _, tc := range trailing {
162+
sb.WriteString(" ")
163+
sb.WriteString(tc)
164+
}
165+
result = sb.String()
166+
}
167+
168+
return result
142169
}
143170

144171
// Format returns formatted SQL for a SelectStatement.

pkg/sql/ast/pool.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,11 @@ func ReleaseAST(ast *AST) {
353353
// Reset slice but keep capacity
354354
ast.Statements = ast.Statements[:0]
355355

356+
// Reset comments but keep capacity
357+
if cap(ast.Comments) > 0 {
358+
ast.Comments = ast.Comments[:0]
359+
}
360+
356361
// Return to pool
357362
astPool.Put(ast)
358363
}

pkg/sql/tokenizer/pool.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,9 @@ func (t *Tokenizer) Reset() {
161161

162162
// Don't reset keywords as they're constant
163163
t.logger = nil
164+
165+
// Preserve Comments slice capacity but reset length
166+
if cap(t.Comments) > 0 {
167+
t.Comments = t.Comments[:0]
168+
}
164169
}

pkg/sql/tokenizer/tokenizer.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ type Tokenizer struct {
245245
line int // Current line number (1-based)
246246
keywords *keywords.Keywords // Keyword classifier for token type determination
247247
logger *slog.Logger // Optional structured logger for verbose tracing
248+
Comments []models.Comment // Comments captured during tokenization
248249
}
249250

250251
// New creates a new Tokenizer with default configuration and keyword support.
@@ -1233,6 +1234,8 @@ func (t *Tokenizer) readPunctuation() (models.Token, error) {
12331234
}
12341235
// Check for line comment: --
12351236
if nxtR == '-' {
1237+
commentStartIdx := t.pos.Index - size // back to first '-'
1238+
commentStartPos := t.toSQLPosition(Position{Index: commentStartIdx})
12361239
t.pos.AdvanceRune(nxtR, nxtSize)
12371240
// Skip until end of line or EOF
12381241
for t.pos.Index < len(t.input) {
@@ -1243,6 +1246,19 @@ func (t *Tokenizer) readPunctuation() (models.Token, error) {
12431246
}
12441247
t.pos.AdvanceRune(cr, csize)
12451248
}
1249+
commentEndIdx := t.pos.Index
1250+
// Trim trailing newline from comment text
1251+
textEnd := commentEndIdx
1252+
if textEnd > 0 && t.input[textEnd-1] == '\n' {
1253+
textEnd--
1254+
}
1255+
t.Comments = append(t.Comments, models.Comment{
1256+
Text: string(t.input[commentStartIdx:textEnd]),
1257+
Style: models.LineComment,
1258+
Start: commentStartPos,
1259+
End: t.toSQLPosition(t.pos),
1260+
Inline: t.hasCodeBeforeOnLine(commentStartIdx),
1261+
})
12461262
// Return the next token (skip the comment)
12471263
t.skipWhitespace()
12481264
return t.nextToken()
@@ -1258,6 +1274,8 @@ func (t *Tokenizer) readPunctuation() (models.Token, error) {
12581274
nxtR, nxtSize := utf8.DecodeRune(t.input[t.pos.Index:])
12591275
// Check for block comment: /*
12601276
if nxtR == '*' {
1277+
commentStartIdx := t.pos.Index - size // back to '/'
1278+
commentStartPos := t.toSQLPosition(Position{Index: commentStartIdx})
12611279
t.pos.AdvanceRune(nxtR, nxtSize)
12621280
// Skip until */ or EOF
12631281
for t.pos.Index < len(t.input) {
@@ -1275,6 +1293,13 @@ func (t *Tokenizer) readPunctuation() (models.Token, error) {
12751293
t.pos.AdvanceRune(cr, csize)
12761294
}
12771295
}
1296+
t.Comments = append(t.Comments, models.Comment{
1297+
Text: string(t.input[commentStartIdx:t.pos.Index]),
1298+
Style: models.BlockComment,
1299+
Start: commentStartPos,
1300+
End: t.toSQLPosition(t.pos),
1301+
Inline: t.hasCodeBeforeOnLine(commentStartIdx),
1302+
})
12781303
// Return the next token (skip the comment)
12791304
t.skipWhitespace()
12801305
return t.nextToken()
@@ -1636,3 +1661,23 @@ func (t *Tokenizer) getLocation(pos int) models.Location {
16361661
func isIdentifierChar(r rune) bool {
16371662
return isUnicodeIdentifierPart(r)
16381663
}
1664+
1665+
// hasCodeBeforeOnLine checks if there are non-whitespace characters on the same
1666+
// line before the given byte index. Used to determine if a comment is inline.
1667+
func (t *Tokenizer) hasCodeBeforeOnLine(idx int) bool {
1668+
// Find the start of the line containing idx
1669+
lineStart := 0
1670+
for i := len(t.lineStarts) - 1; i >= 0; i-- {
1671+
if t.lineStarts[i] <= idx {
1672+
lineStart = t.lineStarts[i]
1673+
break
1674+
}
1675+
}
1676+
// Check for non-whitespace between lineStart and idx
1677+
for i := lineStart; i < idx && i < len(t.input); i++ {
1678+
if t.input[i] != ' ' && t.input[i] != '\t' && t.input[i] != '\r' {
1679+
return true
1680+
}
1681+
}
1682+
return false
1683+
}

0 commit comments

Comments
 (0)