Skip to content

Commit d998bcd

Browse files
Merge pull request #52 from abhijitWakchaure/master
add support for base64 encoded cert contents
2 parents 413d160 + 4ff5efc commit d998bcd

File tree

1 file changed

+59
-36
lines changed

1 file changed

+59
-36
lines changed

flow-state/store/postgres/connection.go

+59-36
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"database/sql"
55
"database/sql/driver"
66
"encoding/base64"
7+
b64 "encoding/base64"
78
"encoding/json"
89
"errors"
910
"fmt"
@@ -60,6 +61,7 @@ func decodeTLSParam(tlsparm string) string {
6061
}
6162
}
6263

64+
// NewDB opens a DB connection and pings it to verify connection
6365
func NewDB(settings map[string]interface{}) (*sql.DB, error) {
6466
var err error
6567

@@ -153,45 +155,28 @@ func NewDB(settings map[string]interface{}) (*sql.DB, error) {
153155
logCache.Errorf("could not get working dir due to %s", err.Error())
154156
return nil, fmt.Errorf("could not get working dir due to %s", err.Error())
155157
}
158+
156159
if s.CACert != "" {
157-
// check if input is already a filepath
158-
if strings.HasPrefix(s.CACert, "-----") {
159-
// input is not a file path
160-
pathCACert := filepath.Join(pwd, "caCert.pem")
161-
err = os.WriteFile(pathCACert, []byte(s.CACert), 0600)
162-
if err != nil {
163-
logCache.Errorf("could not create CA cert file due to %s", err.Error())
164-
return nil, fmt.Errorf("could not create CA cert file due to %s", err.Error())
165-
}
166-
s.CACert = pathCACert
160+
pathCACert := filepath.Join(pwd, "caCert.pem")
161+
s.CACert, err = getCertPath(pathCACert, s.CACert)
162+
if err != nil {
163+
return nil, err
167164
}
168165
conninfo = conninfo + fmt.Sprintf("sslrootcert=%s ", s.CACert)
169166
}
170167
if s.ClientCert != "" {
171-
// check if input is already a filepath
172-
if strings.HasPrefix(s.ClientCert, "-----") {
173-
// input is not a file path
174-
pathClientCert := filepath.Join(pwd, "clientCert.pem")
175-
err = os.WriteFile(pathClientCert, []byte(s.ClientCert), 0600)
176-
if err != nil {
177-
logCache.Errorf("could not create client cert file due to %s", err.Error())
178-
return nil, fmt.Errorf("could not create client cert file due to %s", err.Error())
179-
}
180-
s.ClientCert = pathClientCert
168+
pathClientCert := filepath.Join(pwd, "clientCert.pem")
169+
s.ClientCert, err = getCertPath(pathClientCert, s.ClientCert)
170+
if err != nil {
171+
return nil, err
181172
}
182173
conninfo = conninfo + fmt.Sprintf("sslcert=%s ", s.ClientCert)
183174
}
184175
if s.ClientKey != "" {
185-
// check if input is already a filepath
186-
if strings.HasPrefix(s.ClientKey, "-----") {
187-
// input is not a file path
188-
pathClientKey := filepath.Join(pwd, "cacert.pem")
189-
err = os.WriteFile(pathClientKey, []byte(s.ClientKey), 0600)
190-
if err != nil {
191-
logCache.Errorf("could not create client key file due to %s", err.Error())
192-
return nil, fmt.Errorf("could not create client key file due to %s", err.Error())
193-
}
194-
s.ClientKey = pathClientKey
176+
pathClientKey := filepath.Join(pwd, "cacert.pem")
177+
s.ClientKey, err = getCertPath(pathClientKey, s.ClientKey)
178+
if err != nil {
179+
return nil, err
195180
}
196181
conninfo = conninfo + fmt.Sprintf("sslkey=%s ", s.ClientKey)
197182
}
@@ -208,13 +193,12 @@ func NewDB(settings map[string]interface{}) (*sql.DB, error) {
208193
db, err = sql.Open("postgres", conninfo)
209194
if err != nil {
210195
return nil, fmt.Errorf("Could not open connection to database %s, %s", cDbName, err.Error())
211-
} else {
212-
err = db.Ping()
213-
if err != nil {
214-
return nil, fmt.Errorf("Could not open connection to database %s, %s", cDbName, err.Error())
215-
}
216-
dbConnected = 1
217196
}
197+
err = db.Ping()
198+
if err != nil {
199+
return nil, fmt.Errorf("Could not open connection to database %s, %s", cDbName, err.Error())
200+
}
201+
dbConnected = 1
218202
} else if cMaxConnRetryAttempts > 0 {
219203
logCache.Debugf("Maximum connection retry attempts allowed - %d", cMaxConnRetryAttempts)
220204
logCache.Debugf("Connection retry delay - %d", cConnRetryDelay)
@@ -288,6 +272,45 @@ func NewDB(settings map[string]interface{}) (*sql.DB, error) {
288272
return db, nil
289273
}
290274

275+
func getCertPath(certFilePath string, content string) (string, error) {
276+
certPrefix := "-----"
277+
certPrefixB64 := "LS0tLS"
278+
certFileBase := filepath.Base(certFilePath)
279+
// check if content is valid SSL cert
280+
if strings.HasPrefix(content, certPrefix) {
281+
// input is cert content
282+
logCache.Debugf("found actual cert contents for %s", certFileBase)
283+
err := os.WriteFile(certFilePath, []byte(content), 0600)
284+
if err != nil {
285+
logCache.Errorf("could not create %s due to %s", certFileBase, err.Error())
286+
return "", fmt.Errorf("could not create %s due to %s", certFileBase, err.Error())
287+
}
288+
return certFilePath, nil
289+
}
290+
if strings.HasPrefix(content, certPrefixB64) {
291+
// input is cert content but base64 encoded, decode first
292+
logCache.Debugf("found base64 encoded contents for %s", certFileBase)
293+
certBytes, err := b64.StdEncoding.DecodeString(content)
294+
if err != nil {
295+
logCache.Errorf("could not b64 decode %s due to %s", certFileBase, err.Error())
296+
return "", fmt.Errorf("could not b64 decode %s due to %s", certFileBase, err.Error())
297+
}
298+
err = os.WriteFile(certFilePath, certBytes, 0600)
299+
if err != nil {
300+
logCache.Errorf("could not create %s due to %s", certFileBase, err.Error())
301+
return "", fmt.Errorf("could not create %s due to %s", certFileBase, err.Error())
302+
}
303+
return certFilePath, nil
304+
}
305+
// assume input is path to cert file
306+
_, err := os.Stat(content)
307+
if err != nil {
308+
logCache.Errorf("could not read cert file at %s due to %s", content, err.Error())
309+
return "", fmt.Errorf("could not read cert file at %s due to %s", content, err.Error())
310+
}
311+
return content, nil
312+
}
313+
291314
// copyCertToTempFile creates temp mssql.pem file for running app in container
292315
// and sqlserver needs filepath for ssl cert so can not pass byte array which we get from connection tile
293316
func copyCertToTempFile(certdata []byte, name string) (string, error) {

0 commit comments

Comments
 (0)