diff --git a/go.mod b/go.mod index e38dcdb..978630b 100644 --- a/go.mod +++ b/go.mod @@ -70,5 +70,5 @@ require ( github.com/olekukonko/tablewriter v0.0.5 github.com/rogpeppe/go-internal v1.9.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 + github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 // indirect ) diff --git a/go.sum b/go.sum index bd58ba0..e27eaab 100644 --- a/go.sum +++ b/go.sum @@ -82,8 +82,6 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/leodido/go-urn v1.2.2 h1:7z68G0FCGvDk646jz1AelTYNYWrTNm0bEcFAo147wt4= github.com/leodido/go-urn v1.2.2/go.mod h1:kUaIbLZWttglzwNuG0pgsh5vuV6u2YcGBYz1hIPjtOQ= -github.com/libsql/libsql-client-go v0.0.0-20230417135653-4b3a4f626bc0 h1:k7m1LHwZfTC3jV8hwC9GhgEUFWx2LsTtODN8DNvpROQ= -github.com/libsql/libsql-client-go v0.0.0-20230417135653-4b3a4f626bc0/go.mod h1:w1KCoxf6c2eACi0Rpape7cIooejVqeqYiVr1E/tGcLk= github.com/libsql/libsql-client-go v0.0.0-20230425122822-72eff623c460 h1:dG1TyWCzFX2KiL6MSyqfzt0XjJq5BP5x4ZD8sY2xahE= github.com/libsql/libsql-client-go v0.0.0-20230425122822-72eff623c460/go.mod h1:w1KCoxf6c2eACi0Rpape7cIooejVqeqYiVr1E/tGcLk= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= diff --git a/internal/db/db.go b/internal/db/db.go index f9c52d9..49983d6 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -8,7 +8,6 @@ import ( _ "github.com/libsql/libsql-client-go/libsql" _ "github.com/mattn/go-sqlite3" - "github.com/xwb1989/sqlparser" "github.com/libsql/libsql-shell-go/pkg/shell/enums" "github.com/libsql/libsql-shell-go/pkg/shell/shellerrors" @@ -61,9 +60,40 @@ func (db *Db) Close() { db.sqlDb.Close() } +func splitStatementToPieces(statementsString string) (pieces []string, err error) { + pieces = make([]string, 0, 16) + embeddedChar := ' ' + var stmt string + stmtBegin := 0 + for i, char := range statementsString { + if char == embeddedChar && char != ' ' { + embeddedChar = ' ' + continue + } + if (char == '\'' || char == '"') && embeddedChar == ' ' { + embeddedChar = char + continue + } + if embeddedChar != ' ' || char != ';' { + continue + } + stmt = strings.TrimSpace(statementsString[stmtBegin : i+1]) + if len(stmt) < 1 || strings.HasPrefix(stmt, ";") { + stmtBegin = i + 1 + continue + } + pieces = append(pieces, stmt) + stmtBegin = i + 1 + } + if stmtBegin < len(statementsString) { + pieces = append(pieces, statementsString[stmtBegin:]) + } + return pieces, nil +} + func (db *Db) ExecuteStatements(statementsString string) (StatementsResult, error) { - statements, err := sqlparser.SplitStatementToPieces(statementsString) + statements, err := splitStatementToPieces(statementsString) if err != nil { return StatementsResult{}, err }