diff --git a/cmd/litestream-vfs/main.go b/cmd/litestream-vfs/main.go index 6483214eb..462b441d4 100644 --- a/cmd/litestream-vfs/main.go +++ b/cmd/litestream-vfs/main.go @@ -17,6 +17,7 @@ import ( "io" "log/slog" "os" + "strconv" "strings" "time" "unsafe" @@ -41,21 +42,18 @@ func main() {} //export LitestreamVFSRegister func LitestreamVFSRegister() *C.char { var client litestream.ReplicaClient - var err error replicaURL := os.Getenv("LITESTREAM_REPLICA_URL") - if replicaURL == "" { - return C.CString("LITESTREAM_REPLICA_URL environment variable required") - } - - client, err = litestream.NewReplicaClientFromURL(replicaURL) - if err != nil { - return C.CString(fmt.Sprintf("failed to create replica client: %s", err)) - } + if replicaURL != "" { + var err error + client, err = litestream.NewReplicaClientFromURL(replicaURL) + if err != nil { + return C.CString(fmt.Sprintf("failed to create replica client: %s", err)) + } - // Initialize the client. - if err := client.Init(context.Background()); err != nil { - return C.CString(fmt.Sprintf("failed to initialize replica client: %s", err)) + if err := client.Init(context.Background()); err != nil { + return C.CString(fmt.Sprintf("failed to initialize replica client: %s", err)) + } } var level slog.Level @@ -111,6 +109,58 @@ func LitestreamVFSRegister() *C.char { return nil } +//export GoLitestreamConfigure +func GoLitestreamConfigure(dbName *C.char, key *C.char, value *C.char) *C.char { + name := C.GoString(dbName) + k := C.GoString(key) + v := C.GoString(value) + + cfg := litestream.GetVFSConfig(name) + if cfg == nil { + cfg = &litestream.VFSConfig{} + } + + switch k { + case "replica_url": + cfg.ReplicaURL = v + case "write_enabled": + b := strings.ToLower(v) == "true" || v == "1" + cfg.WriteEnabled = &b + case "sync_interval": + d, err := time.ParseDuration(v) + if err != nil { + return C.CString(fmt.Sprintf("invalid sync_interval: %s", err)) + } + cfg.SyncInterval = &d + case "buffer_path": + cfg.BufferPath = v + case "hydration_enabled": + b := strings.ToLower(v) == "true" || v == "1" + cfg.HydrationEnabled = &b + case "hydration_path": + cfg.HydrationPath = v + case "poll_interval": + d, err := time.ParseDuration(v) + if err != nil { + return C.CString(fmt.Sprintf("invalid poll_interval: %s", err)) + } + cfg.PollInterval = &d + case "cache_size": + n, err := strconv.Atoi(v) + if err != nil { + return C.CString(fmt.Sprintf("invalid cache_size: %s", err)) + } + cfg.CacheSize = &n + case "log_level": + cfg.LogLevel = v + default: + return C.CString(fmt.Sprintf("unknown config key: %s", k)) + } + + litestream.SetVFSConfig(name, cfg) + return nil +} + //export GoLitestreamRegisterConnection func GoLitestreamRegisterConnection(dbPtr unsafe.Pointer, fileID C.sqlite3_uint64) *C.char { if err := litestream.RegisterVFSConnection(uintptr(dbPtr), uint64(fileID)); err != nil { diff --git a/vfs.go b/vfs.go index bf2a96d65..a2c46a599 100644 --- a/vfs.go +++ b/vfs.go @@ -138,26 +138,81 @@ func (vfs *VFS) Open(name string, flags sqlite3vfs.OpenFlag) (sqlite3vfs.File, s } func (vfs *VFS) openMainDB(name string, flags sqlite3vfs.OpenFlag) (sqlite3vfs.File, sqlite3vfs.OpenFlag, error) { - f := NewVFSFile(vfs.client, name, vfs.logger.With("name", name)) + cfg := GetVFSConfig(name) + + client := vfs.client + var perConnClient bool + if cfg != nil && cfg.ReplicaURL != "" { + var err error + client, err = NewReplicaClientFromURL(cfg.ReplicaURL) + if err != nil { + return nil, 0, fmt.Errorf("create per-connection replica client: %w", err) + } + if err := client.Init(context.Background()); err != nil { + return nil, 0, fmt.Errorf("init per-connection replica client: %w", err) + } + perConnClient = true + } + + if client == nil { + return nil, 0, fmt.Errorf("no replica client configured: set LITESTREAM_REPLICA_URL or use SetVFSConfig") + } + + f := NewVFSFile(client, name, vfs.logger.With("name", name)) f.PollInterval = vfs.PollInterval f.CacheSize = vfs.CacheSize - f.vfs = vfs // Store reference to parent VFS for config access + f.vfs = vfs + f.perConnClient = perConnClient + + if cfg != nil { + if cfg.PollInterval != nil { + f.PollInterval = *cfg.PollInterval + } + if cfg.CacheSize != nil { + f.CacheSize = *cfg.CacheSize + } + } + + writeEnabled := vfs.WriteEnabled + if cfg != nil && cfg.WriteEnabled != nil { + writeEnabled = *cfg.WriteEnabled + } + + syncInterval := vfs.WriteSyncInterval + if cfg != nil && cfg.SyncInterval != nil { + syncInterval = *cfg.SyncInterval + } + + bufferPath := vfs.WriteBufferPath + if cfg != nil && cfg.BufferPath != "" { + bufferPath = cfg.BufferPath + } + + hydrationEnabled := vfs.HydrationEnabled + if cfg != nil && cfg.HydrationEnabled != nil { + hydrationEnabled = *cfg.HydrationEnabled + } + + hydrationPath := vfs.HydrationPath + if cfg != nil && cfg.HydrationPath != "" { + hydrationPath = cfg.HydrationPath + } // Initialize write support if enabled - if vfs.WriteEnabled { + if writeEnabled { f.writeEnabled = true f.dirty = make(map[uint32]int64) - f.syncInterval = vfs.WriteSyncInterval + f.syncInterval = syncInterval if f.syncInterval == 0 { f.syncInterval = DefaultSyncInterval } writeSeq := atomic.AddUint64(&vfs.writeSeq, 1) - if vfs.WriteBufferPath != "" { + if bufferPath != "" { if writeSeq == 1 { - f.bufferPath = vfs.WriteBufferPath + f.bufferPath = bufferPath } else { - f.bufferPath = vfs.WriteBufferPath + "." + strconv.FormatUint(writeSeq, 10) + f.bufferPath = bufferPath + "." + strconv.FormatUint(writeSeq, 10) } } else { dir, err := vfs.ensureTempDir() @@ -169,18 +224,16 @@ func (vfs *VFS) openMainDB(name string, flags sqlite3vfs.OpenFlag) (sqlite3vfs.F // Initialize compaction if enabled if vfs.CompactionEnabled { - f.compactor = NewCompactor(vfs.client, f.logger) - // VFS has no local files, so leave LocalFileOpener/LocalFileDeleter nil + f.compactor = NewCompactor(client, f.logger) } } // Initialize hydration support if enabled - if vfs.HydrationEnabled { - if vfs.HydrationPath != "" { - f.hydrationPath = vfs.HydrationPath + if hydrationEnabled { + if hydrationPath != "" { + f.hydrationPath = hydrationPath f.hydrationPersistent = true } else { - // Use a temp file if no path specified dir, err := vfs.ensureTempDir() if err != nil { return nil, 0, fmt.Errorf("create temp dir for hydration: %w", err) @@ -190,10 +243,15 @@ func (vfs *VFS) openMainDB(name string, flags sqlite3vfs.OpenFlag) (sqlite3vfs.F } if err := f.Open(); err != nil { + if perConnClient { + if closer, ok := client.(io.Closer); ok { + closer.Close() + } + } return nil, 0, err } - if vfs.WriteEnabled { + if writeEnabled { vfs.writeMu.Lock() if f.expectedTXID > vfs.lastSyncedTXID { vfs.lastSyncedTXID = f.expectedTXID @@ -543,6 +601,8 @@ type VFSFile struct { disabling bool // True when write disable is in progress cond *sync.Cond // Signals transaction state changes + perConnClient bool // True when client was created from config registry (close on file close) + hydrator *Hydrator // Background hydration (nil if disabled) hydrationPath string // Path for hydration file (set during Open) hydrationPersistent bool // True when using user-specified persistent path @@ -1416,6 +1476,14 @@ func (f *VFSFile) Close() error { f.vfs.writeMu.Unlock() } + if f.perConnClient { + if closer, ok := f.client.(io.Closer); ok { + if err := closer.Close(); err != nil { + f.logger.Warn("failed to close per-connection client", "error", err) + } + } + } + return nil } diff --git a/vfs_config.go b/vfs_config.go new file mode 100644 index 000000000..f0d910194 --- /dev/null +++ b/vfs_config.go @@ -0,0 +1,50 @@ +//go:build vfs +// +build vfs + +package litestream + +import ( + "sync" + "time" +) + +type VFSConfig struct { + ReplicaURL string + WriteEnabled *bool + SyncInterval *time.Duration + BufferPath string + HydrationEnabled *bool + HydrationPath string + PollInterval *time.Duration + CacheSize *int + LogLevel string +} + +var ( + vfsConfigs = make(map[string]*VFSConfig) + vfsConfigsMu sync.RWMutex +) + +func SetVFSConfig(dbName string, cfg *VFSConfig) { + vfsConfigsMu.Lock() + defer vfsConfigsMu.Unlock() + copied := *cfg + vfsConfigs[dbName] = &copied +} + +func GetVFSConfig(dbName string) *VFSConfig { + vfsConfigsMu.RLock() + defer vfsConfigsMu.RUnlock() + orig := vfsConfigs[dbName] + if orig == nil { + return nil + } + copied := *orig + return &copied +} + +func DeleteVFSConfig(dbName string) { + vfsConfigsMu.Lock() + defer vfsConfigsMu.Unlock() + delete(vfsConfigs, dbName) +} diff --git a/vfs_config_test.go b/vfs_config_test.go new file mode 100644 index 000000000..f365a719c --- /dev/null +++ b/vfs_config_test.go @@ -0,0 +1,167 @@ +//go:build vfs + +package litestream + +import ( + "fmt" + "log/slog" + "sync" + "testing" + "time" +) + +func TestVFSConfig_SetGet(t *testing.T) { + defer clearVFSConfigs() + + cfg := &VFSConfig{ReplicaURL: "s3://bucket/path"} + SetVFSConfig("test.db", cfg) + + got := GetVFSConfig("test.db") + if got == nil { + t.Fatal("expected config, got nil") + } + if got.ReplicaURL != "s3://bucket/path" { + t.Fatalf("expected replica url %q, got %q", "s3://bucket/path", got.ReplicaURL) + } + + if got := GetVFSConfig("nonexistent.db"); got != nil { + t.Fatalf("expected nil for nonexistent db, got %+v", got) + } + + DeleteVFSConfig("test.db") + if got := GetVFSConfig("test.db"); got != nil { + t.Fatalf("expected nil after delete, got %+v", got) + } +} + +func TestVFSConfig_ConcurrentAccess(t *testing.T) { + defer clearVFSConfigs() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + name := fmt.Sprintf("db%d", i) + cfg := &VFSConfig{ReplicaURL: "s3://bucket/" + name} + SetVFSConfig(name, cfg) + _ = GetVFSConfig(name) + DeleteVFSConfig(name) + }(i) + } + wg.Wait() +} + +func TestVFSConfig_OverridesDefaults(t *testing.T) { + defer clearVFSConfigs() + + poll := 5 * time.Second + cacheSize := 20 * 1024 * 1024 + writeEnabled := true + + cfg := &VFSConfig{ + PollInterval: &poll, + CacheSize: &cacheSize, + WriteEnabled: &writeEnabled, + } + SetVFSConfig("override.db", cfg) + + got := GetVFSConfig("override.db") + if got == nil { + t.Fatal("expected config, got nil") + } + if got.PollInterval == nil || *got.PollInterval != 5*time.Second { + t.Fatalf("expected poll interval 5s, got %v", got.PollInterval) + } + if got.CacheSize == nil || *got.CacheSize != 20*1024*1024 { + t.Fatalf("expected cache size 20MB, got %v", got.CacheSize) + } + if got.WriteEnabled == nil || !*got.WriteEnabled { + t.Fatalf("expected write enabled true, got %v", got.WriteEnabled) + } +} + +func TestVFSConfig_NilOptionalFields(t *testing.T) { + defer clearVFSConfigs() + + cfg := &VFSConfig{ReplicaURL: "s3://bucket/path"} + SetVFSConfig("sparse.db", cfg) + + got := GetVFSConfig("sparse.db") + if got.WriteEnabled != nil { + t.Fatalf("expected nil WriteEnabled, got %v", got.WriteEnabled) + } + if got.PollInterval != nil { + t.Fatalf("expected nil PollInterval, got %v", got.PollInterval) + } + if got.CacheSize != nil { + t.Fatalf("expected nil CacheSize, got %v", got.CacheSize) + } +} + +func TestVFSConfig_CopyOnSetAndGet(t *testing.T) { + defer clearVFSConfigs() + + cfg := &VFSConfig{ReplicaURL: "s3://bucket/original"} + SetVFSConfig("copy.db", cfg) + + cfg.ReplicaURL = "s3://bucket/mutated" + + got := GetVFSConfig("copy.db") + if got.ReplicaURL != "s3://bucket/original" { + t.Fatalf("expected original url, got %q (SetVFSConfig did not copy)", got.ReplicaURL) + } + + got.ReplicaURL = "s3://bucket/mutated-via-get" + got2 := GetVFSConfig("copy.db") + if got2.ReplicaURL != "s3://bucket/original" { + t.Fatalf("expected original url, got %q (GetVFSConfig did not copy)", got2.ReplicaURL) + } +} + +func TestVFSConfig_PerConnectionOverrides(t *testing.T) { + defer clearVFSConfigs() + + client := newMockReplicaClient() + client.addFixture(t, buildLTXFixture(t, 1, 'a')) + + poll := 3 * time.Second + cacheSize := 5 * 1024 * 1024 + SetVFSConfig("config-override.db", &VFSConfig{ + PollInterval: &poll, + CacheSize: &cacheSize, + }) + + vfs := NewVFS(client, slog.Default()) + + f, _, err := vfs.openMainDB("config-override.db", 0x00000100) + if err != nil { + t.Fatalf("open main db: %v", err) + } + defer f.Close() + + vfsFile := f.(*VFSFile) + if vfsFile.PollInterval != 3*time.Second { + t.Fatalf("expected poll interval 3s, got %v", vfsFile.PollInterval) + } + if vfsFile.CacheSize != 5*1024*1024 { + t.Fatalf("expected cache size 5MB, got %v", vfsFile.CacheSize) + } +} + +func TestVFS_NilClientReturnsError(t *testing.T) { + defer clearVFSConfigs() + + vfs := NewVFS(nil, slog.Default()) + + _, _, err := vfs.openMainDB("no-client.db", 0x00000100) + if err == nil { + t.Fatal("expected error when no client configured, got nil") + } +} + +func clearVFSConfigs() { + vfsConfigsMu.Lock() + defer vfsConfigsMu.Unlock() + vfsConfigs = make(map[string]*VFSConfig) +}