diff --git a/database/dialect.go b/database/dialect.go index 9f138f560..602dd7aba 100644 --- a/database/dialect.go +++ b/database/dialect.go @@ -5,25 +5,28 @@ import ( "database/sql" "errors" "fmt" + "github.com/pressly/goose/v3/internal/dialect" "github.com/pressly/goose/v3/internal/dialect/dialectquery" ) // Dialect is the type of database dialect. -type Dialect string +type Dialect = dialect.Dialect + +var ErrUnknownDialect = dialect.ErrUnknownDialect const ( - DialectClickHouse Dialect = "clickhouse" - DialectMSSQL Dialect = "mssql" - DialectMySQL Dialect = "mysql" - DialectPostgres Dialect = "postgres" - DialectRedshift Dialect = "redshift" - DialectSQLite3 Dialect = "sqlite3" - DialectTiDB Dialect = "tidb" - DialectTurso Dialect = "turso" - DialectVertica Dialect = "vertica" - DialectYdB Dialect = "ydb" - DialectStarrocks Dialect = "starrocks" + DialectClickHouse Dialect = dialect.Clickhouse + DialectMSSQL Dialect = dialect.Mssql + DialectMySQL Dialect = dialect.Mysql + DialectPostgres Dialect = dialect.Postgres + DialectRedshift Dialect = dialect.Redshift + DialectSQLite3 Dialect = dialect.Sqlite3 + DialectTiDB Dialect = dialect.Tidb + DialectTurso Dialect = dialect.Turso + DialectVertica Dialect = dialect.Vertica + DialectYdB Dialect = dialect.Ydb + DialectStarrocks Dialect = dialect.Starrocks ) // NewStore returns a new [Store] implementation for the given dialect. @@ -49,7 +52,7 @@ func NewStore(dialect Dialect, tablename string) (Store, error) { } querier, ok := lookup[dialect] if !ok { - return nil, fmt.Errorf("unknown dialect: %q", dialect) + return nil, fmt.Errorf("%s: %w", dialect, ErrUnknownDialect) } return &store{ tablename: tablename, diff --git a/dialect.go b/dialect.go index ecebd144f..3844ec92c 100644 --- a/dialect.go +++ b/dialect.go @@ -1,28 +1,31 @@ package goose import ( - "fmt" - - "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/dialect" ) -// Dialect is the type of database dialect. It is an alias for [database.Dialect]. -type Dialect = database.Dialect +// Dialect is the type of database dialect. It is an alias for [dialect.Dialect]. +type Dialect = dialect.Dialect const ( - DialectClickHouse Dialect = database.DialectClickHouse - DialectMSSQL Dialect = database.DialectMSSQL - DialectMySQL Dialect = database.DialectMySQL - DialectPostgres Dialect = database.DialectPostgres - DialectRedshift Dialect = database.DialectRedshift - DialectSQLite3 Dialect = database.DialectSQLite3 - DialectTiDB Dialect = database.DialectTiDB - DialectVertica Dialect = database.DialectVertica - DialectYdB Dialect = database.DialectYdB - DialectStarrocks Dialect = database.DialectStarrocks + DialectClickHouse Dialect = dialect.Clickhouse + DialectMSSQL Dialect = dialect.Mssql + DialectMySQL Dialect = dialect.Mysql + DialectPostgres Dialect = dialect.Postgres + DialectRedshift Dialect = dialect.Redshift + DialectSQLite3 Dialect = dialect.Sqlite3 + DialectTiDB Dialect = dialect.Tidb + DialectVertica Dialect = dialect.Vertica + DialectYdB Dialect = dialect.Ydb + DialectTurso Dialect = dialect.Turso + DialectStarrocks Dialect = dialect.Starrocks ) +var ErrUnknownDialect = dialect.ErrUnknownDialect + +// GetDialect gets the dialect used in the goose package. +var GetDialect = dialect.GetDialect + func init() { store, _ = dialect.NewStore(dialect.Postgres) } @@ -30,35 +33,22 @@ func init() { var store dialect.Store // SetDialect sets the dialect to use for the goose package. -func SetDialect(s string) error { - var d dialect.Dialect - switch s { - case "postgres", "pgx": - d = dialect.Postgres - case "mysql": - d = dialect.Mysql - case "sqlite3", "sqlite": - d = dialect.Sqlite3 - case "mssql", "azuresql", "sqlserver": - d = dialect.Sqlserver - case "redshift": - d = dialect.Redshift - case "tidb": - d = dialect.Tidb - case "clickhouse": - d = dialect.Clickhouse - case "vertica": - d = dialect.Vertica - case "ydb": - d = dialect.Ydb - case "turso": - d = dialect.Turso - case "starrocks": - d = dialect.Starrocks - default: - return fmt.Errorf("%q: unknown dialect", s) +func SetDialect[D string | Dialect](d D) error { + var ( + v Dialect + err error + ) + + switch t := any(d).(type) { + case string: + v, err = GetDialect(t) + if err != nil { + return err + } + case Dialect: + v = t } - var err error - store, err = dialect.NewStore(d) + + store, err = dialect.NewStore(v) return err } diff --git a/dialect_test.go b/dialect_test.go new file mode 100644 index 000000000..bb521c51b --- /dev/null +++ b/dialect_test.go @@ -0,0 +1,41 @@ +package goose_test + +import ( + "github.com/pressly/goose/v3" + "github.com/stretchr/testify/require" + "testing" +) + +func TestGetDialect(t *testing.T) { + tests := []struct { + name string + want goose.Dialect + }{ + { + name: "postgres", + want: goose.DialectPostgres, + }, + { + name: "mysql", + want: goose.DialectMySQL, + }, + { + name: "MySQL", + want: goose.DialectMySQL, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dialect, err := goose.GetDialect(test.name) + require.NoError(t, err) + require.Equal(t, test.want, dialect) + }) + } +} + +func TestGetDialectFail(t *testing.T) { + dialect, err := goose.GetDialect("fail") + require.Empty(t, dialect) + require.ErrorIs(t, err, goose.ErrUnknownDialect) + require.EqualError(t, err, "fail: unknown dialect") +} diff --git a/internal/dialect/dialects.go b/internal/dialect/dialects.go index acf064258..e938a782c 100644 --- a/internal/dialect/dialects.go +++ b/internal/dialect/dialects.go @@ -1,12 +1,22 @@ package dialect +import ( + "errors" + "fmt" + "strings" +) + // Dialect is the type of database dialect. type Dialect string +var ErrUnknownDialect = errors.New("unknown dialect") + const ( - Postgres Dialect = "postgres" - Mysql Dialect = "mysql" - Sqlite3 Dialect = "sqlite3" + Postgres Dialect = "postgres" + Mysql Dialect = "mysql" + Sqlite3 Dialect = "sqlite3" + Mssql Dialect = "mssql" + // Deprecated: use [Mssql] Sqlserver Dialect = "sqlserver" Redshift Dialect = "redshift" Tidb Dialect = "tidb" @@ -16,3 +26,44 @@ const ( Turso Dialect = "turso" Starrocks Dialect = "starrocks" ) + +// GetDialect gets the dialect used in the goose package. +func GetDialect(s string) (Dialect, error) { + switch strings.ToLower(s) { + case "postgres", "pgx": + return Postgres, nil + case "mysql": + return Mysql, nil + case "sqlite3", "sqlite": + return Sqlite3, nil + case "mssql", "azuresql", "sqlserver": + return Mssql, nil + case "redshift": + return Redshift, nil + case "tidb": + return Tidb, nil + case "clickhouse": + return Clickhouse, nil + case "vertica": + return Vertica, nil + case "ydb": + return Ydb, nil + case "turso": + return Turso, nil + case "starrocks": + return Starrocks, nil + default: + return "", fmt.Errorf("%s: %w", s, ErrUnknownDialect) + } +} + +func (d *Dialect) UnmarshalText(text []byte) error { + dialect, err := GetDialect(string(text)) + if err != nil { + return err + } + + *d = dialect + + return nil +} diff --git a/internal/dialect/dialects_test.go b/internal/dialect/dialects_test.go new file mode 100644 index 000000000..a98c7e5cf --- /dev/null +++ b/internal/dialect/dialects_test.go @@ -0,0 +1,58 @@ +package dialect_test + +import ( + "github.com/pressly/goose/v3/internal/dialect" + "github.com/stretchr/testify/require" + "testing" +) + +var _testUnmarshalData = []struct { + name string + want dialect.Dialect +}{ + { + name: "postgres", + want: dialect.Postgres, + }, + { + name: "mysql", + want: dialect.Mysql, + }, + { + name: "MySQL", + want: dialect.Mysql, + }, +} + +func TestDialect_GetDialect(t *testing.T) { + for _, test := range _testUnmarshalData { + t.Run(test.name, func(t *testing.T) { + d, err := dialect.GetDialect(test.name) + require.NoError(t, err) + require.Equal(t, test.want, d) + }) + } +} + +func TestDialect_GetDialectFail(t *testing.T) { + d, err := dialect.GetDialect("fail") + require.Empty(t, d) + require.ErrorIs(t, err, dialect.ErrUnknownDialect) + require.EqualError(t, err, "fail: unknown dialect") +} + +func TestDialect_UnmarshalText(t *testing.T) { + for _, test := range _testUnmarshalData { + t.Run(test.name, func(t *testing.T) { + var d dialect.Dialect + require.NoError(t, d.UnmarshalText([]byte(test.name))) + }) + } +} + +func TestDialect_UnmarshalTextFail(t *testing.T) { + var d dialect.Dialect + var err = d.UnmarshalText([]byte("fail")) + require.ErrorIs(t, err, dialect.ErrUnknownDialect) + require.EqualError(t, err, "fail: unknown dialect") +} diff --git a/internal/dialect/store.go b/internal/dialect/store.go index e9b768f91..271ede1b0 100644 --- a/internal/dialect/store.go +++ b/internal/dialect/store.go @@ -55,7 +55,7 @@ func NewStore(d Dialect) (Store, error) { querier = &dialectquery.Mysql{} case Sqlite3: querier = &dialectquery.Sqlite3{} - case Sqlserver: + case Mssql, Sqlserver: querier = &dialectquery.Sqlserver{} case Redshift: querier = &dialectquery.Redshift{}