Skip to content

Commit 2631948

Browse files
authored
fix: incorrect reference to catalog (#332)
1 parent 1749dfe commit 2631948

File tree

7 files changed

+65
-38
lines changed

7 files changed

+65
-38
lines changed

adapter/adapter.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ type ConnectionHolder interface {
1414
GetCatalogConn(ctx context.Context) (*stdsql.Conn, error)
1515
GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error)
1616
TryGetTxn() *stdsql.Tx
17+
GetCurrentCatalog() string
18+
GetCurrentSchema() string
1719
CloseTxn()
1820
CloseConn()
1921
}
@@ -42,6 +44,14 @@ func TryGetTxn(ctx *sql.Context) *stdsql.Tx {
4244
return ctx.Session.(ConnectionHolder).TryGetTxn()
4345
}
4446

47+
func GetCurrentCatalog(ctx *sql.Context) string {
48+
return ctx.Session.(ConnectionHolder).GetCurrentCatalog()
49+
}
50+
51+
func GetCurrentSchema(ctx *sql.Context) string {
52+
return ctx.Session.(ConnectionHolder).GetCurrentSchema()
53+
}
54+
4555
func CloseTxn(ctx *sql.Context) {
4656
ctx.Session.(ConnectionHolder).CloseTxn()
4757
}

backend/executor.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
stdsql "database/sql"
1818
"fmt"
1919

20+
"github.com/apecloud/myduckserver/adapter"
2021
"github.com/apecloud/myduckserver/catalog"
2122
"github.com/apecloud/myduckserver/transpiler"
2223
"github.com/dolthub/go-mysql-server/sql"
@@ -124,7 +125,7 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
124125

