diff --git a/internal/db/db.go b/internal/db/db.go index 9decef592..b5dfbf2c3 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -2,34 +2,48 @@ package db import ( "fmt" - "gorm.io/gorm/logger" "log" + "sync" "github.com/glebarez/sqlite" "gorm.io/driver/postgres" "gorm.io/gorm" + "gorm.io/gorm/logger" ) -// TODO: Turn this into a singleton class. -// Only one database connection should be created and used throughout the application. +const DB_SQLITE_FALLBACK_WARNING = "[db] DATABASE_URL not set – falling back to embedded SQLite ./mcp.db" +const DB_SQLITE_CONN_STRING = "mcp.db?_busy_timeout=5000&_journal_mode=WAL" + +var db *gorm.DB +var err error +var once sync.Once // NewDBConnection creates a new database connection based on the provided DSN. // If the DSN is empty, it falls back to an embedded SQLite database at "./mcp.db". -func NewDBConnection(dsn string) (*gorm.DB, error) { - var dialector gorm.Dialector - if dsn == "" { - log.Println("[db] DATABASE_URL not set – falling back to embedded SQLite ./mcp.db") - dialector = sqlite.Open("mcp.db?_busy_timeout=5000&_journal_mode=WAL") - } else { - dialector = postgres.Open(dsn) - } +func NewDBConnection(dsnString string) (*gorm.DB, error) { + var dbErr error + once.Do(func() { + var dialector gorm.Dialector - c := &gorm.Config{ - Logger: logger.Default.LogMode(logger.Silent), - } - db, err := gorm.Open(dialector, c) - if err != nil { - return nil, fmt.Errorf("failed to connect to database: %w", err) + if dsnString == "" { + log.Println(DB_SQLITE_FALLBACK_WARNING) + dialector = sqlite.Open(DB_SQLITE_CONN_STRING) + } else { + dialector = postgres.Open(dsnString) + } + + c := &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + } + + db, err = gorm.Open(dialector, c) + + if err != nil { + dbErr = fmt.Errorf("failed to connect to database: %w", err) + } + }) + if dbErr != nil { + return nil, dbErr } - return db, nil + return db, err } diff --git a/internal/db/db_test.go b/internal/db/db_test.go new file mode 100644 index 000000000..2a7a675de --- /dev/null +++ b/internal/db/db_test.go @@ -0,0 +1,29 @@ +package db + +import ( + "os" + "testing" +) + +func TestNewDBConnection_SQLiteFallback(t *testing.T) { + // Unset DATABASE_URL to force fallback + os.Unsetenv("DATABASE_URL") + db, err := NewDBConnection("") + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if db == nil { + t.Fatal("expected db instance, got nil") + } +} + +func TestNewDBConnection_PostgresDSN(t *testing.T) { + // Use a fake DSN, expect connection error + db, err := NewDBConnection("postgres://invalid:invalid@localhost:5432/invalid?sslmode=disable") + if err == nil { + t.Error("expected error for invalid DSN, got nil") + } + if db != nil { + t.Error("expected nil db for invalid DSN") + } +}