88 "os/exec"
99 "strings"
1010
11- "cloud.google.com/go/spanner/spansql"
11+ "github.com/cloudspannerecosystem/memefish"
12+ "github.com/cloudspannerecosystem/memefish/ast"
13+ "github.com/cloudspannerecosystem/memefish/token"
1214)
1315
1416const (
@@ -56,6 +58,10 @@ func (cli *cli) run(args []string) int {
5658 log .Print (err )
5759 return exitCodeError
5860 }
61+ log .Printf ("Parsed %d tables" , len (tables ))
62+ for i , t := range tables {
63+ log .Printf ("Table %d: %s" , i , getTableName (t ))
64+ }
5965 graph , err := NewGraph ()
6066 if err != nil {
6167 log .Print (err )
@@ -67,11 +73,19 @@ func (cli *cli) run(args []string) int {
6773 return exitCodeError
6874 }
6975 s := graph .String ()
76+ log .Printf ("Graph DOT content: %s" , s )
7077 r := strings .NewReader (s )
7178 c := exec .Command ("dot" , fmt .Sprintf ("-T%s" , t ), "-o" , output )
7279 c .Stdin = r
73- c .Start ()
74- c .Wait ()
80+ var stderr strings.Builder
81+ c .Stderr = & stderr
82+ err = c .Run ()
83+ if err != nil {
84+ log .Printf ("Error running dot command: %v" , err )
85+ log .Printf ("Stderr: %s" , stderr .String ())
86+ return exitCodeError
87+ }
88+ log .Printf ("Output file created: %s" , output )
7589
7690 return exitCodeOK
7791}
@@ -86,19 +100,55 @@ func (cli *cli) read(file string) (string, error) {
86100
87101}
88102
89- func parse (sqls string ) ([]* spansql.CreateTable , error ) {
90- // spansql not allow backquote
91- sqls = strings .Replace (sqls , "`" , "" , - 1 )
92- d , err := spansql .ParseDDL ("" , sqls )
93- if err != nil {
94- return nil , err
95- }
96- tables := []* spansql.CreateTable {}
97- for _ , e := range d .List {
98- switch v := e .(type ) {
99- case * spansql.CreateTable :
100- tables = append (tables , v )
103+ func parse (sqls string ) ([]* ast.CreateTable , error ) {
104+ // Log the original SQL for debugging
105+ log .Printf ("Original SQL: %s" , sqls )
106+
107+ // Split the SQL by semicolons to get individual statements
108+ statements := strings .Split (sqls , ";" )
109+
110+ var tables []* ast.CreateTable
111+ for _ , stmt := range statements {
112+ // Skip empty statements
113+ stmt = strings .TrimSpace (stmt )
114+ if stmt == "" {
115+ continue
116+ }
117+
118+ // Log each statement for debugging
119+ log .Printf ("Parsing statement: %s" , stmt )
120+
121+ // Create a new Parser instance for each statement
122+ file := & token.File {
123+ Buffer : stmt ,
124+ }
125+ p := & memefish.Parser {
126+ Lexer : & memefish.Lexer {File : file },
127+ }
128+
129+ // Parse the statement
130+ parsedStmt , err := p .ParseStatement ()
131+ if err != nil {
132+ log .Printf ("Error parsing statement: %v" , err )
133+ continue
134+ }
135+
136+ // If it's a CREATE TABLE statement, add it to our list
137+ if createTable , ok := parsedStmt .(* ast.CreateTable ); ok {
138+ log .Printf ("Found CREATE TABLE: %s" , getTableName (createTable ))
139+ tables = append (tables , createTable )
140+ } else {
141+ log .Printf ("Statement is not a CREATE TABLE: %T" , parsedStmt )
101142 }
102143 }
144+
103145 return tables , nil
104146}
147+
148+ // Helper function to get the name from a CreateTable
149+ func getTableName (t * ast.CreateTable ) string {
150+ if t .Name != nil && len (t .Name .Idents ) > 0 {
151+ return t .Name .Idents [len (t .Name .Idents )- 1 ].Name
152+ }
153+ return ""
154+ }
0 commit comments