Skip to content

Commit fdbe2ec

Browse files
authored
Accept null branch & database in credentials files (#385)
The gel CLI recently started writing credentials files that use `null` to indicate that the default branch name should be used instead of specifying a value explicitly. This caused the following error in gel-go. ``` `database` must be a string ``` Closes #382
1 parent 6206f47 commit fdbe2ec

File tree

5 files changed

+102
-36
lines changed

5 files changed

+102
-36
lines changed

internal/client/connutils.go

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ type configResolver struct {
133133
host cfgVal // string
134134
port cfgVal // int
135135
database cfgVal // string
136+
branch cfgVal // string
136137
user cfgVal // string
137138
password cfgVal // OptionalStr
138139
tlsCAData cfgVal // []byte
@@ -222,6 +223,19 @@ func (r *configResolver) setDatabase(val, source string) error {
222223
return nil
223224
}
224225

226+
func (r *configResolver) setBranch(val, source string) error {
227+
if r.branch.val != nil {
228+
return nil
229+
}
230+
231+
if val == "" {
232+
return errors.New(`invalid branch name: ""`)
233+
}
234+
235+
r.branch = cfgVal{val: val, source: source}
236+
return nil
237+
}
238+
225239
func (r *configResolver) setUser(val, source string) error {
226240
if r.user.val != nil {
227241
return nil
@@ -351,17 +365,23 @@ func (r *configResolver) resolveOptions(
351365
}
352366

353367
if opts.Database != "" {
354-
e := r.setDatabase(
355-
opts.Database,
356-
"Database options",
357-
)
358-
if e != nil {
368+
source := "Database options"
369+
if e := r.setDatabase(opts.Database, source); e != nil {
359370
return e
360371
}
372+
if opts.Branch == "" {
373+
if e := r.setBranch(opts.Database, source); e != nil {
374+
return e
375+
}
376+
}
361377
}
362378

363379
if opts.Branch != "" {
364-
if e := r.setDatabase(opts.Branch, "Branch options"); e != nil {
380+
source := "Branch options"
381+
if e := r.setBranch(opts.Branch, source); e != nil {
382+
return e
383+
}
384+
if e := r.setDatabase(opts.Branch, source); e != nil {
365385
return e
366386
}
367387
}
@@ -534,7 +554,7 @@ func (r *configResolver) resolveDSN(
534554
return err
535555
} else if val.val != nil {
536556
br := strings.TrimPrefix(val.val.(string), "/")
537-
if e := r.setDatabase(br, source+val.source); e != nil {
557+
if e := r.setBranch(br, source+val.source); e != nil {
538558
return e
539559
}
540560
}
@@ -545,7 +565,7 @@ func (r *configResolver) resolveDSN(
545565
return err
546566
} else if val.val != nil {
547567
db := strings.TrimPrefix(val.val.(string), "/")
548-
if e := r.setDatabase(db, source+val.source); e != nil {
568+
if e := r.setBranch(db, source+val.source); e != nil {
549569
return e
550570
}
551571
}
@@ -744,13 +764,15 @@ func (r *configResolver) applyCredentials(
744764
}
745765

746766
if br, ok := creds.branch.Get(); ok && br != "" {
747-
if e := r.setDatabase(br, source); e != nil {
767+
if e := r.setBranch(br, source); e != nil {
748768
return e
749769
}
750770
}
751771

752-
if e := r.setUser(creds.user, source); e != nil {
753-
return e
772+
if user, ok := creds.user.Get(); ok && user != "" {
773+
if e := r.setUser(user, source); e != nil {
774+
return e
775+
}
754776
}
755777

756778
if pwd, ok := creds.password.Get(); ok {
@@ -767,13 +789,25 @@ func (r *configResolver) applyCredentials(
767789
}
768790
}
769791

792+
if key, ok := creds.secretKey.Get(); ok {
793+
if e := r.setSecretKey(key, source); e != nil {
794+
return e
795+
}
796+
}
797+
770798
return nil
771799
}
772800

773801
func (r *configResolver) resolveEnvVars(paths *cfgPaths) (bool, error) {
774802
db, dbOk := os.LookupEnv("EDGEDB_DATABASE")
775803
if dbOk {
776-
err := r.setDatabase(db, "EDGEDB_DATABASE environment variable")
804+
source := "EDGEDB_DATABASE environment variable"
805+
err := r.setDatabase(db, source)
806+
if err != nil {
807+
return false, err
808+
}
809+
810+
err = r.setBranch(db, source)
777811
if err != nil {
778812
return false, err
779813
}
@@ -788,10 +822,13 @@ func (r *configResolver) resolveEnvVars(paths *cfgPaths) (bool, error) {
788822
branchEnvVarName,
789823
)
790824
}
791-
err := r.setDatabase(
792-
branch,
793-
fmt.Sprintf("%s environment variable", branchEnvVarName),
794-
)
825+
source := fmt.Sprintf("%s environment variable", branchEnvVarName)
826+
err := r.setBranch(branch, source)
827+
if err != nil {
828+
return false, err
829+
}
830+
831+
err = r.setDatabase(branch, source)
795832
if err != nil {
796833
return false, err
797834
}
@@ -1005,6 +1042,7 @@ func (r *configResolver) resolveTOML(paths *cfgPaths) error {
10051042
)
10061043
}
10071044

1045+
// Returns envvar name and value if found.
10081046
func lookupGelOrEdgedbEnv(name string) (string, string, bool) {
10091047
gelName := fmt.Sprintf("GEL%s", name)
10101048
edbName := fmt.Sprintf("EDGEDB%s", name)
@@ -1038,11 +1076,20 @@ func (r *configResolver) config(opts *gelcfg.Options) (*connConfig, error) {
10381076
port = r.port.val.(int)
10391077
}
10401078

1041-
database := "edgedb"
10421079
branch := "__default__"
1080+
database := "edgedb"
1081+
if r.branch.val != nil {
1082+
branch = r.branch.val.(string)
1083+
if r.database.val == nil {
1084+
database = branch
1085+
}
1086+
}
1087+
10431088
if r.database.val != nil {
10441089
database = r.database.val.(string)
1045-
branch = database
1090+
if r.branch.val == nil && database != "edgedb" {
1091+
branch = database
1092+
}
10461093
}
10471094

10481095
user := "edgedb"
@@ -1332,7 +1379,7 @@ func parseDSN(dsn string) (*url.URL, map[string]string, error) {
13321379
}
13331380

13341381
db := strings.TrimPrefix(uri.Path, "/")
1335-
if e := validateQueryArg(vals, "database", db); e != nil {
1382+
if e := validateQueryArg(vals, "branch", db); e != nil {
13361383
return nil, nil, e
13371384
}
13381385

@@ -1646,6 +1693,18 @@ func (r *configResolver) parseCloudInstanceNameIntoConfig(
16461693
)
16471694
}
16481695

1696+
if r.secretKey.val == nil {
1697+
if name, key, ok := lookupGelOrEdgedbEnv("_SECRET_KEY"); ok {
1698+
err := r.setSecretKey(
1699+
key,
1700+
fmt.Sprintf("%s environment variable", name),
1701+
)
1702+
if err != nil {
1703+
return err
1704+
}
1705+
}
1706+
}
1707+
16491708
var secretKey string
16501709
if r.secretKey.val != nil {
16511710
secretKey = r.secretKey.val.(string)

internal/client/connutils_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ var testcaseErrorMapping = map[string]string{
431431
"invalid_port": "invalid port",
432432
"invalid_host": "invalid host",
433433
"invalid_user": "invalid user",
434-
"invalid_database": "invalid database",
434+
"invalid_database": "invalid database|invalid branch",
435435
"exclusive_options": "mutually exclusive options",
436436
"multiple_compound_opts": "mutually exclusive connection options",
437437
"multiple_compound_env": "mutually exclusive environment variables",

internal/client/credentials.go

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ import (
3131
type credentials struct {
3232
host types.OptionalStr
3333
port types.OptionalInt32
34-
user string
34+
user types.OptionalStr
3535
database types.OptionalStr
3636
branch types.OptionalStr
3737
password types.OptionalStr
3838
ca types.OptionalBytes
3939
tlsSecurity types.OptionalStr
40+
secretKey types.OptionalStr
4041
}
4142

4243
func readCredentials(path string) (*credentials, error) {
@@ -92,11 +93,11 @@ func validateCredentials(data map[string]interface{}) (*credentials, error) {
9293
}
9394

9495
if user, ok := data["user"]; ok {
95-
if result.user, ok = user.(string); !ok {
96+
str, ok := user.(string)
97+
if !ok {
9698
return nil, errors.New("`user` must be a string")
9799
}
98-
} else {
99-
return nil, errors.New("`user` key is required")
100+
result.user.Set(str)
100101
}
101102

102103
if host, ok := data["host"]; ok && host != "" {
@@ -108,10 +109,13 @@ func validateCredentials(data map[string]interface{}) (*credentials, error) {
108109
}
109110

110111
if inMap("database", data) &&
111-
inMap("branch", data) &&
112-
data["database"] != data["branch"] {
113-
return nil, errors.New(
114-
"`database` and `branch` are both set but do not match")
112+
inMap("branch", data) {
113+
if data["database"] != data["branch"] &&
114+
data["database"] != "edgedb" &&
115+
data["branch"] != "__default__" {
116+
return nil, errors.New(
117+
"`database` and `branch` are both set but do not match")
118+
}
115119
}
116120

117121
if database, ok := data["database"]; ok {
@@ -138,6 +142,15 @@ func validateCredentials(data map[string]interface{}) (*credentials, error) {
138142
result.password.Set(pwd)
139143
}
140144

145+
if key, ok := data["secret_key"]; ok && key != nil {
146+
str, ok := key.(string)
147+
if !ok {
148+
return nil, errors.New("`secret_key` must be a string")
149+
}
150+
151+
result.secretKey.Set(str)
152+
}
153+
141154
if ca, ok := data["tls_ca"]; ok {
142155
str, ok := ca.(string)
143156
if !ok {

internal/client/credentials_test.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,12 @@ func TestCredentialsRead(t *testing.T) {
3232
database: types.NewOptionalStr("test3n"),
3333
password: types.NewOptionalStr("lZTBy1RVCfOpBAOwSCwIyBIR"),
3434
port: types.NewOptionalInt32(10702),
35-
user: "test3n",
35+
user: types.NewOptionalStr("test3n"),
3636
}
3737

3838
assert.Equal(t, expected, creds)
3939
}
4040

41-
func TestCredentialsEmpty(t *testing.T) {
42-
creds, err := validateCredentials(map[string]interface{}{})
43-
assert.EqualError(t, err, "`user` key is required")
44-
assert.Nil(t, creds)
45-
}
46-
4741
func TestCredentialsPort(t *testing.T) {
4842
creds, err := validateCredentials(map[string]interface{}{
4943
"user": "u1",

shared-client-testcases

0 commit comments

Comments
 (0)