diff --git a/README.md b/README.md index d2ff29f..641cb0a 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,8 @@ port="5432" user="postgres" dbname="postgres" password="xxxxx" +sslmode="disable" +pingCheck=true maxIdleConn = 2 maxOpenConn = 2 diff --git a/kshieldconfig_example.toml b/kshieldconfig_example.toml index beab5e9..9a60492 100644 --- a/kshieldconfig_example.toml +++ b/kshieldconfig_example.toml @@ -5,6 +5,8 @@ # port = "5432" # user = "postgres" # password = "password123" +# sslmode = "disable" +# pingCheck = true # dbname = "mydb" # maxIdleConn = 10 # maxOpenConn = 100 diff --git a/pkg/config/config.go b/pkg/config/config.go index 9cf3182..74a9543 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -215,6 +215,7 @@ type MySQL struct { Port string `toml:"port"` User string `toml:"user"` Password string `toml:"password"` + PingCheck bool `toml:"pingCheck"` // DBName string `toml:"dbname"` // SSLmode string `toml:"sslmode"` MaxIdleConn int `toml:"maxIdleConn"` diff --git a/pkg/mysqldb/connect.go b/pkg/mysqldb/connect.go index 5aea3e2..1375007 100644 --- a/pkg/mysqldb/connect.go +++ b/pkg/mysqldb/connect.go @@ -28,10 +28,13 @@ func Open(conf config.MySQL) (*sql.DB, string, error) { Msg("Failed to connect to database") return nil, "", err } - err = db.Ping() - if err != nil { - // fmt.Printf("Failed to connect to database. Error: %s", err.Error()) - return nil, "", err + + if conf.PingCheck { + err = db.Ping() + if err != nil { + // fmt.Printf("Failed to connect to database. Error: %s", err.Error()) + return nil, "", err + } } if conf.MaxIdleConn > 0 { db.SetMaxIdleConns(conf.MaxIdleConn) diff --git a/pkg/postgresdb/connect.go b/pkg/postgresdb/connect.go index 71f0a34..e402bec 100644 --- a/pkg/postgresdb/connect.go +++ b/pkg/postgresdb/connect.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "regexp" + "strings" _ "github.com/lib/pq" "github.com/rs/zerolog/log" @@ -15,7 +16,8 @@ type Postgres struct { User string `toml:"user"` Password string `toml:"password"` DBName string `toml:"dbname"` - // SSLmode string `toml:"sslmode"` + SSLmode string `toml:"sslmode"` + PingCheck bool `toml:"pingCheck"` MaxIdleConn int `toml:"maxIdleConn"` MaxOpenConn int `toml:"maxOpenConn"` } @@ -35,7 +37,7 @@ var re = regexp.MustCompile(`(?m)(?:host=)([^\s]+)`) func Open(conf Postgres) (*sql.DB, string, error) { // "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable" - url := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", conf.Host, conf.Port, conf.User, conf.Password, conf.DBName) + url := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s pingcheck=%t", conf.Host, conf.Port, conf.User, conf.Password, conf.DBName, conf.SSLmode, conf.PingCheck) db, err := ConnectDatabaseUsingConnectionString(url) if err != nil { @@ -68,7 +70,11 @@ func Open(conf Postgres) (*sql.DB, string, error) { // ConnectDatabaseUsingConnectionString connects to a PostgreSQL database using the provided connection string. // It returns a database connection, the connection string, and an error if any. + func ConnectDatabaseUsingConnectionString(url string) (*sql.DB, error) { + parts := strings.Split(url, "pingcheck=") + pingcheck := parts[1] + url = parts[0] db, err := sql.Open("postgres", url) if err != nil { log.Error(). @@ -78,15 +84,18 @@ func ConnectDatabaseUsingConnectionString(url string) (*sql.DB, error) { return nil, err } - err = db.Ping() - if err != nil { - log.Error(). - Err(err). - Str("conn", url). - Msg("Failed to ping database") - db.Close() - return nil, err + if len(pingcheck) > 0 && pingcheck == "true" { + err = db.Ping() + if err != nil { + log.Error(). + Err(err). + Str("conn", url). + Msg("Failed to ping database") + db.Close() + return nil, err + } } + return db, nil }