125126
switch node := n.(type) {
126127
case *plan.Use:
127-
useStmt := "USE " + catalog.FullSchemaName(b.provider.CatalogName(), node.Database().Name())
128+
useStmt := "USE " + catalog.FullSchemaName(adapter.GetCurrentCatalog(ctx), node.Database().Name())
128129
if _, err := conn.ExecContext(ctx.Context, useStmt); err != nil {
129130
if catalog.IsDuckDBSetSchemaNotFoundError(err) {
130131
return nil, sql.ErrDatabaseNotFound.New(node.Database().Name())

backend/session.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,16 @@ func (sess *Session) TryGetTxn() *stdsql.Tx {
225225
return sess.db.Pool().TryGetTxn(sess.ID())
226226
}
227227

228+
// GetCurrentCatalog implements adapter.ConnectionHolder.
229+
func (sess *Session) GetCurrentCatalog() string {
230+
return sess.db.Pool().CurrentCatalog(sess.ID())
231+
}
232+
233+
// GetCurrentSchema implements adapter.ConnectionHolder.
234+
func (sess *Session) GetCurrentSchema() string {
235+
return sess.db.Pool().CurrentSchema(sess.ID())
236+
}
237+
228238
// CloseTxn implements adapter.ConnectionHolder.
229239
func (sess *Session) CloseTxn() {
230240
sess.db.Pool().CloseTxn(sess.ID())

catalog/connpool.go

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,14 @@ import (
3030
type ConnectionPool struct {
3131
*stdsql.DB
3232
connector *duckdb.Connector
33-
catalog string
3433
conns sync.Map // concurrent-safe map[uint32]*stdsql.Conn
3534
txns sync.Map // concurrent-safe map[uint32]*stdsql.Tx
3635
}
3736

38-
func NewConnectionPool(catalog string, connector *duckdb.Connector, db *stdsql.DB) *ConnectionPool {
37+
func NewConnectionPool(connector *duckdb.Connector, db *stdsql.DB) *ConnectionPool {
3938
return &ConnectionPool{
4039
DB: db,
4140
connector: connector,
42-
catalog: catalog,
4341
}
4442
}
4543

@@ -57,13 +55,30 @@ func (p *ConnectionPool) CurrentSchema(id uint32) string {
5755
}
5856
conn := entry.(*stdsql.Conn)
5957
var schema string
60-
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA()").Scan(&schema); err != nil {
58+
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA").Scan(&schema); err != nil {
6159
logrus.WithError(err).Error("Failed to get current schema")
6260
return ""
6361
}
6462
return schema
6563
}
6664

65+
// CurrentCatalog retrieves the current catalog of the connection.
66+
// Returns an empty string if the connection is not established
67+
// or the catalog cannot be retrieved.
68+
func (p *ConnectionPool) CurrentCatalog(id uint32) string {
69+
entry, ok := p.conns.Load(id)
70+
if !ok {
71+
return ""
72+
}
73+
conn := entry.(*stdsql.Conn)
74+
var catalog string
75+
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_CATALOG").Scan(&catalog); err != nil {
76+
logrus.WithError(err).Error("Failed to get current catalog")
77+
return ""
78+
}
79+
return catalog
80+
}
81+
6782
func (p *ConnectionPool) GetConn(ctx context.Context, id uint32) (*stdsql.Conn, error) {
6883
var conn *stdsql.Conn
6984
entry, ok := p.conns.Load(id)
@@ -88,11 +103,11 @@ func (p *ConnectionPool) GetConnForSchema(ctx context.Context, id uint32, schema
88103

89104
if schemaName != "" {
90105
var currentSchema string
91-
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA()").Scan(&currentSchema); err != nil {
106+
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA").Scan(&currentSchema); err != nil {
92107
logrus.WithError(err).Error("Failed to get current schema")
93108
return nil, err
94109
} else if currentSchema != schemaName {
95-
if _, err := conn.ExecContext(context.Background(), "USE "+FullSchemaName(p.catalog, schemaName)); err != nil {
110+
if _, err := conn.ExecContext(context.Background(), "USE "+FullSchemaName(p.CurrentCatalog(id), schemaName)); err != nil {
96111
if IsDuckDBSetSchemaNotFoundError(err) {
97112
return nil, sql.ErrDatabaseNotFound.New(schemaName)
98113
}
@@ -187,15 +202,14 @@ func (p *ConnectionPool) Close() error {
187202
return errors.Join(lastErr, p.DB.Close())
188203
}
189204

190-
func (p *ConnectionPool) Reset(catalog string, connector *duckdb.Connector, db *stdsql.DB) error {
205+
func (p *ConnectionPool) Reset(connector *duckdb.Connector, db *stdsql.DB) error {
191206
err := p.Close()
192207
if err != nil {
193208
return fmt.Errorf("failed to close connection pool: %w", err)
194209
}
195210

196211
p.conns.Clear()
197212
p.txns.Clear()
198-
p.catalog = catalog
199213
p.DB = db
200214
p.connector = connector
201215

catalog/provider.go

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414

1515
"github.com/dolthub/go-mysql-server/sql"
1616
"github.com/marcboeker/go-duckdb"
17-
_ "github.com/marcboeker/go-duckdb"
1817

1918
"github.com/apecloud/myduckserver/adapter"
2019
"github.com/apecloud/myduckserver/configuration"
@@ -27,7 +26,7 @@ type DatabaseProvider struct {
2726
connector *duckdb.Connector
2827
storage *stdsql.DB
2928
pool *ConnectionPool
30-
catalogName string // database name in postgres
29+
defaultCatalogName string // default database name in postgres
3130
dataDir string
3231
dbFile string
3332
dsn string
@@ -60,11 +59,11 @@ func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabasePr
6059

6160
shouldInit := true
6261
if defaultDB == "" || defaultDB == "memory" {
63-
prov.catalogName = "memory"
62+
prov.defaultCatalogName = "memory"
6463
prov.dbFile = ""
6564
prov.dsn = ""
6665
} else {
67-
prov.catalogName = defaultDB
66+
prov.defaultCatalogName = defaultDB
6867
prov.dbFile = defaultDB + ".db"
6968
prov.dsn = filepath.Join(prov.dataDir, prov.dbFile)
7069
_, err = os.Stat(prov.dsn)
@@ -76,7 +75,7 @@ func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabasePr
7675
return nil, err
7776
}
7877
prov.storage = stdsql.OpenDB(prov.connector)
79-
prov.pool = NewConnectionPool(prov.catalogName, prov.connector, prov.storage)
78+
prov.pool = NewConnectionPool(prov.connector, prov.storage)
8079

8180
bootQueries := []string{
8281
"INSTALL arrow",
@@ -353,8 +352,8 @@ func (prov *DatabaseProvider) Pool() *ConnectionPool {
353352
return prov.pool
354353
}
355354

356-
func (prov *DatabaseProvider) CatalogName() string {
357-
return prov.catalogName
355+
func (prov *DatabaseProvider) DefaultCatalogName() string {
356+
return prov.defaultCatalogName
358357
}
359358

360359
func (prov *DatabaseProvider) DataDir() string {
@@ -380,7 +379,8 @@ func (prov *DatabaseProvider) AllDatabases(ctx *sql.Context) []sql.Database {
380379
prov.mu.RLock()
381380
defer prov.mu.RUnlock()
382381

383-
rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ?", prov.catalogName)
382+
catalogName := adapter.GetCurrentCatalog(ctx)
383+
rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ?", catalogName)
384384
if err != nil {
385385
panic(ErrDuckDB.New(err))
386386
}
@@ -398,7 +398,7 @@ func (prov *DatabaseProvider) AllDatabases(ctx *sql.Context) []sql.Database {
398398
continue
399399
}
400400

401-
all = append(all, NewDatabase(schemaName, prov.catalogName))
401+
all = append(all, NewDatabase(schemaName, catalogName))
402402
}
403403

404404
sort.Slice(all, func(i, j int) bool {
@@ -413,13 +413,14 @@ func (prov *DatabaseProvider) Database(ctx *sql.Context, name string) (sql.Datab
413413
prov.mu.RLock()
414414
defer prov.mu.RUnlock()
415415

416-
ok, err := hasDatabase(ctx, prov.catalogName, name)
416+
catalogName := adapter.GetCurrentCatalog(ctx)
417+
ok, err := hasDatabase(ctx, catalogName, name)
417418
if err != nil {
418419
return nil, err
419420
}
420421

421422
if ok {
422-
return NewDatabase(name, prov.catalogName), nil
423+
return NewDatabase(name, catalogName), nil
423424
}
424425
return nil, sql.ErrDatabaseNotFound.New(name)
425426
}
@@ -429,7 +430,7 @@ func (prov *DatabaseProvider) HasDatabase(ctx *sql.Context, name string) bool {
429430
prov.mu.RLock()
430431
defer prov.mu.RUnlock()
431432

432-
ok, err := hasDatabase(ctx, prov.catalogName, name)
433+
ok, err := hasDatabase(ctx, adapter.GetCurrentCatalog(ctx), name)
433434
if err != nil {
434435
panic(err)
435436
}
@@ -451,7 +452,8 @@ func (prov *DatabaseProvider) CreateDatabase(ctx *sql.Context, name string) erro
451452
prov.mu.Lock()
452453
defer prov.mu.Unlock()
453454

454-
_, err := adapter.ExecCatalog(ctx, fmt.Sprintf(`CREATE SCHEMA %s`, FullSchemaName(prov.catalogName, name)))
455+
_, err := adapter.ExecCatalog(ctx, fmt.Sprintf(`CREATE SCHEMA %s`,
456+
FullSchemaName(adapter.GetCurrentCatalog(ctx), name)))
455457
if err != nil {
456458
return ErrDuckDB.New(err)
457459
}
@@ -464,7 +466,8 @@ func (prov *DatabaseProvider) DropDatabase(ctx *sql.Context, name string) error
464466
prov.mu.Lock()
465467
defer prov.mu.Unlock()
466468

467-
_, err := adapter.Exec(ctx, fmt.Sprintf(`DROP SCHEMA %s CASCADE`, FullSchemaName(prov.catalogName, name)))
469+
_, err := adapter.Exec(ctx, fmt.Sprintf(`DROP SCHEMA %s CASCADE`,
470+
FullSchemaName(adapter.GetCurrentCatalog(ctx), name)))
468471
if err != nil {
469472
return ErrDuckDB.New(err)
470473
}
@@ -494,5 +497,5 @@ func (prov *DatabaseProvider) Restart(readOnly bool) error {
494497
prov.connector = connector
495498
prov.storage = storage
496499

497-
return nil
500+
return prov.pool.Reset(connector, storage)
498501
}

pgserver/backup_handler.go

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,6 @@ func parseBackupSQL(sql string) (*BackupConfig, error) {
8989
}
9090

9191
func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, error) {
92-
// TODO(neo.zty): Add support for backing up multiple databases once MyDuck Server supports multi-database functionality.
93-
if backupConfig.DbName != h.server.Provider.CatalogName() {
94-
return "", fmt.Errorf("backup database name %s does not match server database name %s",
95-
backupConfig.DbName, h.server.Provider.CatalogName())
96-
}
97-
9892
sqlCtx, err := h.duckHandler.sm.NewContextWithQuery(context.Background(), h.mysqlConn, "")
9993
if err != nil {
10094
return "", fmt.Errorf("failed to create context for query: %w", err)
@@ -114,7 +108,7 @@ func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, e
114108
}
115109

116110
msg, err := backupConfig.StorageConfig.UploadFile(
117-
h.server.Provider.DataDir(), h.server.Provider.DbFile(), backupConfig.RemotePath)
111+
h.server.Provider.DataDir(), backupConfig.DbName+".db", backupConfig.RemotePath)
118112
if err != nil {
119113
return "", err
120114
}
@@ -133,12 +127,7 @@ func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, e
133127

134128
func (h *ConnectionHandler) restartServer(readOnly bool) error {
135129
provider := h.server.Provider
136-
err := provider.Restart(readOnly)
137-
if err != nil {
138-
return err
139-
}
140-
141-
return h.server.Provider.Pool().Reset(provider.CatalogName(), provider.Connector(), provider.Storage())
130+
return provider.Restart(readOnly)
142131
}
143132

144133
func doCheckpoint(sqlCtx *sql.Context) error {

pgserver/connection_handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ func (h *ConnectionHandler) chooseInitialDatabase(startupMessage *pgproto3.Start
313313
}
314314
if db == "postgres" || db == "mysql" {
315315
if provider := h.duckHandler.GetCatalogProvider(); provider != nil {
316-
db = provider.CatalogName()
316+
db = provider.DefaultCatalogName()
317317
}
318318
}
319319

0 commit comments

Comments
 (0)