Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 0 additions & 165 deletions flytestdlib/database/db.go

This file was deleted.

101 changes: 0 additions & 101 deletions flytestdlib/database/postgres.go
Original file line number Diff line number Diff line change
@@ -1,115 +1,14 @@
package database

import (
"context"
"errors"
"fmt"
"io/ioutil"
"os"
"strings"

"github.com/jackc/pgx/v5/pgconn"
"gorm.io/driver/postgres"
"gorm.io/gorm"

"github.com/flyteorg/flyte/flytestdlib/logger"
)

const PqInvalidDBCode = "3D000"
const PqDbAlreadyExistsCode = "42P04"
const PgDuplicatedForeignKey = "23503"
const PgDuplicatedKey = "23505"
const defaultDB = "postgres"

// Resolves a password value from either a user-provided inline value or a filepath whose contents contain a password.
func resolvePassword(ctx context.Context, passwordVal, passwordPath string) string {
password := passwordVal
if len(passwordPath) > 0 {
if _, err := os.Stat(passwordPath); os.IsNotExist(err) {
logger.Fatalf(ctx,
"missing database password at specified path [%s]", passwordPath)
}
passwordVal, err := ioutil.ReadFile(passwordPath)
if err != nil {
logger.Fatalf(ctx, "failed to read database password from path [%s] with err: %v",
passwordPath, err)
}
// Passwords can contain special characters as long as they are percent encoded
// https://www.postgresql.org/docs/current/libpq-connect.html
password = strings.TrimSpace(string(passwordVal))
}
return password
}

// Produces the DSN (data source name) for opening a postgres db connection.
func getPostgresDsn(ctx context.Context, pgConfig PostgresConfig) string {
password := resolvePassword(ctx, pgConfig.Password, pgConfig.PasswordPath)
if len(password) == 0 {
// The password-less case is included for development environments.
return fmt.Sprintf("host=%s port=%d dbname=%s user=%s sslmode=disable",
pgConfig.Host, pgConfig.Port, pgConfig.DbName, pgConfig.User)
}
return fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s %s",
pgConfig.Host, pgConfig.Port, pgConfig.DbName, pgConfig.User, password, pgConfig.ExtraOptions)
}

// Produces the DSN (data source name) for the read replica for opening a postgres db connection.
func getPostgresReadDsn(ctx context.Context, pgConfig PostgresConfig) string {
password := resolvePassword(ctx, pgConfig.Password, pgConfig.PasswordPath)
if len(password) == 0 {
// The password-less case is included for development environments.
return fmt.Sprintf("host=%s port=%d dbname=%s user=%s sslmode=disable",
pgConfig.ReadReplicaHost, pgConfig.Port, pgConfig.DbName, pgConfig.User)
}
return fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s %s",
pgConfig.ReadReplicaHost, pgConfig.Port, pgConfig.DbName, pgConfig.User, password, pgConfig.ExtraOptions)
}

// CreatePostgresDbIfNotExists creates DB if it doesn't exist for the passed in config
func CreatePostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, pgConfig PostgresConfig) (*gorm.DB, error) {
dialector := postgres.Open(getPostgresDsn(ctx, pgConfig))
gormDb, err := gorm.Open(dialector, gormConfig)
if err == nil {
return gormDb, nil
}
if !IsPgErrorWithCode(err, PqInvalidDBCode) {
logger.Errorf(ctx, "Unhandled error connecting to postgres, pg [%v], gorm [%v]: %v", pgConfig, gormConfig, err)
return nil, err
}
logger.Warningf(ctx, "Database [%v] does not exist", pgConfig.DbName)

// Every postgres installation includes a 'postgres' database by default. We connect to that now in order to
// initialize the user-specified database.
defaultDbPgConfig := pgConfig
defaultDbPgConfig.DbName = defaultDB
defaultDBDialector := postgres.Open(getPostgresDsn(ctx, defaultDbPgConfig))
gormDb, err = gorm.Open(defaultDBDialector, gormConfig)
if err != nil {
return nil, err
}

// Because we asserted earlier that the db does not exist, we create it now.
logger.Infof(ctx, "Creating database %v", pgConfig.DbName)

// NOTE: golang sql drivers do not support parameter injection for CREATE calls
createDBStatement := fmt.Sprintf("CREATE DATABASE %s", pgConfig.DbName)
result := gormDb.Exec(createDBStatement)

if result.Error != nil {
if !IsPgErrorWithCode(result.Error, PqDbAlreadyExistsCode) {
return nil, result.Error
}
logger.Warningf(ctx, "Got DB already exists error for [%s], skipping...", pgConfig.DbName)
}
// Now try connecting to the db again
return gorm.Open(dialector, gormConfig)
}

