Skip to content

Use a singleton for database connections #36

@andrew-sledge

Description

@andrew-sledge

Per the comments in db.go

// TODO: Turn this into a singleton class.
// Only one database connection should be created and used throughout the application.

Implementation of a singleton may look like

import (
    "fmt"
    "log"
    "sync"

    "github.com/glebarez/sqlite"
    "gorm.io/driver/postgres"
    "gorm.io/gorm"
    "gorm.io/gorm/logger"
)

const DB_SQLITE_FALLBACK_WARNING = "[db] DATABASE_URL not set – falling back to embedded SQLite ./mcp.db"
const DB_SQLITE_CONN_STRING = "mcp.db?_busy_timeout=5000&_journal_mode=WAL"

var db *gorm.DB
var err error
var once sync.Once

// NewDBConnection creates a new database connection based on the provided DSN.
// If the DSN is empty, it falls back to an embedded SQLite database at "./mcp.db".
func NewDBConnection(dsnString string) (*gorm.DB, error) {
    var dbErr error
    once.Do(func() {
        var dialector gorm.Dialector

        if dsnString == "" {
            log.Println(DB_SQLITE_FALLBACK_WARNING)
            dialector = sqlite.Open(DB_SQLITE_CONN_STRING)
        } else {
            dialector = postgres.Open(dsnString)
        }

        c := &gorm.Config{
            Logger: logger.Default.LogMode(logger.Silent),
        }

        db, err = gorm.Open(dialector, c)

        if err != nil {
            dbErr = fmt.Errorf("failed to connect to database: %w", err)
        }
    })
    if dbErr != nil {
        return nil, dbErr
    }
    return db, err
}

Because the signature does not change, there should be no refactoring implementations. However it could stand to use a test:

package db

import (
	"os"
	"testing"
)

func TestNewDBConnection_SQLiteFallback(t *testing.T) {
	// Unset DATABASE_URL to force fallback
	os.Unsetenv("DATABASE_URL")
	db, err := NewDBConnection("")
	if err != nil {
		t.Fatalf("expected no error, got: %v", err)
	}
	if db == nil {
		t.Fatal("expected db instance, got nil")
	}
}

func TestNewDBConnection_PostgresDSN(t *testing.T) {
	// Use a fake DSN, expect connection error
	db, err := NewDBConnection("postgres://invalid:invalid@localhost:5432/invalid?sslmode=disable")
	if err == nil {
		t.Error("expected error for invalid DSN, got nil")
	}
	if db != nil {
		t.Error("expected nil db for invalid DSN")
	}
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions