diff --git a/cli/azd/.vscode/cspell.yaml b/cli/azd/.vscode/cspell.yaml index 6ae1d0c06b5..9d8d0522d48 100644 --- a/cli/azd/.vscode/cspell.yaml +++ b/cli/azd/.vscode/cspell.yaml @@ -75,6 +75,8 @@ words: - yarnpkg - azconfig - hostnames + - managedhsm + - microsoftazure - seekable - seekability languageSettings: diff --git a/cli/azd/cmd/extensions.go b/cli/azd/cmd/extensions.go index 32a241342f8..15b313bba36 100644 --- a/cli/azd/cmd/extensions.go +++ b/cli/azd/cmd/extensions.go @@ -19,6 +19,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/exec" "github.com/azure/azure-dev/cli/azd/pkg/extensions" "github.com/azure/azure-dev/cli/azd/pkg/input" + kv "github.com/azure/azure-dev/cli/azd/pkg/keyvault" "github.com/azure/azure-dev/cli/azd/pkg/lazy" "github.com/azure/azure-dev/cli/azd/pkg/output/ux" pkgux "github.com/azure/azure-dev/cli/azd/pkg/ux" @@ -119,6 +120,7 @@ type extensionAction struct { extensionManager *extensions.Manager azdServer *grpcserver.Server globalOptions *internal.GlobalCommandOptions + kvService kv.KeyVaultService cmd *cobra.Command args []string } @@ -132,6 +134,7 @@ func newExtensionAction( cmd *cobra.Command, azdServer *grpcserver.Server, globalOptions *internal.GlobalCommandOptions, + kvService kv.KeyVaultService, args []string, ) actions.Action { return &extensionAction{ @@ -141,6 +144,7 @@ func newExtensionAction( extensionManager: extensionManager, azdServer: azdServer, globalOptions: globalOptions, + kvService: kvService, cmd: cmd, args: args, } @@ -216,7 +220,18 @@ func (a *extensionAction) Run(ctx context.Context) (*actions.ActionResult, error env, err := a.lazyEnv.GetValue() if err == nil && env != nil { - allEnv = append(allEnv, env.Environ()...) + // Resolve Key Vault secret references only in azd-managed environment + // variables (akvs:// and @Microsoft.KeyVault formats). System env vars + // from os.Environ() are NOT processed — only the azd environment's + // variables may contain KV references. + azdEnvVars := env.Environ() + subId := env.Getenv("AZURE_SUBSCRIPTION_ID") + azdEnvVars, kvErr := kv.ResolveSecretEnvironment(ctx, a.kvService, azdEnvVars, subId) + if kvErr != nil { + log.Printf("warning: %v", kvErr) + } + + allEnv = append(allEnv, azdEnvVars...) } serverInfo, err := a.azdServer.Start() diff --git a/cli/azd/pkg/azdext/keyvault_resolver.go b/cli/azd/pkg/azdext/keyvault_resolver.go new file mode 100644 index 00000000000..637e1c6fb9a --- /dev/null +++ b/cli/azd/pkg/azdext/keyvault_resolver.go @@ -0,0 +1,412 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "fmt" + "maps" + "net/http" + "regexp" + "slices" + "strings" + "sync" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" + "github.com/azure/azure-dev/cli/azd/pkg/keyvault" +) + +// KeyVaultResolver resolves Azure Key Vault secret references for extension +// scenarios. It uses the extension's [TokenProvider] for authentication and +// the Azure SDK data-plane client for secret retrieval. +// +// Two reference formats are supported: +// +// akvs://// +// @Microsoft.KeyVault(SecretUri=https://.vault.azure.net/secrets/[/]) +// +// The akvs:// scheme is the preferred compact form. The @Microsoft.KeyVault +// format supports only the SecretUri= variant; the VaultName/SecretName form +// is not currently implemented. +// +// Usage: +// +// tp, _ := azdext.NewTokenProvider(ctx, client, nil) +// resolver, _ := azdext.NewKeyVaultResolver(tp, nil) +// value, err := resolver.Resolve(ctx, "akvs://sub-id/my-vault/my-secret") +type KeyVaultResolver struct { + credential azcore.TokenCredential + clientFactory secretClientFactory + opts KeyVaultResolverOptions + clientCache sync.Map // map[vaultURL]secretGetter — per-vault client cache +} + +// secretClientFactory abstracts secret client creation for testability. +type secretClientFactory func(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) + +// secretGetter abstracts the Azure SDK secret client's GetSecret method. +type secretGetter interface { + GetSecret( + ctx context.Context, + name string, + version string, + options *azsecrets.GetSecretOptions, + ) (azsecrets.GetSecretResponse, error) +} + +// KeyVaultResolverOptions configures a [KeyVaultResolver]. +type KeyVaultResolverOptions struct { + // VaultSuffix overrides the default Key Vault DNS suffix. + // Defaults to "vault.azure.net" (Azure public cloud). + VaultSuffix string + + // ClientFactory overrides the default secret client constructor. + // Useful for testing. When nil, the production [azsecrets.NewClient] is used. + ClientFactory func(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) +} + +// NewKeyVaultResolver creates a [KeyVaultResolver] with the given credential. +// +// credential must not be nil; it is typically a [*TokenProvider]. +// If opts is nil, production defaults are used. +func NewKeyVaultResolver(credential azcore.TokenCredential, opts *KeyVaultResolverOptions) (*KeyVaultResolver, error) { + if credential == nil { + return nil, errors.New("azdext.NewKeyVaultResolver: credential must not be nil") + } + + if opts == nil { + opts = &KeyVaultResolverOptions{} + } + + if opts.VaultSuffix == "" { + opts.VaultSuffix = "vault.azure.net" + } + + factory := defaultSecretClientFactory + if opts.ClientFactory != nil { + factory = opts.ClientFactory + } + + return &KeyVaultResolver{ + credential: credential, + clientFactory: factory, + opts: *opts, + }, nil +} + +// defaultSecretClientFactory creates a real Azure SDK secrets client. +func defaultSecretClientFactory(vaultURL string, credential azcore.TokenCredential) (secretGetter, error) { + client, err := azsecrets.NewClient(vaultURL, credential, &azsecrets.ClientOptions{ + DisableChallengeResourceVerification: false, + }) + if err != nil { + return nil, err + } + + return client, nil +} + +// Resolve fetches the secret value for a Key Vault secret reference. +// +// Both akvs:// and @Microsoft.KeyVault(SecretUri=...) formats are accepted. +// +// Returns a [*KeyVaultResolveError] for all domain errors (invalid reference, +// secret not found, authentication failure). No silent fallbacks or hidden retries. +func (r *KeyVaultResolver) Resolve(ctx context.Context, ref string) (string, error) { + if ctx == nil { + return "", errors.New("azdext.KeyVaultResolver.Resolve: context must not be nil") + } + + parsed, err := ParseSecretReference(ref) + if err != nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonInvalidReference, + Err: err, + } + } + + vaultURL := parsed.VaultURL + if vaultURL == "" { + vaultURL = fmt.Sprintf("https://%s.%s", parsed.VaultName, r.opts.VaultSuffix) + } + + secretVersion := parsed.SecretVersion + + client, err := r.getOrCreateClient(vaultURL) + if err != nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonClientCreation, + Err: fmt.Errorf("failed to create Key Vault client for %s: %w", vaultURL, err), + } + } + + resp, err := client.GetSecret(ctx, parsed.SecretName, secretVersion, nil) + if err != nil { + // Default to ServiceError for non-ResponseError failures (e.g., network + // timeouts, DNS resolution failures). AccessDenied is only used when the + // server explicitly returns 401/403. + reason := ResolveReasonServiceError + + var respErr *azcore.ResponseError + if errors.As(err, &respErr) { + switch respErr.StatusCode { + case http.StatusNotFound: + reason = ResolveReasonNotFound + case http.StatusForbidden, http.StatusUnauthorized: + reason = ResolveReasonAccessDenied + default: + reason = ResolveReasonServiceError + } + } + + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: reason, + Err: fmt.Errorf( + "failed to retrieve secret %q from vault %q: %w", + parsed.SecretName, + parsed.VaultName, + err, + ), + } + } + + if resp.Value == nil { + return "", &KeyVaultResolveError{ + Reference: ref, + Reason: ResolveReasonNotFound, + Err: fmt.Errorf("secret %q in vault %q has a nil value", parsed.SecretName, parsed.VaultName), + } + } + + return *resp.Value, nil +} + +// getOrCreateClient returns a cached client for the given vault URL, creating +// one via the client factory if no cached entry exists. The cache is safe for +// concurrent use. +func (r *KeyVaultResolver) getOrCreateClient(vaultURL string) (secretGetter, error) { + if cached, ok := r.clientCache.Load(vaultURL); ok { + return cached.(secretGetter), nil + } + + client, err := r.clientFactory(vaultURL, r.credential) + if err != nil { + return nil, err + } + + // Store-or-load to handle concurrent creation for the same vault. + actual, _ := r.clientCache.LoadOrStore(vaultURL, client) + return actual.(secretGetter), nil +} + +// ResolveMap resolves a map of key → secret references, returning a map of +// key → resolved secret values. Both akvs:// and @Microsoft.KeyVault formats +// are accepted. All entries are attempted; errors are collected and returned +// together via [errors.Join] so that callers see every failure at once. +// +// Non-secret values are passed through unchanged, so callers can safely +// resolve a mixed map of plain values and secret references. +// +// Keys are processed in sorted order so that error messages are deterministic. +func (r *KeyVaultResolver) ResolveMap(ctx context.Context, refs map[string]string) (map[string]string, error) { + if ctx == nil { + return nil, errors.New("azdext.KeyVaultResolver.ResolveMap: context must not be nil") + } + + result := make(map[string]string, len(refs)) + + // Sort keys for deterministic iteration and error reporting. + var errs []error + + for _, key := range slices.Sorted(maps.Keys(refs)) { + value := refs[key] + + if !IsSecretReference(value) { + result[key] = value + continue + } + + resolved, err := r.Resolve(ctx, value) + if err != nil { + errs = append(errs, fmt.Errorf("key %q: %w", key, err)) + result[key] = value // preserve original reference so callers see all keys + continue + } + + result[key] = resolved + } + + if len(errs) > 0 { + return result, fmt.Errorf("azdext.KeyVaultResolver.ResolveMap: %w", errors.Join(errs...)) + } + + return result, nil +} + +// SecretReference represents a parsed Key Vault secret reference. +// It may be populated from either the akvs:// or @Microsoft.KeyVault format. +type SecretReference struct { + // SubscriptionID is the Azure subscription containing the Key Vault. + // Present for akvs:// references; empty for @Microsoft.KeyVault references. + SubscriptionID string + + // VaultName is the Key Vault name (not the full URL). + VaultName string + + // SecretName is the name of the secret within the vault. + SecretName string + + // SecretVersion is the specific secret version to retrieve. + // Empty string means latest version. + SecretVersion string + + // VaultURL is the full vault URL (e.g., "https://my-vault.vault.azure.net"). + // Present for @Microsoft.KeyVault references; empty for akvs:// references + // (where the URL is constructed from VaultName + VaultSuffix). + VaultURL string +} + +// IsSecretReference reports whether s is a Key Vault secret reference +// in either the akvs:// or @Microsoft.KeyVault(SecretUri=...) format. +func IsSecretReference(s string) bool { + return keyvault.IsSecretReference(s) +} + +// vaultNameRe validates Azure Key Vault names per Azure naming rules: +// - 3–24 characters +// - starts with a letter +// - contains only alphanumeric and hyphens +// - does not end with a hyphen +var vaultNameRe = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9-]{1,22}[a-zA-Z0-9]$`) + +// ParseSecretReference parses a Key Vault secret reference into its components. +// +// Two formats are supported: +// +// akvs://// +// @Microsoft.KeyVault(SecretUri=https://.vault.azure.net/secrets/[/]) +// +// For the akvs:// format, the vault name is validated against Azure Key Vault +// naming rules (3–24 characters, starts with letter, alphanumeric and hyphens +// only, does not end with a hyphen). +func ParseSecretReference(ref string) (*SecretReference, error) { + if keyvault.IsKeyVaultAppReference(ref) { + return parseKeyVaultAppReference(ref) + } + + return parseAkvsReference(ref) +} + +// parseAkvsReference parses an akvs:// URI into its components. +func parseAkvsReference(ref string) (*SecretReference, error) { + parsed, err := keyvault.ParseAzureKeyVaultSecret(ref) + if err != nil { + return nil, err + } + + if strings.TrimSpace(parsed.SubscriptionId) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: subscription-id must not be empty", ref) + } + if strings.TrimSpace(parsed.VaultName) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: vault-name must not be empty", ref) + } + if !vaultNameRe.MatchString(parsed.VaultName) { + return nil, fmt.Errorf( + "invalid akvs:// reference %q: vault name %q must be 3-24 characters, "+ + "start with a letter, and contain only alphanumeric characters and hyphens", + ref, parsed.VaultName, + ) + } + if strings.TrimSpace(parsed.SecretName) == "" { + return nil, fmt.Errorf("invalid akvs:// reference %q: secret-name must not be empty", ref) + } + + return &SecretReference{ + SubscriptionID: parsed.SubscriptionId, + VaultName: parsed.VaultName, + SecretName: parsed.SecretName, + }, nil +} + +// parseKeyVaultAppReference parses an @Microsoft.KeyVault(SecretUri=...) reference +// by delegating to the core keyvault package. +func parseKeyVaultAppReference(ref string) (*SecretReference, error) { + parsed, err := keyvault.ParseKeyVaultAppReference(ref) + if err != nil { + return nil, err + } + + return &SecretReference{ + VaultName: parsed.VaultName, + SecretName: parsed.SecretName, + SecretVersion: parsed.SecretVersion, + VaultURL: parsed.VaultURL, + }, nil +} + +// ResolveReason classifies the cause of a [KeyVaultResolveError]. +type ResolveReason int + +const ( + // ResolveReasonInvalidReference indicates the secret reference is malformed. + ResolveReasonInvalidReference ResolveReason = iota + + // ResolveReasonClientCreation indicates failure to create the Key Vault client. + ResolveReasonClientCreation + + // ResolveReasonNotFound indicates the secret does not exist. + ResolveReasonNotFound + + // ResolveReasonAccessDenied indicates an authentication or authorization failure. + ResolveReasonAccessDenied + + // ResolveReasonServiceError indicates an unexpected Key Vault service error. + ResolveReasonServiceError +) + +// String returns a human-readable label for the reason. +func (r ResolveReason) String() string { + switch r { + case ResolveReasonInvalidReference: + return "invalid_reference" + case ResolveReasonClientCreation: + return "client_creation" + case ResolveReasonNotFound: + return "not_found" + case ResolveReasonAccessDenied: + return "access_denied" + case ResolveReasonServiceError: + return "service_error" + default: + return "unknown" + } +} + +// KeyVaultResolveError is returned when [KeyVaultResolver.Resolve] fails. +type KeyVaultResolveError struct { + // Reference is the original akvs:// URI that was being resolved. + Reference string + + // Reason classifies the failure. + Reason ResolveReason + + // Err is the underlying error. + Err error +} + +func (e *KeyVaultResolveError) Error() string { + return fmt.Sprintf( + "azdext.KeyVaultResolver: %s (ref=%s): %v", + e.Reason, e.Reference, e.Err, + ) +} + +func (e *KeyVaultResolveError) Unwrap() error { + return e.Err +} diff --git a/cli/azd/pkg/azdext/keyvault_resolver_test.go b/cli/azd/pkg/azdext/keyvault_resolver_test.go new file mode 100644 index 00000000000..3423113fb8a --- /dev/null +++ b/cli/azd/pkg/azdext/keyvault_resolver_test.go @@ -0,0 +1,745 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" +) + +// stubSecretGetter is a test double for the Key Vault data-plane client. +// It records the name and version args it receives for verification. +type stubSecretGetter struct { + resp azsecrets.GetSecretResponse + err error + + // Recorded call args (set on each GetSecret call). + calledName string + calledVersion string +} + +func (s *stubSecretGetter) GetSecret( + _ context.Context, name string, version string, _ *azsecrets.GetSecretOptions, +) (azsecrets.GetSecretResponse, error) { + s.calledName = name + s.calledVersion = version + return s.resp, s.err +} + +// stubSecretFactory returns a factory that always returns the given stubSecretGetter. +func stubSecretFactory(g secretGetter, factoryErr error) func(string, azcore.TokenCredential) (secretGetter, error) { + return func(_ string, _ azcore.TokenCredential) (secretGetter, error) { + if factoryErr != nil { + return nil, factoryErr + } + return g, nil + } +} + +// --- NewKeyVaultResolver --- + +func TestNewKeyVaultResolver_NilCredential(t *testing.T) { + t.Parallel() + + _, err := NewKeyVaultResolver(nil, nil) + if err == nil { + t.Fatal("expected error for nil credential") + } +} + +func TestNewKeyVaultResolver_Defaults(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolver.opts.VaultSuffix != "vault.azure.net" { + t.Errorf("VaultSuffix = %q, want %q", resolver.opts.VaultSuffix, "vault.azure.net") + } +} + +func TestNewKeyVaultResolver_CustomSuffix(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + VaultSuffix: "vault.azure.cn", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolver.opts.VaultSuffix != "vault.azure.cn" { + t.Errorf("VaultSuffix = %q, want %q", resolver.opts.VaultSuffix, "vault.azure.cn") + } +} + +// --- IsSecretReference --- + +func TestIsSecretReference(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + want bool + }{ + {"akvs://sub/vault/secret", true}, + {"akvs://", true}, + {"AKVS://sub/vault/secret", false}, // case-sensitive + {"https://vault.azure.net", false}, + {"", false}, + // @Microsoft.KeyVault format + {"@Microsoft.KeyVault(SecretUri=https://v.vault.azure.net/secrets/s)", true}, + // case-insensitive prefix (matches Azure App Service behavior) + {"@microsoft.keyvault(secreturi=https://v.vault.azure.net/secrets/s)", true}, + // VaultName/SecretName form is not supported + {"@Microsoft.KeyVault(VaultName=v;SecretName=s)", false}, + } + + for _, tt := range tests { + if got := IsSecretReference(tt.input); got != tt.want { + t.Errorf("IsSecretReference(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} + +// --- ParseSecretReference --- + +func TestParseSecretReference_Valid(t *testing.T) { + t.Parallel() + + ref, err := ParseSecretReference("akvs://sub-123/my-vault/my-secret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ref.SubscriptionID != "sub-123" { + t.Errorf("SubscriptionID = %q, want %q", ref.SubscriptionID, "sub-123") + } + if ref.VaultName != "my-vault" { + t.Errorf("VaultName = %q, want %q", ref.VaultName, "my-vault") + } + if ref.SecretName != "my-secret" { + t.Errorf("SecretName = %q, want %q", ref.SecretName, "my-secret") + } +} + +func TestParseSecretReference_NotAkvsScheme(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("https://vault.azure.net/secrets/x") + if err == nil { + t.Fatal("expected error for non-akvs scheme") + } +} + +func TestParseSecretReference_TooFewParts(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("akvs://sub/vault") + if err == nil { + t.Fatal("expected error for two-part ref") + } +} + +func TestParseSecretReference_TooManyParts(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference("akvs://sub/vault/secret/extra") + if err == nil { + t.Fatal("expected error for four-part ref") + } +} + +func TestParseSecretReference_EmptyComponent(t *testing.T) { + t.Parallel() + + cases := []string{ + "akvs:///vault/secret", // empty subscription + "akvs://sub//secret", // empty vault + "akvs://sub/vault/", // empty secret + "akvs:// /vault/secret", // whitespace subscription + "akvs://sub/ /secret", // whitespace vault + "akvs://sub/vault/ ", // whitespace secret + } + + for _, ref := range cases { + _, err := ParseSecretReference(ref) + if err == nil { + t.Errorf("ParseSecretReference(%q) expected error, got nil", ref) + } + } +} + +// --- Resolve --- + +func TestResolve_Success(t *testing.T) { + t.Parallel() + + secretValue := "super-secret-value" + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + val, err := resolver.Resolve(t.Context(), "akvs://sub-id/my-vault/my-secret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if val != secretValue { + t.Errorf("Resolve() = %q, want %q", val, secretValue) + } + + // Verify the stub received the correct name and version args. + if getter.calledName != "my-secret" { + t.Errorf("stub received name = %q, want %q", getter.calledName, "my-secret") + } + if getter.calledVersion != "" { + t.Errorf("stub received version = %q, want empty", getter.calledVersion) + } +} + +func TestResolve_NilContext(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + //nolint:staticcheck // intentionally testing nil context + //lint:ignore SA1012 intentionally testing nil context handling + _, err := resolver.Resolve(nil, "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for nil context") + } +} + +func TestResolve_InvalidReference(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + _, err := resolver.Resolve(t.Context(), "not-akvs://x") + if err == nil { + t.Fatal("expected error for invalid reference") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonInvalidReference { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonInvalidReference) + } +} + +func TestResolve_ClientCreationFailure(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(nil, errors.New("connection refused")), + }) + + _, err := resolver.Resolve(t.Context(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for client creation failure") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonClientCreation { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonClientCreation) + } +} + +func TestResolve_HTTPErrorClassification(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCode int + wantReason ResolveReason + }{ + {"NotFound", http.StatusNotFound, ResolveReasonNotFound}, + {"Forbidden", http.StatusForbidden, ResolveReasonAccessDenied}, + {"Unauthorized", http.StatusUnauthorized, ResolveReasonAccessDenied}, + {"InternalServerError", http.StatusInternalServerError, ResolveReasonServiceError}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: tt.statusCode}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(t.Context(), "akvs://sub/vault/secret") + if err == nil { + t.Fatalf("expected error for HTTP %d", tt.statusCode) + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != tt.wantReason { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, tt.wantReason) + } + }) + } +} + +func TestResolve_NilValue(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: nil, + }, + }, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(t.Context(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for nil secret value") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonNotFound { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonNotFound) + } +} + +func TestResolve_NonResponseError(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: errors.New("network timeout"), + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + _, err := resolver.Resolve(t.Context(), "akvs://sub/vault/secret") + if err == nil { + t.Fatal("expected error for network failure") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + // Non-ResponseError defaults to service_error (not access_denied), + // since non-HTTP errors are typically connectivity/network issues. + if resolveErr.Reason != ResolveReasonServiceError { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonServiceError) + } +} + +// --- ResolveMap --- + +func TestResolveMap_MixedValues(t *testing.T) { + t.Parallel() + + secretValue := "resolved-secret" + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + input := map[string]string{ //nolint:gosec // G101 false positive: test fixture, not real credentials + "plain": "hello-world", + "secret": "akvs://sub/vault/secret", + } + + result, err := resolver.ResolveMap(t.Context(), input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result["plain"] != "hello-world" { + t.Errorf("result[plain] = %q, want %q", result["plain"], "hello-world") + } + + if result["secret"] != secretValue { + t.Errorf("result[secret] = %q, want %q", result["secret"], secretValue) + } +} + +func TestResolveMap_Empty(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + result, err := resolver.ResolveMap(t.Context(), map[string]string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestResolveMap_ErrorCollectsAllFailures(t *testing.T) { + t.Parallel() + + getter := &stubSecretGetter{ + err: &azcore.ResponseError{StatusCode: http.StatusNotFound}, + } + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + + input := map[string]string{ //nolint:gosec // G101 false positive: test fixture, not real credentials + "secret1": "akvs://sub/vault/missing1", + "secret2": "akvs://sub/vault/missing2", + "secret3": "akvs://sub/vault/missing3", + "plain": "not-a-secret-ref", + } + + // ResolveMap collects errors instead of stopping at the first one. + result, err := resolver.ResolveMap(t.Context(), input) + if err == nil { + t.Fatal("expected error when resolution fails") + } + + // Partial result should be non-nil and contain the plain value. + if result == nil { + t.Fatal("expected non-nil partial result") + } + + if result["plain"] != "not-a-secret-ref" { + t.Errorf("result[plain] = %q, want %q", result["plain"], "not-a-secret-ref") + } + + // The error should mention all 3 failing keys. + errMsg := err.Error() + for _, key := range []string{"secret1", "secret2", "secret3"} { + if !strings.Contains(errMsg, key) { + t.Errorf("error should mention %q, got: %s", key, errMsg) + } + } +} + +func TestResolveMap_NilContext(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + //nolint:staticcheck // intentionally testing nil context + //lint:ignore SA1012 intentionally testing nil context handling + _, err := resolver.ResolveMap(nil, map[string]string{"k": "v"}) + if err == nil { + t.Fatal("expected error for nil context") + } +} + +// --- @Microsoft.KeyVault format --- + +func TestParseSecretReference_AppRefValid(t *testing.T) { + t.Parallel() + + ref, err := ParseSecretReference( + "@Microsoft.KeyVault(SecretUri=https://myvault.vault.azure.net/secrets/mysecret)") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ref.VaultName != "myvault" { + t.Errorf("VaultName = %q, want %q", ref.VaultName, "myvault") + } + if ref.SecretName != "mysecret" { + t.Errorf("SecretName = %q, want %q", ref.SecretName, "mysecret") + } + if ref.SecretVersion != "" { + t.Errorf("SecretVersion = %q, want empty", ref.SecretVersion) + } + if ref.VaultURL != "https://myvault.vault.azure.net" { + t.Errorf("VaultURL = %q, want %q", ref.VaultURL, "https://myvault.vault.azure.net") + } +} + +func TestParseSecretReference_AppRefValidWithVersion(t *testing.T) { + t.Parallel() + + ref, err := ParseSecretReference( + "@Microsoft.KeyVault(SecretUri=https://myvault.vault.azure.net/secrets/mysecret/version123)") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ref.SecretName != "mysecret" { + t.Errorf("SecretName = %q, want %q", ref.SecretName, "mysecret") + } + if ref.SecretVersion != "version123" { + t.Errorf("SecretVersion = %q, want %q", ref.SecretVersion, "version123") + } +} + +func TestParseSecretReference_AppRefInvalidHost(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference( + "@Microsoft.KeyVault(SecretUri=https://evil.com/secrets/foo)") + if err == nil { + t.Fatal("expected error for non-Azure Key Vault host") + } + + if !strings.Contains(err.Error(), "not a known Azure Key Vault endpoint") { + t.Errorf("error = %q, want mention of 'not a known Azure Key Vault endpoint'", err.Error()) + } +} + +func TestParseSecretReference_AppRefMalformedURI(t *testing.T) { + t.Parallel() + + _, err := ParseSecretReference( + "@Microsoft.KeyVault(SecretUri=not-a-url)") + if err == nil { + t.Fatal("expected error for malformed SecretUri") + } +} + +func TestParseSecretReference_AppRefSovereignClouds(t *testing.T) { + t.Parallel() + + validHosts := []struct { + name string + uri string + }{ + {"AzureChina", "https://myvault.vault.azure.cn/secrets/s"}, + {"AzureGov", "https://myvault.vault.usgovcloudapi.net/secrets/s"}, + {"AzureGermany", "https://myvault.vault.microsoftazure.de/secrets/s"}, + {"ManagedHSM", "https://myvault.managedhsm.azure.net/secrets/s"}, + } + + for _, tc := range validHosts { + t.Run(tc.name, func(t *testing.T) { + ref, err := ParseSecretReference( + fmt.Sprintf("@Microsoft.KeyVault(SecretUri=%s)", tc.uri)) + if err != nil { + t.Fatalf("unexpected error for %s: %v", tc.name, err) + } + if ref.SecretName != "s" { + t.Errorf("SecretName = %q, want %q", ref.SecretName, "s") + } + }) + } +} + +func TestResolve_AppRefSuccess(t *testing.T) { + t.Parallel() + + secretValue := "app-ref-secret-value" //nolint:gosec // test data, not a real credential + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + val, err := resolver.Resolve(t.Context(), + "@Microsoft.KeyVault(SecretUri=https://myvault.vault.azure.net/secrets/mysecret)") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if val != secretValue { + t.Errorf("Resolve() = %q, want %q", val, secretValue) + } +} + +func TestResolve_AppRefWithVersion(t *testing.T) { + t.Parallel() + + secretValue := "versioned-value" + getter := &stubSecretGetter{ + resp: azsecrets.GetSecretResponse{ + Secret: azsecrets.Secret{ + Value: &secretValue, + }, + }, + } + + cred := &stubCredential{} + resolver, err := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(getter, nil), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + val, err := resolver.Resolve(t.Context(), + "@Microsoft.KeyVault(SecretUri=https://myvault.vault.azure.net/secrets/mysecret/v1)") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if val != secretValue { + t.Errorf("Resolve() = %q, want %q", val, secretValue) + } + + // Verify name and version were dispatched correctly. + if getter.calledName != "mysecret" { + t.Errorf("stub received name = %q, want %q", getter.calledName, "mysecret") + } + if getter.calledVersion != "v1" { + t.Errorf("stub received version = %q, want %q", getter.calledVersion, "v1") + } +} + +func TestResolve_AppRefInvalidHostReturnsError(t *testing.T) { + t.Parallel() + + cred := &stubCredential{} + resolver, _ := NewKeyVaultResolver(cred, &KeyVaultResolverOptions{ + ClientFactory: stubSecretFactory(&stubSecretGetter{}, nil), + }) + + _, err := resolver.Resolve(t.Context(), + "@Microsoft.KeyVault(SecretUri=https://evil.com/secrets/foo)") + if err == nil { + t.Fatal("expected error for invalid vault host") + } + + var resolveErr *KeyVaultResolveError + if !errors.As(err, &resolveErr) { + t.Fatalf("error type = %T, want *KeyVaultResolveError", err) + } + + if resolveErr.Reason != ResolveReasonInvalidReference { + t.Errorf("Reason = %v, want %v", resolveErr.Reason, ResolveReasonInvalidReference) + } +} + +// --- Error types --- + +func TestKeyVaultResolveError_Error(t *testing.T) { + t.Parallel() + + err := &KeyVaultResolveError{ + Reference: "akvs://sub/vault/secret", + Reason: ResolveReasonNotFound, + Err: errors.New("secret not found"), + } + + got := err.Error() + if got == "" { + t.Fatal("Error() returned empty string") + } +} + +func TestKeyVaultResolveError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("inner error") + err := &KeyVaultResolveError{ + Reference: "akvs://sub/vault/secret", + Reason: ResolveReasonServiceError, + Err: inner, + } + + if !errors.Is(err, inner) { + t.Error("Unwrap should expose inner error via errors.Is") + } +} + +func TestResolveReason_String(t *testing.T) { + t.Parallel() + + tests := []struct { + reason ResolveReason + want string + }{ + {ResolveReasonInvalidReference, "invalid_reference"}, + {ResolveReasonClientCreation, "client_creation"}, + {ResolveReasonNotFound, "not_found"}, + {ResolveReasonAccessDenied, "access_denied"}, + {ResolveReasonServiceError, "service_error"}, + {ResolveReason(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.reason.String(); got != tt.want { + t.Errorf("ResolveReason(%d).String() = %q, want %q", tt.reason, got, tt.want) + } + } +} diff --git a/cli/azd/pkg/keyvault/keyvault.go b/cli/azd/pkg/keyvault/keyvault.go index 0eb3424cdf6..c2cc59beb55 100644 --- a/cli/azd/pkg/keyvault/keyvault.go +++ b/cli/azd/pkg/keyvault/keyvault.go @@ -9,6 +9,7 @@ import ( "fmt" "log" "net/http" + "net/url" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -75,6 +76,12 @@ type KeyVaultService interface { secretValue string, ) error SecretFromAkvs(ctx context.Context, akvs string) (string, error) + // SecretFromKeyVaultReference resolves a secret reference in either the + // akvs:// or @Microsoft.KeyVault(SecretUri=...) format. The subscriptionId + // is required for credential scoping; for @Microsoft.KeyVault references + // (which lack a subscription), the caller should provide the environment's + // default subscription. + SecretFromKeyVaultReference(ctx context.Context, ref string, defaultSubscriptionId string) (string, error) } type keyVaultService struct { @@ -373,6 +380,45 @@ func (kvs *keyVaultService) SecretFromAkvs(ctx context.Context, akvs string) (st return secretValue.Value, nil } +func (kvs *keyVaultService) SecretFromKeyVaultReference( + ctx context.Context, ref string, defaultSubscriptionId string, +) (string, error) { + // Try akvs:// first (includes its own subscription ID) + if IsAzureKeyVaultSecret(ref) { + return kvs.SecretFromAkvs(ctx, ref) + } + + // Try @Microsoft.KeyVault(SecretUri=...) + if IsKeyVaultAppReference(ref) { + parsed, err := ParseKeyVaultAppReference(ref) + if err != nil { + return "", err + } + + // Use the vault URL directly. The subscription ID is only needed + // for credential scoping (tenant lookup), so we use the default. + client, err := kvs.createSecretsDataClient(ctx, defaultSubscriptionId, parsed.VaultURL) + if err != nil { + return "", fmt.Errorf("creating Key Vault client for %s: %w", parsed.VaultURL, err) + } + + resp, err := client.GetSecret(ctx, parsed.SecretName, parsed.SecretVersion, nil) + if err != nil { + return "", fmt.Errorf("fetching secret %q from vault %q: %w", + parsed.SecretName, parsed.VaultName, err) + } + + if resp.Value == nil { + return "", fmt.Errorf("secret %q in vault %q has a nil value", + parsed.SecretName, parsed.VaultName) + } + + return *resp.Value, nil + } + + return "", fmt.Errorf("unrecognized Key Vault reference format: %s", ref) +} + // AzureKeyVaultSecret represents a secret stored in an Azure Key Vault. // It contains the necessary information to identify and access the secret. // @@ -418,3 +464,210 @@ func ParseAzureKeyVaultSecret(akvs string) (AzureKeyVaultSecret, error) { SecretName: vaultParts[2], }, nil } + +const keyVaultAppRefPrefix = "@Microsoft.KeyVault(" + +// IsKeyVaultAppReference reports whether s uses the @Microsoft.KeyVault(SecretUri=...) format +// used by Azure App Service and App Configuration for Key Vault references. +// The prefix check is case-insensitive to match Azure App Service behavior. +// Only the SecretUri= variant is supported; other forms (e.g., VaultName/SecretName) return false. +func IsKeyVaultAppReference(s string) bool { + if len(s) < len(keyVaultAppRefPrefix) || + !strings.EqualFold(s[:len(keyVaultAppRefPrefix)], keyVaultAppRefPrefix) || + !strings.HasSuffix(s, ")") { + return false + } + + inner := strings.TrimSpace(s[len(keyVaultAppRefPrefix) : len(s)-1]) + return len(inner) > len("SecretUri=") && + strings.EqualFold(inner[:len("SecretUri=")], "SecretUri=") +} + +// IsSecretReference reports whether s is a Key Vault secret reference in either +// the akvs:// or @Microsoft.KeyVault(SecretUri=...) format. +func IsSecretReference(s string) bool { + return IsAzureKeyVaultSecret(s) || IsKeyVaultAppReference(s) +} + +// validVaultHostSuffixes lists the known Azure Key Vault DNS suffixes. +// Used to validate SecretUri hostnames and prevent SSRF attacks via +// @Microsoft.KeyVault(SecretUri=https://evil.com/...) references. +var validVaultHostSuffixes = []string{ + ".vault.azure.net", + ".vault.azure.cn", + ".vault.usgovcloudapi.net", + ".vault.microsoftazure.de", + ".managedhsm.azure.net", +} + +// isValidVaultHost reports whether host is a known Azure Key Vault endpoint. +func isValidVaultHost(host string) bool { + host = strings.ToLower(host) + for _, suffix := range validVaultHostSuffixes { + if strings.HasSuffix(host, suffix) { + return true + } + } + return false +} + +// KeyVaultAppReference represents a parsed @Microsoft.KeyVault(SecretUri=...) reference. +type KeyVaultAppReference struct { + // VaultURL is the full vault URL (e.g., "https://my-vault.vault.azure.net"). + VaultURL string + + // VaultName is the vault name extracted from the host. + VaultName string + + // SecretName is the name of the secret. + SecretName string + + // SecretVersion is the specific version, or empty for latest. + SecretVersion string +} + +// ParseKeyVaultAppReference parses an @Microsoft.KeyVault(SecretUri=...) reference. +// +// Expected format: +// +// @Microsoft.KeyVault(SecretUri=https://.vault.azure.net/secrets/[/]) +func ParseKeyVaultAppReference(ref string) (KeyVaultAppReference, error) { + if !IsKeyVaultAppReference(ref) { + return KeyVaultAppReference{}, fmt.Errorf("invalid @Microsoft.KeyVault reference: %s", ref) + } + + inner := strings.TrimSpace(ref[len(keyVaultAppRefPrefix) : len(ref)-1]) + + const secretURIPrefix = "SecretUri=" + if len(inner) < len(secretURIPrefix) || + !strings.EqualFold(inner[:len(secretURIPrefix)], secretURIPrefix) { + return KeyVaultAppReference{}, fmt.Errorf( + "invalid @Microsoft.KeyVault reference %q: expected SecretUri= parameter", ref) + } + + secretURI := strings.TrimSpace(inner[len(secretURIPrefix):]) + if secretURI == "" { + return KeyVaultAppReference{}, fmt.Errorf( + "invalid @Microsoft.KeyVault reference %q: SecretUri value must not be empty", ref) + } + + u, err := url.Parse(secretURI) + if err != nil { + return KeyVaultAppReference{}, fmt.Errorf( + "invalid @Microsoft.KeyVault reference %q: malformed SecretUri: %w", ref, err) + } + + if u.Scheme != "https" { + return KeyVaultAppReference{}, fmt.Errorf( + "invalid @Microsoft.KeyVault reference %q: SecretUri must use https scheme", ref) + } + + host := u.Hostname() + if host == "" { + return KeyVaultAppReference{}, fmt.Errorf( + "invalid @Microsoft.KeyVault reference %q: SecretUri must include a host", ref) + } + + // Reject non-standard ports to prevent SSRF — bearer tokens should only be + // sent to the default HTTPS port (443). + if port := u.Port(); port != "" && port != "443" { + return KeyVaultAppReference{}, fmt.Errorf( + "invalid @Microsoft.KeyVault reference %q: non-standard port %q is not allowed", ref, port) + } + + if !isValidVaultHost(host) { + return KeyVaultAppReference{}, fmt.Errorf( + "invalid @Microsoft.KeyVault reference %q: host %q is not a known Azure Key Vault endpoint", ref, host) + } + + parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/") + if len(parts) < 2 || parts[0] != "secrets" { + return KeyVaultAppReference{}, fmt.Errorf( + "invalid @Microsoft.KeyVault reference %q: SecretUri path must be /secrets/[/]", ref) + } + + secretName := parts[1] + if secretName == "" { + return KeyVaultAppReference{}, fmt.Errorf( + "invalid @Microsoft.KeyVault reference %q: secret name must not be empty", ref) + } + + var secretVersion string + if len(parts) >= 3 && parts[2] != "" { + secretVersion = parts[2] + } + + vaultName := host + if idx := strings.Index(vaultName, "."); idx > 0 { + vaultName = vaultName[:idx] + } + + if vaultName == "" || vaultName == host { + // Either the host had no subdomain (e.g., "vault.azure.net") or the + // dot was at position 0 — both indicate a missing vault name. + return KeyVaultAppReference{}, fmt.Errorf( + "invalid @Microsoft.KeyVault reference %q: could not extract vault name from host %q", ref, host) + } + + return KeyVaultAppReference{ + VaultURL: fmt.Sprintf("https://%s", u.Host), + VaultName: vaultName, + SecretName: secretName, + SecretVersion: secretVersion, + }, nil +} + +// ResolveSecretEnvironment resolves Key Vault secret references in a list of +// environment variables (in "KEY=VALUE" format). Any value that matches the +// akvs:// or @Microsoft.KeyVault(SecretUri=...) format is replaced with the +// resolved secret value. Non-secret values are passed through unchanged. +// +// On failure, individual variables are set to empty values (to avoid leaking +// raw references), and all errors are collected and returned via [errors.Join]. +// The returned env slice is always valid — callers can choose to proceed with +// partial results or fail based on the error. +func ResolveSecretEnvironment( + ctx context.Context, + kvService KeyVaultService, + envVars []string, + defaultSubscriptionId string, +) ([]string, error) { + if kvService == nil { + return envVars, nil + } + + result := make([]string, len(envVars)) + var errs []error + + for i, envVar := range envVars { + before, after, ok := strings.Cut(envVar, "=") + if !ok { + result[i] = envVar + continue + } + + key := before + value := after + + if !IsSecretReference(value) { + result[i] = envVar + continue + } + + resolved, err := kvService.SecretFromKeyVaultReference(ctx, value, defaultSubscriptionId) + if err != nil { + log.Printf("warning: failed to resolve Key Vault reference for %s: %v", key, err) + errs = append(errs, fmt.Errorf("key %q: %w", key, err)) + result[i] = key + "=" // Empty value — don't leak the raw reference + continue + } + + result[i] = key + "=" + resolved + } + + if len(errs) > 0 { + return result, fmt.Errorf("failed to resolve Key Vault references: %w", errors.Join(errs...)) + } + + return result, nil +} diff --git a/cli/azd/pkg/keyvault/keyvault_test.go b/cli/azd/pkg/keyvault/keyvault_test.go new file mode 100644 index 00000000000..52d38c09514 --- /dev/null +++ b/cli/azd/pkg/keyvault/keyvault_test.go @@ -0,0 +1,338 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package keyvault + +import ( + "context" + "errors" + "strings" + "testing" +) + +// mockKeyVaultService is a minimal mock for the KeyVaultService interface. +// It only implements SecretFromKeyVaultReference (the method under test); +// all other methods panic if called. +type mockKeyVaultService struct { + // resolveFunc, when set, is called by SecretFromKeyVaultReference. + resolveFunc func(ctx context.Context, ref string, defaultSubID string) (string, error) +} + +func (m *mockKeyVaultService) GetKeyVault(context.Context, string, string, string) (*KeyVault, error) { + panic("not implemented") +} +func (m *mockKeyVaultService) GetKeyVaultSecret(context.Context, string, string, string) (*Secret, error) { + panic("not implemented") +} +func (m *mockKeyVaultService) PurgeKeyVault(context.Context, string, string, string) error { + panic("not implemented") +} +func (m *mockKeyVaultService) ListSubscriptionVaults(context.Context, string) ([]Vault, error) { + panic("not implemented") +} +func (m *mockKeyVaultService) CreateVault(context.Context, string, string, string, string, string) (Vault, error) { + panic("not implemented") +} +func (m *mockKeyVaultService) ListKeyVaultSecrets(context.Context, string, string) ([]string, error) { + panic("not implemented") +} +func (m *mockKeyVaultService) CreateKeyVaultSecret(context.Context, string, string, string, string) error { + panic("not implemented") +} +func (m *mockKeyVaultService) SecretFromAkvs(context.Context, string) (string, error) { + panic("not implemented") +} + +func (m *mockKeyVaultService) SecretFromKeyVaultReference( + ctx context.Context, ref string, defaultSubID string, +) (string, error) { + if m.resolveFunc != nil { + return m.resolveFunc(ctx, ref, defaultSubID) + } + return "", errors.New("mockKeyVaultService: resolveFunc not set") +} + +// --- ResolveSecretEnvironment --- + +func TestResolveSecretEnvironment_NilService(t *testing.T) { + t.Parallel() + + input := []string{"FOO=bar", "SECRET=akvs://sub/vault/name"} + result, err := ResolveSecretEnvironment(t.Context(), nil, input, "sub") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // With nil kvService, input is returned unchanged. + if len(result) != len(input) { + t.Fatalf("len(result) = %d, want %d", len(result), len(input)) + } + for i, v := range result { + if v != input[i] { + t.Errorf("result[%d] = %q, want %q", i, v, input[i]) + } + } +} + +func TestResolveSecretEnvironment_PlainValues(t *testing.T) { + t.Parallel() + + mock := &mockKeyVaultService{ + resolveFunc: func(_ context.Context, _ string, _ string) (string, error) { + t.Fatal("resolveFunc should not be called for plain values") + return "", nil + }, + } + + input := []string{"FOO=bar", "BAZ=qux", "EMPTY="} + result, err := ResolveSecretEnvironment(t.Context(), mock, input, "sub") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + for i, v := range result { + if v != input[i] { + t.Errorf("result[%d] = %q, want %q", i, v, input[i]) + } + } +} + +func TestResolveSecretEnvironment_MalformedEnvVar(t *testing.T) { + t.Parallel() + + mock := &mockKeyVaultService{} + // Entries without '=' should be passed through unchanged. + input := []string{"NO_EQUALS_SIGN", "FOO=bar"} + result, err := ResolveSecretEnvironment(t.Context(), mock, input, "sub") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result[0] != "NO_EQUALS_SIGN" { + t.Errorf("result[0] = %q, want %q", result[0], "NO_EQUALS_SIGN") + } + if result[1] != "FOO=bar" { + t.Errorf("result[1] = %q, want %q", result[1], "FOO=bar") + } +} + +func TestResolveSecretEnvironment_MixedAkvsAndAppRef(t *testing.T) { + t.Parallel() + + mock := &mockKeyVaultService{ + resolveFunc: func(_ context.Context, ref string, _ string) (string, error) { + switch { + case strings.HasPrefix(ref, "akvs://"): + return "akvs-resolved", nil + case IsKeyVaultAppReference(ref): + return "appref-resolved", nil + default: + return "", errors.New("unexpected ref: " + ref) + } + }, + } + + input := []string{ + "PLAIN=hello", + "AKVS_SECRET=akvs://sub/vault/secret", + "APPREF_SECRET=@Microsoft.KeyVault(SecretUri=https://v.vault.azure.net/secrets/s)", + } + + result, err := ResolveSecretEnvironment(t.Context(), mock, input, "sub") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []string{ + "PLAIN=hello", + "AKVS_SECRET=akvs-resolved", + "APPREF_SECRET=appref-resolved", + } + for i, v := range result { + if v != expected[i] { + t.Errorf("result[%d] = %q, want %q", i, v, expected[i]) + } + } +} + +func TestResolveSecretEnvironment_ErrorCollection(t *testing.T) { + t.Parallel() + + mock := &mockKeyVaultService{ + resolveFunc: func(_ context.Context, _ string, _ string) (string, error) { + return "", errors.New("vault unavailable") + }, + } + + input := []string{ + "SECRET1=akvs://sub/vault/s1", + "SECRET2=akvs://sub/vault/s2", + "PLAIN=hello", + } + + result, err := ResolveSecretEnvironment(t.Context(), mock, input, "sub") + if err == nil { + t.Fatal("expected error for failed resolutions") + } + + // Both failing keys should appear in the error message. + errMsg := err.Error() + if !strings.Contains(errMsg, `"SECRET1"`) { + t.Errorf("error should mention SECRET1, got: %s", errMsg) + } + if !strings.Contains(errMsg, `"SECRET2"`) { + t.Errorf("error should mention SECRET2, got: %s", errMsg) + } + + // Failed secrets get empty values; plain value passes through. + if result[0] != "SECRET1=" { + t.Errorf("result[0] = %q, want %q", result[0], "SECRET1=") + } + if result[1] != "SECRET2=" { + t.Errorf("result[1] = %q, want %q", result[1], "SECRET2=") + } + if result[2] != "PLAIN=hello" { + t.Errorf("result[2] = %q, want %q", result[2], "PLAIN=hello") + } +} + +func TestResolveSecretEnvironment_PreservesOrdering(t *testing.T) { + t.Parallel() + + mock := &mockKeyVaultService{ + resolveFunc: func(_ context.Context, _ string, _ string) (string, error) { + return "resolved", nil + }, + } + + // System env first, then azd override — last-wins semantics. + input := []string{ + "PATH=/usr/bin", + "DB_CONN=akvs://sub/vault/db", + "PATH=/override", + } + + result, err := ResolveSecretEnvironment(t.Context(), mock, input, "sub") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Order must be preserved (not sorted alphabetically). + if result[0] != "PATH=/usr/bin" { + t.Errorf("result[0] = %q, want %q", result[0], "PATH=/usr/bin") + } + if result[1] != "DB_CONN=resolved" { + t.Errorf("result[1] = %q, want %q", result[1], "DB_CONN=resolved") + } + if result[2] != "PATH=/override" { + t.Errorf("result[2] = %q, want %q", result[2], "PATH=/override") + } +} + +func TestResolveSecretEnvironment_UnrecognizedFormatError(t *testing.T) { + t.Parallel() + + // A ref that passes IsSecretReference but SecretFromKeyVaultReference + // returns "unrecognized format" — simulates the fallthrough path. + mock := &mockKeyVaultService{ + resolveFunc: func(_ context.Context, ref string, _ string) (string, error) { + return "", errors.New("unrecognized Key Vault reference format: " + ref) + }, + } + + input := []string{ + "SECRET=akvs://sub/vault/secret", + } + + _, err := ResolveSecretEnvironment(t.Context(), mock, input, "sub") + if err == nil { + t.Fatal("expected error for unrecognized format") + } + if !strings.Contains(err.Error(), "unrecognized") { + t.Errorf("error should mention 'unrecognized', got: %s", err.Error()) + } +} + +// --- ParseKeyVaultAppReference additional cases --- + +func TestParseKeyVaultAppReference_NonStandardPort(t *testing.T) { + t.Parallel() + + _, err := ParseKeyVaultAppReference( + "@Microsoft.KeyVault(SecretUri=https://myvault.vault.azure.net:9999/secrets/foo)") + if err == nil { + t.Fatal("expected error for non-standard port") + } + if !strings.Contains(err.Error(), "non-standard port") { + t.Errorf("error = %q, want mention of 'non-standard port'", err.Error()) + } +} + +func TestParseKeyVaultAppReference_Port443Allowed(t *testing.T) { + t.Parallel() + + ref, err := ParseKeyVaultAppReference( + "@Microsoft.KeyVault(SecretUri=https://myvault.vault.azure.net:443/secrets/foo)") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ref.SecretName != "foo" { + t.Errorf("SecretName = %q, want %q", ref.SecretName, "foo") + } +} + +func TestParseKeyVaultAppReference_EmptyVaultName(t *testing.T) { + t.Parallel() + + // "vault.azure.net" is the bare suffix without a vault-name subdomain. + // isValidVaultHost rejects it (needs ".vault.azure.net" suffix with a + // leading dot), so the error reports an unknown endpoint rather than + // reaching the vault-name extraction guard. + _, err := ParseKeyVaultAppReference( + "@Microsoft.KeyVault(SecretUri=https://vault.azure.net/secrets/foo)") + if err == nil { + t.Fatal("expected error for bare suffix hostname") + } + if !strings.Contains(err.Error(), "vault.azure.net") { + t.Errorf("error = %q, want mention of problematic host", err.Error()) + } +} + +func TestParseKeyVaultAppReference_CaseInsensitive(t *testing.T) { + t.Parallel() + + ref, err := ParseKeyVaultAppReference( + "@microsoft.keyvault(secreturi=https://myvault.vault.azure.net/secrets/mysecret)") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ref.VaultName != "myvault" { + t.Errorf("VaultName = %q, want %q", ref.VaultName, "myvault") + } + if ref.SecretName != "mysecret" { + t.Errorf("SecretName = %q, want %q", ref.SecretName, "mysecret") + } +} + +// --- IsSecretReference --- + +func TestIsSecretReference_Comprehensive(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + want bool + }{ + {"akvs://sub/vault/secret", true}, + {"@Microsoft.KeyVault(SecretUri=https://v.vault.azure.net/secrets/s)", true}, + {"@microsoft.keyvault(secreturi=https://v.vault.azure.net/secrets/s)", true}, + {"@Microsoft.KeyVault(VaultName=v;SecretName=s)", false}, + {"plain-value", false}, + {"", false}, + } + + for _, tt := range tests { + if got := IsSecretReference(tt.input); got != tt.want { + t.Errorf("IsSecretReference(%q) = %v, want %v", tt.input, got, tt.want) + } + } +}