// CreatePostgresReadOnlyDbConnection creates readonly DB connection and returns the gorm.DB object and error
func CreatePostgresReadOnlyDbConnection(ctx context.Context, gormConfig *gorm.Config, pgConfig PostgresConfig) (*gorm.DB, error) {
dialector := postgres.Open(getPostgresReadDsn(ctx, pgConfig))
return gorm.Open(dialector, gormConfig)
}

func IsPgErrorWithCode(err error, code string) bool {
pgErr := &pgconn.PgError{}
Expand Down
98 changes: 0 additions & 98 deletions flytestdlib/database/postgres_test.go
Original file line number Diff line number Diff line change
@@ -1,112 +1,14 @@
package database

import (
"context"
"errors"
"io/ioutil"
"net"
"os"
"testing"

"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/assert"
)

func TestResolvePassword(t *testing.T) {
password := "123abc"
tmpFile, err := os.CreateTemp("", "prefix")
if err != nil {
t.Errorf("Couldn't open temp file: %v", err)
}
defer tmpFile.Close()
if _, err = tmpFile.WriteString(password); err != nil {
t.Errorf("Couldn't write to temp file: %v", err)
}
resolvedPassword := resolvePassword(context.TODO(), "", tmpFile.Name())
assert.Equal(t, resolvedPassword, password)
}

func TestGetPostgresDsn(t *testing.T) {
pgConfig := PostgresConfig{
Host: "localhost",
Port: 5432,
DbName: "postgres",
User: "postgres",
ExtraOptions: "sslmode=disable",
}
t.Run("no password", func(t *testing.T) {
dsn := getPostgresDsn(context.TODO(), pgConfig)
assert.Equal(t, "host=localhost port=5432 dbname=postgres user=postgres sslmode=disable", dsn)
})
t.Run("with password", func(t *testing.T) {
pgConfig.Password = "pass"
dsn := getPostgresDsn(context.TODO(), pgConfig)
assert.Equal(t, "host=localhost port=5432 dbname=postgres user=postgres password=pass sslmode=disable", dsn)

})
t.Run("with password, no extra", func(t *testing.T) {
pgConfig.Password = "pass"
pgConfig.ExtraOptions = ""
dsn := getPostgresDsn(context.TODO(), pgConfig)
assert.Equal(t, "host=localhost port=5432 dbname=postgres user=postgres password=pass ", dsn)
})
t.Run("with password path", func(t *testing.T) {
password := "123abc"
tmpFile, err := ioutil.TempFile("", "prefix")
if err != nil {
t.Errorf("Couldn't open temp file: %v", err)
}
defer tmpFile.Close()
if _, err = tmpFile.WriteString(password); err != nil {
t.Errorf("Couldn't write to temp file: %v", err)
}
pgConfig.PasswordPath = tmpFile.Name()
dsn := getPostgresDsn(context.TODO(), pgConfig)
assert.Equal(t, "host=localhost port=5432 dbname=postgres user=postgres password=123abc ", dsn)
})
}

func TestGetPostgresReadDsn(t *testing.T) {
pgConfig := PostgresConfig{
Host: "localhost",
ReadReplicaHost: "readReplicaHost",
Port: 5432,
DbName: "postgres",
User: "postgres",
ExtraOptions: "sslmode=disable",
}
t.Run("no password", func(t *testing.T) {
dsn := getPostgresReadDsn(context.TODO(), pgConfig)
assert.Equal(t, "host=readReplicaHost port=5432 dbname=postgres user=postgres sslmode=disable", dsn)
})
t.Run("with password", func(t *testing.T) {
pgConfig.Password = "passw"
dsn := getPostgresReadDsn(context.TODO(), pgConfig)
assert.Equal(t, "host=readReplicaHost port=5432 dbname=postgres user=postgres password=passw sslmode=disable", dsn)

})
t.Run("with password, no extra", func(t *testing.T) {
pgConfig.Password = "passwo"
pgConfig.ExtraOptions = ""
dsn := getPostgresReadDsn(context.TODO(), pgConfig)
assert.Equal(t, "host=readReplicaHost port=5432 dbname=postgres user=postgres password=passwo ", dsn)
})
t.Run("with password path", func(t *testing.T) {
password := "1234abc"
tmpFile, err := ioutil.TempFile("", "prefix")
if err != nil {
t.Errorf("Couldn't open temp file: %v", err)
}
defer tmpFile.Close()
if _, err = tmpFile.WriteString(password); err != nil {
t.Errorf("Couldn't write to temp file: %v", err)
}
pgConfig.PasswordPath = tmpFile.Name()
dsn := getPostgresReadDsn(context.TODO(), pgConfig)
assert.Equal(t, "host=readReplicaHost port=5432 dbname=postgres user=postgres password=1234abc ", dsn)
})
}

type wrappedError struct {
err error
}
Expand Down
Loading