Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor device login tests #2959

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cmd/credentialUtil.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@ func GetUserOAuthTokenManagerInstance() *common.UserOAuthTokenManager {
if common.AzcopyJobPlanFolder == "" {
panic("invalid state, AzcopyJobPlanFolder should not be an empty string")
}
cacheName := common.GetEnvironmentVariable(common.EEnvironmentVariable.LoginCacheName())

currentUserOAuthTokenManager = common.NewUserOAuthTokenManagerInstance(common.CredCacheOptions{
DPAPIFilePath: common.AzcopyJobPlanFolder,
KeyName: oauthLoginSessionCacheKeyName,
KeyName: common.Iff(cacheName != "", cacheName, oauthLoginSessionCacheKeyName),
ServiceName: oauthLoginSessionCacheServiceName,
AccountName: oauthLoginSessionCacheAccountName,
AccountName: common.Iff(cacheName != "", cacheName, oauthLoginSessionCacheAccountName),
})
})

Expand Down
4 changes: 2 additions & 2 deletions cmd/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,12 @@ func (lca loginCmdArgs) process() error {
// For MSI login, info success message to user.
glcm.Info("Login with identity succeeded.")
case common.EAutoLoginType.AzCLI().String():
if err := uotm.AzCliLogin(lca.tenantID); err != nil {
if err := uotm.AzCliLogin(lca.tenantID, lca.persistToken); err != nil {
return err
}
glcm.Info("Login with AzCliCreds succeeded")
case common.EAutoLoginType.PsCred().String():
if err := uotm.PSContextToken(lca.tenantID); err != nil {
if err := uotm.PSContextToken(lca.tenantID, lca.persistToken); err != nil {
return err
}
glcm.Info("Login with Powershell context succeeded")
Expand Down
53 changes: 46 additions & 7 deletions cmd/loginStatus.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,25 @@ package cmd

import (
"context"
"encoding/json"
"fmt"
"github.com/Azure/azure-storage-azcopy/v10/common"
"github.com/Azure/azure-storage-azcopy/v10/ste"
"github.com/spf13/cobra"
)

type LoginStatusOutput struct {
Valid bool `json:"valid"`
TenantID *string `json:"tenantID,omitempty"`
AADEndpoint *string `json:"AADEndpoint,omitempty"`
AuthMethod *string `json:"authMethod,omitempty"`
}

func init() {
type loginStatus struct {
tenantID bool
endpoint bool
method bool
}
commandLineInput := loginStatus{}

Expand All @@ -51,26 +60,56 @@ func init() {
uotm := GetUserOAuthTokenManagerInstance()
tokenInfo, err := uotm.GetTokenInfo(ctx)

if err == nil && !tokenInfo.IsExpired() {
glcm.Info("You have successfully refreshed your token. Your login session is still active")
var Info = LoginStatusOutput{
Valid: err == nil && !tokenInfo.IsExpired(),
}

logText := func(format string, a ...any) {
if azcopyOutputFormat == common.EOutputFormat.None() || azcopyOutputFormat == common.EOutputFormat.Text() {
glcm.Info(fmt.Sprintf(format, a...))
}
}

if Info.Valid {
logText("You have successfully refreshed your token. Your login session is still active")

if commandLineInput.tenantID {
glcm.Info(fmt.Sprintf("Tenant ID: %v", tokenInfo.Tenant))
logText("Tenant ID: %v", tokenInfo.Tenant)
Info.TenantID = &tokenInfo.Tenant
}

if commandLineInput.endpoint {
glcm.Info(fmt.Sprintf("Active directory endpoint: %v", tokenInfo.ActiveDirectoryEndpoint))
logText(fmt.Sprintf("Active directory endpoint: %v", tokenInfo.ActiveDirectoryEndpoint))
Info.AADEndpoint = &tokenInfo.ActiveDirectoryEndpoint
}

if commandLineInput.method {
logText(fmt.Sprintf("Authorized using %s", tokenInfo.LoginType))
method := tokenInfo.LoginType.String()
Info.AuthMethod = &method
}
} else {
logText("You are currently not logged in. Please login using 'azcopy login'")
}

if azcopyOutputFormat == common.EOutputFormat.Json() {
glcm.Output(
func(_ common.OutputFormat) string {
buf, err := json.Marshal(Info)
if err != nil {
panic(err)
}

glcm.Exit(nil, common.EExitCode.Success())
return string(buf)
}, common.EOutputMessageType.LoginStatusInfo())
}

glcm.Info("You are currently not logged in. Please login using 'azcopy login'")
glcm.Exit(nil, common.EExitCode.Error())
glcm.Exit(nil, common.Iff(Info.Valid, common.EExitCode.Success(), common.EExitCode.Error()))
},
}

lgCmd.AddCommand(lgStatus)
lgStatus.PersistentFlags().BoolVar(&commandLineInput.tenantID, "tenant", false, "Prints the Microsoft Entra tenant ID that is currently being used in session.")
lgStatus.PersistentFlags().BoolVar(&commandLineInput.endpoint, "endpoint", false, "Prints the Microsoft Entra endpoint that is being used in the current session.")
lgStatus.PersistentFlags().BoolVar(&commandLineInput.method, "method", false, "Prints the authorization method used in the current session.")
}
37 changes: 26 additions & 11 deletions common/azure_ps_context_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"os"
"os/exec"
"regexp"
"strings"
"sync"
"time"

Expand All @@ -24,7 +25,8 @@ import (

const credNamePSContext = "PSContextCredential"

type PSTokenProvider func(ctx context.Context, resource string, tenant string) ([]byte, error)
type PSTokenProvider func(ctx context.Context, options policy.TokenRequestOptions) ([]byte, error)

func validTenantID(tenantID string) bool {
match, err := regexp.MatchString("^[0-9a-zA-Z-.]+$", tenantID)
if err != nil {
Expand All @@ -50,6 +52,7 @@ func resolveTenant(defaultTenant, specified, credName string, additionalTenants
}
return "", fmt.Errorf(`%s isn't configured to acquire tokens for tenant %q. To enable acquiring tokens for this tenant add it to the AdditionallyAllowedTenants on the credential options, or add "*" to allow acquiring tokens for any tenant`, credName, specified)
}

// PowershellContextCredentialOptions contains optional parameters for AzureDeveloperCLICredential.
type PowershellContextCredentialOptions struct {
// TenantID identifies the tenant the credential should authenticate in. Defaults to the azd environment,
Expand Down Expand Up @@ -96,7 +99,10 @@ func (c *PowershellContextCredential) GetToken(ctx context.Context, opts policy.
}
c.mu.Lock()
defer c.mu.Unlock()
b, err := c.opts.tokenProvider(ctx, opts.Scopes[0], tenant)

opts.TenantID = tenant

b, err := c.opts.tokenProvider(ctx, opts)
if err == nil {
at, err = c.createAccessToken(b)
}
Expand All @@ -109,21 +115,30 @@ func (c *PowershellContextCredential) GetToken(ctx context.Context, opts policy.

// We ignore resource because PS does not support all Resources. Disk scope is not supported
// and we are here only with Storage scope
var defaultAzdTokenProvider PSTokenProvider = func(ctx context.Context, _ string, tenantID string) ([]byte, error) {
var defaultAzdTokenProvider PSTokenProvider = func(ctx context.Context, opts policy.TokenRequestOptions) ([]byte, error) {
// set a default timeout for this authentication iff the application hasn't done so already
var cancel context.CancelFunc
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
ctx, cancel = context.WithTimeout(ctx, 10 * time.Minute)
ctx, cancel = context.WithTimeout(ctx, 10*time.Minute)
defer cancel()
}

r := regexp.MustCompile("(?s){.*Token.*ExpiresOn.*}")

if tenantID != "" {
tenantID += " -TenantId" + tenantID
cmd := "Get-AzAccessToken"
// set options
if len(opts.Scopes) != 1 {
return nil, errors.New("exactly one scope must be specified")
} else {
cmd += fmt.Sprintf(" -ResourceUrl \"%s\"", strings.TrimSuffix(opts.Scopes[0], "/.default"))
}
cmd := "Get-AzAccessToken -ResourceUrl https://storage.azure.com" + tenantID + " | ConvertTo-Json"


if opts.TenantID != "" {
cmd += fmt.Sprintf(" -TenantId \"%s\"", opts.TenantID)
}

// We're going to get broken on this in Az 14.0 and Az.Accounts 5.0, so we may as well fix it now.
cmd += " -AsSecureString | Foreach-Object {[PSCustomObject]@{Token= $($_.Token | ConvertFrom-SecureString -AsPlainText); ExpiresOn = $_.ExpiresOn}} | ConvertTo-Json"

cliCmd := exec.CommandContext(ctx, "pwsh", "-Command", cmd)
cliCmd.Env = os.Environ()
Expand All @@ -142,7 +157,7 @@ var defaultAzdTokenProvider PSTokenProvider = func(ctx context.Context, _ string
output = []byte(r.FindString(string(output)))
if string(output) == "" {
invalidTokenMsg := " Invalid output received while retrieving token with Powershell. Run command \"" + cmd + "\"" +
" on powershell and verify that the output is indeed a valid token."
" on powershell and verify that the output is indeed a valid token."
return nil, errors.New(credNamePSContext + invalidTokenMsg)
}
return output, nil
Expand All @@ -158,7 +173,7 @@ func (c *PowershellContextCredential) createAccessToken(tk []byte) (azcore.Acces
if err != nil {
return azcore.AccessToken{}, errors.New(err.Error())
}

parseErr := "error parsing token expiration time %q: %v"
exp, err := time.Parse(time.RFC3339, t.ExpiresOn)
if err != nil {
Expand All @@ -170,4 +185,4 @@ func (c *PowershellContextCredential) createAccessToken(tk []byte) (azcore.Acces
}, nil
}

var _ azcore.TokenCredential = (*PowershellContextCredential)(nil)
var _ azcore.TokenCredential = (*PowershellContextCredential)(nil)
4 changes: 4 additions & 0 deletions common/credCache_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ func (c *CredCache) saveTokenInternal(token OAuthTokenInfo) error {
}

func (c *CredCache) tokenFilePath() string {
if cacheFile := GetEnvironmentVariable(EEnvironmentVariable.LoginCacheName()); cacheFile != "" {
return path.Join(c.dpapiFilePath, "/", cacheFile)
}

return path.Join(c.dpapiFilePath, "/", defaultTokenFileName)
}

Expand Down
11 changes: 11 additions & 0 deletions common/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ func (AutoLoginType) PsCred() AutoLoginType { return AutoLoginType(4) }
func (AutoLoginType) Workload() AutoLoginType { return AutoLoginType(5) }
func (AutoLoginType) TokenStore() AutoLoginType { return AutoLoginType(255) } // Storage Explorer internal integration only. Do not add this to ValidAutoLoginTypes.

func (d AutoLoginType) IsInteractive() bool {
return d == d.Device()
}

func (d AutoLoginType) String() string {
return strings.ToLower(enum.StringInt(d, reflect.TypeOf(d)))
}
Expand Down Expand Up @@ -306,6 +310,13 @@ func (EnvironmentVariable) CacheProxyLookup() EnvironmentVariable {
}
}

func (EnvironmentVariable) LoginCacheName() EnvironmentVariable {
return EnvironmentVariable{
Name: "AZCOPY_LOGIN_CACHE_NAME",
Description: "Do not use in production. Overrides the file name or key name used to cache azcopy's token. Do not use in production. This feature is not documented, intended for testing, and may break. Do not use in production.",
}
}

func (EnvironmentVariable) LogLocation() EnvironmentVariable {
return EnvironmentVariable{
Name: "AZCOPY_LOG_LOCATION",
Expand Down
18 changes: 13 additions & 5 deletions common/oauthTokenManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,21 +175,21 @@ func (uotm *UserOAuthTokenManager) WorkloadIdentityLogin(persist bool) error {
return uotm.validateAndPersistLogin(oAuthTokenInfo)
}

func (uotm *UserOAuthTokenManager) AzCliLogin(tenantID string) error {
func (uotm *UserOAuthTokenManager) AzCliLogin(tenantID string, persist bool) error {
oAuthTokenInfo := &OAuthTokenInfo{
LoginType: EAutoLoginType.AzCLI(),
Tenant: tenantID,
Persist: false, // AzCLI creds do not need to be persisted, AzCLI handles persistence.
Persist: persist, // AzCLI creds do not need to be persisted, AzCLI handles persistence.
}

return uotm.validateAndPersistLogin(oAuthTokenInfo)
}

func (uotm *UserOAuthTokenManager) PSContextToken(tenantID string) error {
func (uotm *UserOAuthTokenManager) PSContextToken(tenantID string, persist bool) error {
oAuthTokenInfo := &OAuthTokenInfo{
LoginType: EAutoLoginType.PsCred(),
Tenant: tenantID,
Persist: false, // Powershell creds do not need to be persisted, Powershell handles persistence.
Persist: persist, // Powershell creds do not need to be persisted, Powershell handles persistence.
}

return uotm.validateAndPersistLogin(oAuthTokenInfo)
Expand Down Expand Up @@ -645,6 +645,10 @@ func (credInfo *OAuthTokenInfo) GetClientSecretCredential() (azcore.TokenCredent
}

func (credInfo *OAuthTokenInfo) GetAzCliCredential() (azcore.TokenCredential, error) {
if credInfo.Tenant == DefaultTenantID {
credInfo.Tenant = ""
}

tc, err := azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{TenantID: credInfo.Tenant})
if err != nil {
return nil, err
Expand All @@ -654,7 +658,11 @@ func (credInfo *OAuthTokenInfo) GetAzCliCredential() (azcore.TokenCredential, er
}

func (credInfo *OAuthTokenInfo) GetPSContextCredential() (azcore.TokenCredential, error) {
tc, err := NewPowershellContextCredential(nil)
if credInfo.Tenant == DefaultTenantID {
credInfo.Tenant = ""
}

tc, err := NewPowershellContextCredential(&PowershellContextCredentialOptions{TenantID: credInfo.Tenant})
if err != nil {
return nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions common/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ func (OutputMessageType) Response() OutputMessageType { return OutputMessageType
func (OutputMessageType) ListObject() OutputMessageType { return OutputMessageType(8) }
func (OutputMessageType) ListSummary() OutputMessageType { return OutputMessageType(9) }

func (OutputMessageType) LoginStatusInfo() OutputMessageType { return OutputMessageType(10) }

func (o OutputMessageType) String() string {
return enum.StringInt(o, reflect.TypeOf(o))
}
Expand Down
47 changes: 43 additions & 4 deletions e2etest/newe2e_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ All immediate fields of a mutually exclusive struct will be treated as required,
Structs that are not marked "required" will present Environment errors from "required" fields when one or more options are successfully set
*/

const AzurePipeline = "AzurePipeline"
const TestEnvironmentAzurePipelines = "TestEnvironmentAzurePipelines"

type NewE2EConfig struct {
E2EAuthConfig struct { // mutually exclusive
Expand All @@ -51,9 +51,18 @@ type NewE2EConfig struct {

StaticStgAcctInfo struct {
StaticOAuth struct {
TenantID string `env:"NEW_E2E_STATIC_TENANT_ID"`
ApplicationID string `env:"NEW_E2E_STATIC_APPLICATION_ID,required"`
ClientSecret string `env:"NEW_E2E_STATIC_CLIENT_SECRET,required"`
TenantID string `env:"NEW_E2E_STATIC_TENANT_ID"`

OAuthSource struct { // mutually exclusive
SPNSecret struct {
ApplicationID string `env:"NEW_E2E_STATIC_APPLICATION_ID,required"`
ClientSecret string `env:"NEW_E2E_STATIC_CLIENT_SECRET,required"`
} `env:",required"`

PSInherit bool `env:"NEW_E2E_STATIC_PS_INHERIT,required"`

CLIInherit bool `env:"NEW_E2E_STATIC_CLI_INHERIT,required"`
} `env:",required,mutually_exclusive"`
}

// todo: should we automate this somehow? Currently each of these accounts needs some marginal boilerplate.
Expand Down Expand Up @@ -84,6 +93,36 @@ func (e NewE2EConfig) StaticResources() bool {
return e.E2EAuthConfig.SubscriptionLoginInfo.SubscriptionID == "" // all subscriptionlogininfo options would have to be filled due to required
}

func (e NewE2EConfig) GetSPNOptions() (present bool, tenant, applicationId, secret string) {
staticInfo := e.E2EAuthConfig.StaticStgAcctInfo.StaticOAuth
dynamicInfo := e.E2EAuthConfig.SubscriptionLoginInfo.DynamicOAuth.SPNSecret

if e.StaticResources() {
return staticInfo.OAuthSource.SPNSecret.ApplicationID != "",
staticInfo.TenantID,
staticInfo.OAuthSource.SPNSecret.ApplicationID,
staticInfo.OAuthSource.SPNSecret.ClientSecret
} else {
return dynamicInfo.ApplicationID != "",
dynamicInfo.TenantID,
dynamicInfo.ApplicationID,
dynamicInfo.ApplicationID
}
}

func (e NewE2EConfig) GetTenantID() string {
if e.StaticResources() {
return e.E2EAuthConfig.StaticStgAcctInfo.StaticOAuth.TenantID
} else {
dynamicInfo := e.E2EAuthConfig.SubscriptionLoginInfo.DynamicOAuth
if tid := dynamicInfo.SPNSecret.TenantID; tid != "" {
return tid
} else {
return dynamicInfo.Workload.TenantId // worst case if it bubbles down and it's all zero, that's OK.
}
}
}

// ========= Tag Definition ==========

type EnvTag struct {
Expand Down
Loading
Loading