Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,48 @@ package db

import (
"fmt"
"gorm.io/gorm/logger"
"log"
"sync"

"github.com/glebarez/sqlite"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)

// TODO: Turn this into a singleton class.
// Only one database connection should be created and used throughout the application.
const DB_SQLITE_FALLBACK_WARNING = "[db] DATABASE_URL not set – falling back to embedded SQLite ./mcp.db"
const DB_SQLITE_CONN_STRING = "mcp.db?_busy_timeout=5000&_journal_mode=WAL"

var db *gorm.DB
var err error
var once sync.Once

// NewDBConnection creates a new database connection based on the provided DSN.
// If the DSN is empty, it falls back to an embedded SQLite database at "./mcp.db".
func NewDBConnection(dsn string) (*gorm.DB, error) {
var dialector gorm.Dialector
if dsn == "" {
log.Println("[db] DATABASE_URL not set – falling back to embedded SQLite ./mcp.db")
dialector = sqlite.Open("mcp.db?_busy_timeout=5000&_journal_mode=WAL")
} else {
dialector = postgres.Open(dsn)
}
func NewDBConnection(dsnString string) (*gorm.DB, error) {
var dbErr error
once.Do(func() {
var dialector gorm.Dialector

c := &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
}
db, err := gorm.Open(dialector, c)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
if dsnString == "" {
log.Println(DB_SQLITE_FALLBACK_WARNING)
dialector = sqlite.Open(DB_SQLITE_CONN_STRING)
} else {
dialector = postgres.Open(dsnString)
}

c := &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
}

db, err = gorm.Open(dialector, c)

if err != nil {
dbErr = fmt.Errorf("failed to connect to database: %w", err)
}
})
if dbErr != nil {
return nil, dbErr
}
return db, nil
return db, err
}
29 changes: 29 additions & 0 deletions internal/db/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package db

import (
"os"
"testing"
)

func TestNewDBConnection_SQLiteFallback(t *testing.T) {
// Unset DATABASE_URL to force fallback
os.Unsetenv("DATABASE_URL")
db, err := NewDBConnection("")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if db == nil {
t.Fatal("expected db instance, got nil")
}
}

func TestNewDBConnection_PostgresDSN(t *testing.T) {
// Use a fake DSN, expect connection error
db, err := NewDBConnection("postgres://invalid:invalid@localhost:5432/invalid?sslmode=disable")
if err == nil {
t.Error("expected error for invalid DSN, got nil")
}
if db != nil {
t.Error("expected nil db for invalid DSN")
}
}