diff --git a/internal/db/db.go b/internal/db/db.go index d7e5d54..7327307 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -7,6 +7,8 @@ import ( "io" "reflect" "strings" + "sync" + "time" "github.com/libsql/sqlite-antlr4-parser/sqliteparserutils" _ "github.com/mattn/go-sqlite3" @@ -28,6 +30,8 @@ type Db struct { AuthToken string sqlDb *sql.DB + conn *sql.Conn + connMutex *sync.Mutex driver driver urlScheme string @@ -74,7 +78,7 @@ func newRowResultWithError(err error) *rowResult { func NewDb(dbUri, authToken, proxy string) (*Db, error) { var err error - var db = Db{Uri: dbUri, AuthToken: authToken} + var db = Db{Uri: dbUri, AuthToken: authToken, connMutex: &sync.Mutex{}} if IsUrl(dbUri) { var validSqldUrl bool @@ -92,6 +96,10 @@ func NewDb(dbUri, authToken, proxy string) (*Db, error) { return nil, err } db.sqlDb = sql.OpenDB(connector) + db.conn, err = db.sqlDb.Conn(context.Background()) + if err != nil { + return nil, err + } } else { return nil, &shellerrors.ProtocolError{} } @@ -102,6 +110,19 @@ func NewDb(dbUri, authToken, proxy string) (*Db, error) { if err != nil { return nil, err } + // every second acquire mutex and run a no op + go func() { + tick := time.NewTicker(time.Second) + for range tick.C { + db.connMutex.Lock() + _, err := db.conn.QueryContext(context.Background(), "SELECT 1;", nil) + if err != nil { + fmt.Printf("connection closed. restart your shell") + return + } + db.connMutex.Unlock() + } + }() return &db, nil } @@ -160,7 +181,9 @@ func (db *Db) executeQuery(query string, statementResultCh chan StatementResult) ctx, cancel := context.WithCancel(context.Background()) db.cancelRunningQuery = cancel - rows, err := db.sqlDb.QueryContext(ctx, query) + db.connMutex.Lock() + rows, err := db.conn.QueryContext(ctx, query) + db.connMutex.Unlock() if err != nil { statementResultCh <- *newStatementResultWithError(err)