Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions cli/azd/extensions/azure.ai.agents/internal/cmd/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"net/http"
"net/url"
"os"
osExec "os/exec"
"path/filepath"
"strings"
"time"
Expand Down Expand Up @@ -129,6 +130,77 @@ func checkAiModelServiceAvailable(ctx context.Context, azdClient *azdext.AzdClie
return nil
}

// ensureLoggedIn verifies that the user is authenticated before any file-modifying
// operations take place.
//
// We need to parse the JSON output of `azd auth status --output json` because the
// Workflow API's Run method returns EmptyResponse and does not expose command output,
// and `azd auth status` always exits 0 regardless of authentication state — it reports
// the result in its output, not via its exit code or a gRPC error.
// If the Workflow API is extended to return structured command results in the future,
// this subprocess workaround can be replaced with a Workflow API call.
//
// getAuthStatusJSON is the function that runs the command and returns stdout. Production
// callers pass authStatusFromCLI; tests inject a stub.
func ensureLoggedIn(ctx context.Context, getAuthStatusJSON func(ctx context.Context) ([]byte, error)) error {
out, err := getAuthStatusJSON(ctx)

// Context cancellation / deadline always takes priority.
if ctx.Err() != nil {
return ctx.Err()
}

// Try to parse whatever output we got, even if the command returned a non-zero
// exit code (ExitError). azd auth status writes JSON to stdout regardless of
// exit code, so the output may still be usable.
if len(out) > 0 {
authStatus, parseErr := parseAuthStatusJSON(out)
if parseErr == nil {
if authStatus == "unauthenticated" {
return exterrors.Auth(
exterrors.CodeNotLoggedIn,
"not logged in",
"run `azd auth login` to authenticate before running init",
)
}

if authStatus == "authenticated" {
return nil
}

// Unrecognized status value — fall through to best-effort skip.
}
}

// No usable output. If the command itself failed, log and skip so unrelated
// issues (azd not in PATH, network blips) don't block init.
if err != nil {
log.Printf("auth status check skipped: %v", err)
}

return nil
}

// authStatusFromCLI runs `azd auth status --output json --no-prompt` as a subprocess
// and returns the raw stdout bytes.
func authStatusFromCLI(ctx context.Context) ([]byte, error) {
return osExec.CommandContext(ctx, "azd", "auth", "status", "--output", "json", "--no-prompt").Output()
}

// parseAuthStatusJSON extracts the "status" field from `azd auth status --output json`.
func parseAuthStatusJSON(data []byte) (string, error) {
var result struct {
Status string `json:"status"`
}
if err := json.Unmarshal(data, &result); err != nil {
return "", fmt.Errorf("unmarshal auth status: %w", err)
}
if result.Status == "" {
return "", fmt.Errorf("missing \"status\" field in auth status output")
}
return result.Status, nil
}

// runInitFromManifest sets up Azure context, credentials, console, and runs the
// InitAction for a given manifest pointer. This is the shared code path used when
// initializing from a manifest URL/path (the -m flag, agent template, or azd template
Expand Down Expand Up @@ -245,6 +317,10 @@ func newInitCommand(rootFlags *rootFlagsDefinition) *cobra.Command {
return fmt.Errorf("failed waiting for debugger: %w", err)
}

if err := ensureLoggedIn(ctx, authStatusFromCLI); err != nil {
return err
}

var httpClient = &http.Client{
Timeout: 30 * time.Second,
}
Expand Down
167 changes: 167 additions & 0 deletions cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1204,3 +1204,170 @@ func TestResolveCollisions_NoPrompt(t *testing.T) {
})
}
}

func TestEnsureLoggedIn(t *testing.T) {
t.Parallel()

tests := []struct {
name string
output []byte
runErr error
wantErr bool
wantCode string
wantMsg string
}{
{
name: "authenticated returns nil",
output: []byte(`{"status":"authenticated","type":"user","email":"user@example.com"}`),
wantErr: false,
},
{
name: "unauthenticated returns structured auth error",
output: []byte(`{"status":"unauthenticated"}`),
wantErr: true,
wantCode: exterrors.CodeNotLoggedIn,
wantMsg: "not logged in",
},
{
name: "unauthenticated with non-zero exit still detected",
output: []byte(`{"status":"unauthenticated"}`),
runErr: errors.New("exit status 1"),
wantErr: true,
wantCode: exterrors.CodeNotLoggedIn,
wantMsg: "not logged in",
},
{
name: "command failure with no output is skipped",
output: nil,
runErr: errors.New("exec: azd not found"),
wantErr: false,
},
{
name: "malformed JSON is skipped",
output: []byte(`not-json`),
wantErr: false,
},
{
name: "empty status field is skipped",
output: []byte(`{"status":""}`),
wantErr: false,
},
{
name: "missing status field is skipped",
output: []byte(`{"email":"user@example.com"}`),
wantErr: false,
},
{
name: "unrecognised status value is skipped",
output: []byte(`{"status":"unknown-value"}`),
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

stub := func(_ context.Context) ([]byte, error) {
return tt.output, tt.runErr
}

err := ensureLoggedIn(t.Context(), stub)

if !tt.wantErr {
if err != nil {
t.Fatalf("expected nil error, got: %v", err)
}
return
}

if err == nil {
t.Fatal("expected an error, got nil")
}

var localErr *azdext.LocalError
if !errors.As(err, &localErr) {
t.Fatalf("expected *azdext.LocalError, got %T: %v", err, err)
}
if localErr.Code != tt.wantCode {
t.Errorf("code = %q, want %q", localErr.Code, tt.wantCode)
}
if tt.wantMsg != "" && !strings.Contains(localErr.Message, tt.wantMsg) {
t.Errorf("message = %q, want it to contain %q", localErr.Message, tt.wantMsg)
}
})
}
}

func TestEnsureLoggedIn_ContextCancelled(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(t.Context())
cancel() // cancel immediately

stub := func(_ context.Context) ([]byte, error) {
return nil, ctx.Err()
}

err := ensureLoggedIn(ctx, stub)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got: %v", err)
}
}

func TestParseAuthStatusJSON(t *testing.T) {
t.Parallel()

tests := []struct {
name string
data []byte
want string
wantErr bool
}{
{
name: "authenticated",
data: []byte(`{"status":"authenticated","type":"user","email":"a@b.com"}`),
want: "authenticated",
},
{
name: "unauthenticated",
data: []byte(`{"status":"unauthenticated"}`),
want: "unauthenticated",
},
{
name: "invalid JSON",
data: []byte(`not json`),
wantErr: true,
},
{
name: "missing status",
data: []byte(`{"email":"a@b.com"}`),
wantErr: true,
},
{
name: "empty status",
data: []byte(`{"status":""}`),
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

got, err := parseAuthStatusJSON(tt.data)
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
})
}
}
Loading