Skip to content

Commit c4a7948

Browse files
authored
Fix Bug: Prevent Empty Region in WithAzureRegion from Overriding MSAL_FORCE_REGION (#545)
* Fixed a bug where if empty region is passed in WithAzureRegion it would override the MSAL_FORCE_REGION * Updated the first tests * Removed dead code. * Update confidential_test.go * Cleaned up test
1 parent 06ce6ba commit c4a7948

File tree

2 files changed

+58
-41
lines changed

2 files changed

+58
-41
lines changed

apps/confidential/confidential.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,9 @@ func WithInstanceDiscovery(enabled bool) Option {
305305
// If an invalid region name is provided, the non-regional endpoint MIGHT be used or the token request MIGHT fail.
306306
func WithAzureRegion(val string) Option {
307307
return func(o *clientOptions) {
308-
o.azureRegion = val
308+
if val != "" {
309+
o.azureRegion = val
310+
}
309311
}
310312
}
311313

apps/confidential/confidential_test.go

+55-40
Original file line numberDiff line numberDiff line change
@@ -195,50 +195,65 @@ func TestRegionAutoEnable_EmptyRegion_EnvRegion(t *testing.T) {
195195
}
196196
}
197197

198-
func TestRegionAutoEnable_SpecifiedRegion_EnvRegion(t *testing.T) {
199-
cred, err := NewCredFromSecret(fakeSecret)
200-
if err != nil {
201-
t.Fatal(err)
202-
}
203-
204-
envRegion := "envRegion"
205-
err = os.Setenv("MSAL_FORCE_REGION", envRegion)
206-
if err != nil {
207-
t.Fatal(err)
208-
}
209-
defer os.Unsetenv("MSAL_FORCE_REGION")
210-
211-
lmo := "login.microsoftonline.com"
212-
tenant := "tenant"
213-
mockClient := mock.Client{}
214-
testRegion := "region"
215-
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(testRegion))
216-
if err != nil {
217-
t.Fatal(err)
218-
}
219-
220-
if client.base.AuthParams.AuthorityInfo.Region != testRegion {
221-
t.Fatalf("wanted %q, got %q", testRegion, client.base.AuthParams.AuthorityInfo.Region)
198+
func TestRegionAutoEnable_SpecifiedEmptyRegion_EnvRegion(t *testing.T) {
199+
tests := []struct {
200+
name string
201+
envRegion string
202+
region string
203+
resultRegion string
204+
}{
205+
{
206+
name: "Region is empty, envRegion is set",
207+
envRegion: "region",
208+
region: "",
209+
resultRegion: "region",
210+
},
211+
{
212+
name: "Region is set, envRegion is set",
213+
envRegion: "region",
214+
region: "setRegion",
215+
resultRegion: "setRegion",
216+
},
217+
{
218+
name: "Region is set, envRegion is empty",
219+
envRegion: "",
220+
region: "setRegion",
221+
resultRegion: "setRegion",
222+
},
223+
{
224+
name: "Disable region is set, envRegion is set",
225+
envRegion: "region",
226+
region: "DisableMsalForceRegion",
227+
resultRegion: "",
228+
},
222229
}
223-
}
224230

225-
func TestRegionAutoEnable_DisableMsalForceRegion(t *testing.T) {
226-
cred, err := NewCredFromSecret(fakeSecret)
227-
if err != nil {
228-
t.Fatal(err)
229-
}
231+
for _, test := range tests {
232+
t.Run(test.name, func(t *testing.T) {
233+
cred, err := NewCredFromSecret(fakeSecret)
234+
if err != nil {
235+
t.Fatal(err)
236+
}
237+
if test.envRegion != "" {
238+
t.Setenv("MSAL_FORCE_REGION", test.envRegion)
239+
}
240+
lmo := "login.microsoftonline.com"
241+
tenant := "tenant"
242+
mockClient := mock.Client{}
230243

231-
lmo := "login.microsoftonline.com"
232-
tenant := "tenant"
233-
mockClient := mock.Client{}
234-
testRegion := "DisableMsalForceRegion"
235-
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(testRegion))
236-
if err != nil {
237-
t.Fatal(err)
238-
}
244+
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(test.region))
245+
if err != nil {
246+
t.Fatal(err)
247+
}
239248

240-
if client.base.AuthParams.AuthorityInfo.Region != "" {
241-
t.Fatalf("wanted empty, got %q", client.base.AuthParams.AuthorityInfo.Region)
249+
if test.resultRegion == "DisableMsalForceRegion" {
250+
if client.base.AuthParams.AuthorityInfo.Region != "" {
251+
t.Fatalf("wanted %q, got %q", test.resultRegion, client.base.AuthParams.AuthorityInfo.Region)
252+
}
253+
} else if client.base.AuthParams.AuthorityInfo.Region != test.resultRegion {
254+
t.Fatalf("wanted %q, got %q", test.resultRegion, client.base.AuthParams.AuthorityInfo.Region)
255+
}
256+
})
242257
}
243258
}
244259

0 commit comments

Comments
 (0)