Skip to content

Commit dabd2d1

Browse files
authored
Add support for Options.TLSConfig.ServerName (#310)
Additionally bump the shared test cases and adjust error handling around branch and database conflicts in connection configuration.
1 parent bdef9e4 commit dabd2d1

File tree

5 files changed

+91
-66
lines changed

5 files changed

+91
-66
lines changed

internal/client/connutils.go

Lines changed: 78 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ type connConfig struct {
6666
waitUntilAvailable time.Duration
6767
tlsCAData []byte
6868
tlsSecurity string
69+
tlsServerName string
6970
serverSettings *snc.ServerSettings
7071
secretKey string
7172
}
@@ -88,6 +89,7 @@ func (c *connConfig) tlsConfig() (*tls.Config, error) {
8889
tlsConfig := &tls.Config{
8990
RootCAs: roots,
9091
NextProtos: []string{"edgedb-binary"},
92+
ServerName: c.tlsServerName,
9193
}
9294

9395
switch c.tlsSecurity {
@@ -128,11 +130,11 @@ type configResolver struct {
128130
host cfgVal // string
129131
port cfgVal // int
130132
database cfgVal // string
131-
branch cfgVal // string
132133
user cfgVal // string
133134
password cfgVal // OptionalStr
134135
tlsCAData cfgVal // []byte
135136
tlsSecurity cfgVal // string
137+
tlsServerName cfgVal // string
136138
waitUntilAvailable cfgVal // time.Duration
137139
serverSettings *snc.ServerSettings
138140
secretKey cfgVal // string
@@ -217,17 +219,6 @@ func (r *configResolver) setDatabase(val, source string) error {
217219
return nil
218220
}
219221

220-
func (r *configResolver) setBranch(val, source string) error {
221-
if r.branch.val != nil {
222-
return nil
223-
}
224-
if val == "" {
225-
return errors.New(`invalid branch name: ""`)
226-
}
227-
r.branch = cfgVal{val: val, source: source}
228-
return nil
229-
}
230-
231222
func (r *configResolver) setUser(val, source string) error {
232223
if r.user.val != nil {
233224
return nil
@@ -279,6 +270,15 @@ func (r *configResolver) setTLSSecurity(val string, source string) error {
279270
return nil
280271
}
281272

273+
func (r *configResolver) setTLSServerName(val string, source string) error {
274+
if r.tlsServerName.val != nil {
275+
return nil
276+
}
277+
278+
r.tlsServerName = cfgVal{val: val, source: source}
279+
return nil
280+
}
281+
282282
func (r *configResolver) setWaitUntilAvailable(
283283
val time.Duration,
284284
source string,
@@ -354,7 +354,7 @@ func (r *configResolver) resolveOptions(
354354
}
355355

356356
if opts.Branch != "" {
357-
if e := r.setBranch(opts.Branch, "Branch options"); e != nil {
357+
if e := r.setDatabase(opts.Branch, "Branch options"); e != nil {
358358
return e
359359
}
360360
}
@@ -424,6 +424,14 @@ func (r *configResolver) resolveOptions(
424424
"TLSOptions.SecurityMode option")
425425
}
426426

427+
if opts.TLSOptions.ServerName != "" {
428+
secSources = append(secSources, "TLSOptions.ServerName")
429+
err = r.setTLSServerName(
430+
opts.TLSOptions.ServerName,
431+
"TLSOptions.ServerName options",
432+
)
433+
}
434+
427435
if len(secSources) > 1 {
428436
return fmt.Errorf(
429437
"mutually exclusive options set in Options: %v",
@@ -502,50 +510,24 @@ func (r *configResolver) resolveDSN(
502510
"cannot be present at the same time")
503511
}
504512

505-
if r.database.val != nil {
506-
return fmt.Errorf(
507-
"`branch` in DSN and %s are mutually exclusive options",
508-
r.database.source,
509-
)
510-
}
511-
512-
val, err = popDSNValue(query, db, "branch", r.branch.val == nil)
513+
val, err = popDSNValue(query, db, "branch", r.database.val == nil)
513514
if err != nil {
514515
return err
515516
} else if val.val != nil {
516517
br := strings.TrimPrefix(val.val.(string), "/")
517-
if e := r.setBranch(br, source+val.source); e != nil {
518+
if e := r.setDatabase(br, source+val.source); e != nil {
518519
return e
519520
}
520521
}
521522
} else {
522-
if r.branch.val != nil {
523-
if queryContains("database", query) {
524-
return fmt.Errorf(
525-
"`database` in DSN and %s are mutually exclusive options",
526-
r.branch.source,
527-
)
528-
}
529-
530-
val, err = popDSNValue(query, db, "branch", r.branch.val == nil)
531-
if err != nil {
532-
return err
533-
} else if val.val != nil {
534-
br := strings.TrimPrefix(val.val.(string), "/")
535-
if e := r.setBranch(br, source+val.source); e != nil {
536-
return e
537-
}
538-
}
539-
} else {
540-
val, err = popDSNValue(
541-
query, db, "database", r.database.val == nil)
542-
if err != nil {
543-
return err
544-
} else if val.val != nil {
545-
db := strings.TrimPrefix(val.val.(string), "/")
546-
if e := r.setDatabase(db, source+val.source); e != nil {
547-
return e
548-
}
523+
val, err = popDSNValue(
524+
query, db, "database", r.database.val == nil)
525+
if err != nil {
526+
return err
527+
} else if val.val != nil {
528+
db := strings.TrimPrefix(val.val.(string), "/")
529+
if e := r.setDatabase(db, source+val.source); e != nil {
530+
return e
549531
}
550532
}
551533
}
@@ -614,6 +596,22 @@ func (r *configResolver) resolveDSN(
614596
}
615597
}
616598

599+
val, err = popDSNValue(
600+
query,
601+
"",
602+
"tls_server_name",
603+
r.tlsServerName.val == nil,
604+
)
605+
if err != nil {
606+
return err
607+
}
608+
if val.val != nil {
609+
err = r.setTLSServerName(val.val.(string), source+val.source)
610+
if err != nil {
611+
return err
612+
}
613+
}
614+
617615
val, err = popDSNValue(
618616
query,
619617
"",
@@ -707,7 +705,7 @@ func (r *configResolver) applyCredentials(
707705
}
708706

709707
if br, ok := creds.branch.Get(); ok && br != "" {
710-
if e := r.setBranch(br, source); e != nil {
708+
if e := r.setDatabase(br, source); e != nil {
711709
return e
712710
}
713711
}
@@ -734,15 +732,21 @@ func (r *configResolver) applyCredentials(
734732
}
735733

736734
func (r *configResolver) resolveEnvVars(paths *cfgPaths) (bool, error) {
737-
if db, ok := os.LookupEnv("EDGEDB_DATABASE"); ok {
735+
db, dbOk := os.LookupEnv("EDGEDB_DATABASE")
736+
if dbOk {
738737
err := r.setDatabase(db, "EDGEDB_DATABASE environment variable")
739738
if err != nil {
740739
return false, err
741740
}
742741
}
743742

744-
if db, ok := os.LookupEnv("EDGEDB_BRANCH"); ok {
745-
err := r.setBranch(db, "EDGEDB_BRANCH environment variable")
743+
if branch, ok := os.LookupEnv("EDGEDB_BRANCH"); ok {
744+
if dbOk {
745+
return false, errors.New(
746+
"mutually exclusive options EDGEDB_DATABASE and " +
747+
"EDGEDB_BRANCH environment variables are set")
748+
}
749+
err := r.setDatabase(branch, "EDGEDB_BRANCH environment variable")
746750
if err != nil {
747751
return false, err
748752
}
@@ -784,6 +788,16 @@ func (r *configResolver) resolveEnvVars(paths *cfgPaths) (bool, error) {
784788
}
785789
}
786790

791+
if val, ok := os.LookupEnv("EDGEDB_TLS_SERVER_NAME"); ok {
792+
e := r.setTLSServerName(
793+
val,
794+
"EDGEDB_TLS_SERVER_NAME environment variable",
795+
)
796+
if e != nil {
797+
return false, e
798+
}
799+
}
800+
787801
if len(tlsCaSources) > 1 {
788802
return false, fmt.Errorf(
789803
"mutually exclusive environment variables set: %v",
@@ -946,18 +960,8 @@ func (r *configResolver) config(opts *Options) (*connConfig, error) {
946960
database := "edgedb"
947961
branch := "__default__"
948962
if r.database.val != nil {
949-
if r.branch.val != nil {
950-
return nil, fmt.Errorf(
951-
"%s and %s are mutually exclusive options",
952-
r.database.source,
953-
r.branch.source,
954-
)
955-
}
956963
database = r.database.val.(string)
957964
branch = database
958-
} else if r.branch.val != nil {
959-
branch = r.branch.val.(string)
960-
database = branch
961965
}
962966

963967
user := "edgedb"
@@ -980,6 +984,11 @@ func (r *configResolver) config(opts *Options) (*connConfig, error) {
980984
tlsSecurity = r.tlsSecurity.val.(string)
981985
}
982986

987+
tlsServerName := ""
988+
if r.tlsServerName.val != nil {
989+
tlsServerName = r.tlsServerName.val.(string)
990+
}
991+
983992
secretKey := ""
984993
if r.secretKey.val != nil {
985994
secretKey = r.secretKey.val.(string)
@@ -1033,6 +1042,7 @@ func (r *configResolver) config(opts *Options) (*connConfig, error) {
10331042
serverSettings: r.serverSettings,
10341043
tlsCAData: certData,
10351044
tlsSecurity: tlsSecurity,
1045+
tlsServerName: tlsServerName,
10361046
secretKey: secretKey,
10371047
}, nil
10381048
}
@@ -1268,6 +1278,11 @@ var dsnKeyLookup = map[string][]string{
12681278
"password": {"password", "password_env", "password_file"},
12691279
"tls_ca_file": {"tls_ca_file", "tls_ca_file_env"},
12701280
"tls_security": {"tls_security", "tls_security_env", "tls_security_file"},
1281+
"tls_server_name": {
1282+
"tls_server_name",
1283+
"tls_server_name_env",
1284+
"tls_server_name_file",
1285+
},
12711286
"tls_verify_hostname": {
12721287
"tls_verify_hostname",
12731288
"tls_verify_hostname_env",

internal/client/connutils_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,8 @@ func TestConnectionParameterResolution(t *testing.T) {
626626
options.TLSOptions.CA = getBytes(t, opts, "tlsCA")
627627
options.TLSOptions.SecurityMode = TLSSecurityMode(
628628
getStr(t, opts, "tlsSecurity"))
629+
options.TLSOptions.ServerName = getStr(
630+
t, opts, "tlsServerName")
629631
if opts["serverSettings"] != nil {
630632
ss := opts["serverSettings"].(map[string]interface{})
631633
options.ServerSettings = make(map[string][]byte, len(ss))
@@ -673,6 +675,10 @@ func TestConnectionParameterResolution(t *testing.T) {
673675
expectedResult.secretKey = key.(string)
674676
}
675677

678+
if key := res["tlsServerName"]; key != nil {
679+
expectedResult.tlsServerName = key.(string)
680+
}
681+
676682
ss := res["serverSettings"].(map[string]interface{})
677683
for k, v := range ss {
678684
expectedResult.serverSettings.Set(k, []byte(v.(string)))

internal/client/credentials.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,11 @@ func validateCredentials(data map[string]interface{}) (*credentials, error) {
106106
result.host.Set(h)
107107
}
108108

109-
if inMap("database", data) && inMap("branch", data) {
109+
if inMap("database", data) &&
110+
inMap("branch", data) &&
111+
data["database"] != data["branch"] {
110112
return nil, errors.New(
111-
"`database` and `branch` are mutually exclusive")
113+
"`database` and `branch` are both set but do not match")
112114
}
113115

114116
if database, ok := data["database"]; ok {

internal/client/options.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ type TLSOptions struct {
127127
CAFile string
128128
// Determines how strict we are with TLS checks
129129
SecurityMode TLSSecurityMode
130+
// Used to verify the hostname on the returned certificates
131+
ServerName string
130132
}
131133

132134
// TLSSecurityMode specifies how strict TLS validation is.

shared-client-testcases

0 commit comments

Comments
 (0)