Skip to content

Commit c16dfe1

Browse files
committed
auth: handle token audience mismatch
1 parent b5dcdae commit c16dfe1

2 files changed

Lines changed: 283 additions & 22 deletions

File tree

internal/commands/auth.go

Lines changed: 138 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package commands
33
import (
44
"bufio"
55
"context"
6+
"encoding/base64"
67
"encoding/json"
78
"errors"
89
"fmt"
@@ -290,10 +291,14 @@ func runDeviceFlow(ctx context.Context, opts deviceFlowOpts) error {
290291
}
291292

292293
// exchangeForCellaToken trades an auth-issued user JWT for a cella
293-
// bearer token. The CLI mints a short-TTL actor token at auth, then
294-
// exchanges it at cella's /v1/tokens/exchange. Either step may fail
295-
// (older auth without /actor-tokens, older cella without the token
296-
// catalog); the caller falls back to the auth token in that case.
294+
// bearer token. The preferred path mints a short-TTL actor token at
295+
// auth, then exchanges it at cella's /v1/tokens/exchange.
296+
//
297+
// Some deployed auth versions stamp device-code tokens with sandboxd's
298+
// audience, then reject those same tokens on /actor-tokens because the
299+
// auth middleware expects the auth issuer as audience. In that case the
300+
// device token is still accepted by sandboxd, so use it directly for the
301+
// cella exchange instead of persisting the short-lived auth token.
297302
func exchangeForCellaToken(ctx context.Context, opts deviceFlowOpts, authToken string) (string, error) {
298303
authBase := opts.AuthURL
299304
if authBase == "" {
@@ -324,6 +329,9 @@ func exchangeForCellaToken(ctx context.Context, opts deviceFlowOpts, authToken s
324329
defer func() { _ = resp.Body.Close() }()
325330
if resp.StatusCode/100 != 2 {
326331
b, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<14))
332+
if resp.StatusCode == http.StatusUnauthorized && strings.Contains(string(b), "audience mismatch") {
333+
return exchangeAtCella(ctx, httpc, apiBase, authToken)
334+
}
327335
return "", fmt.Errorf("actor-tokens %d: %s", resp.StatusCode, b)
328336
}
329337
var actor struct {
@@ -334,17 +342,21 @@ func exchangeForCellaToken(ctx context.Context, opts deviceFlowOpts, authToken s
334342
}
335343

336344
// 2. Exchange the actor token at cella.
345+
return exchangeAtCella(ctx, httpc, apiBase, actor.ActorToken)
346+
}
347+
348+
func exchangeAtCella(ctx context.Context, httpc *http.Client, apiBase, bearer string) (string, error) {
337349
hostname, _ := os.Hostname()
338350
if hostname == "" {
339351
hostname = "CLI"
340352
}
341-
body, _ = json.Marshal(map[string]any{"label": "CLI on " + hostname})
353+
body, _ := json.Marshal(map[string]any{"label": "CLI on " + hostname})
342354
req2, err := http.NewRequestWithContext(ctx, http.MethodPost, apiBase+"/v1/tokens/exchange", strings.NewReader(string(body)))
343355
if err != nil {
344356
return "", err
345357
}
346358
req2.Header.Set("Content-Type", "application/json")
347-
req2.Header.Set("Authorization", "Bearer "+actor.ActorToken)
359+
req2.Header.Set("Authorization", "Bearer "+bearer)
348360
resp2, err := httpc.Do(req2)
349361
if err != nil {
350362
return "", err
@@ -410,33 +422,137 @@ func newAuthWhoamiCmd() *cobra.Command {
410422
Scopes []string `json:"scopes"`
411423
ClientID string `json:"client_id,omitempty"`
412424
}
413-
if err := req.GetJSON(cmd.Context(), "/tokeninfo", &info); err != nil {
414-
return err
415-
}
416-
fmt.Printf("sub: %s\n", info.Sub)
417-
if info.Email != nil && *info.Email != "" {
418-
fmt.Printf("email: %s\n", *info.Email)
419-
}
420-
fmt.Printf("principal: %s\n", info.PrincipalType)
421-
if info.OrgID != nil && *info.OrgID != "" {
422-
fmt.Printf("context: org\n")
423-
fmt.Printf("org_id: %s\n", *info.OrgID)
425+
if err := req.GetJSON(cmd.Context(), "/tokeninfo", &info); err == nil {
426+
printPrincipal(principalInfo{
427+
Sub: info.Sub,
428+
Email: deref(info.Email),
429+
PrincipalType: info.PrincipalType,
430+
OrgID: deref(info.OrgID),
431+
Scopes: info.Scopes,
432+
ClientID: info.ClientID,
433+
})
434+
return nil
424435
} else {
425-
fmt.Printf("context: personal\n")
436+
var apiErr *api.APIError
437+
if !errors.As(err, &apiErr) || apiErr.Status != http.StatusUnauthorized {
438+
return err
439+
}
426440
}
427-
if info.ClientID != "" {
428-
fmt.Printf("client_id: %s\n", info.ClientID)
441+
442+
// Auth cannot introspect cella-issued tokens, and current
443+
// auth deployments also reject sandbox-audience device
444+
// tokens on /tokeninfo. Confirm sandboxd accepts the bearer,
445+
// then print the identity claims embedded in the JWT.
446+
var ignored any
447+
if err := c.GetJSON(cmd.Context(), "/v1/sandboxes", &ignored); err != nil {
448+
return err
429449
}
430-
if len(info.Scopes) > 0 {
431-
fmt.Printf("scopes: %s\n", strings.Join(info.Scopes, " "))
450+
local, err := principalFromJWT(c.Token)
451+
if err != nil {
452+
return err
432453
}
454+
printPrincipal(local)
433455
return nil
434456
},
435457
}
436458
cmd.Flags().StringVar(&apiURL, "api-url", "", "override sandboxd base URL")
437459
return cmd
438460
}
439461

462+
type principalInfo struct {
463+
Sub string
464+
Email string
465+
PrincipalType string
466+
OrgID string
467+
Scopes []string
468+
ClientID string
469+
}
470+
471+
func printPrincipal(info principalInfo) {
472+
fmt.Printf("sub: %s\n", info.Sub)
473+
if info.Email != "" {
474+
fmt.Printf("email: %s\n", info.Email)
475+
}
476+
fmt.Printf("principal: %s\n", info.PrincipalType)
477+
if info.OrgID != "" {
478+
fmt.Printf("context: org\n")
479+
fmt.Printf("org_id: %s\n", info.OrgID)
480+
} else {
481+
fmt.Printf("context: personal\n")
482+
}
483+
if info.ClientID != "" {
484+
fmt.Printf("client_id: %s\n", info.ClientID)
485+
}
486+
if len(info.Scopes) > 0 {
487+
fmt.Printf("scopes: %s\n", strings.Join(info.Scopes, " "))
488+
}
489+
}
490+
491+
func deref(s *string) string {
492+
if s == nil {
493+
return ""
494+
}
495+
return *s
496+
}
497+
498+
func principalFromJWT(raw string) (principalInfo, error) {
499+
parts := strings.Split(raw, ".")
500+
if len(parts) < 2 {
501+
return principalInfo{}, errors.New("saved token is not a JWT")
502+
}
503+
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
504+
if err != nil {
505+
return principalInfo{}, fmt.Errorf("decode token payload: %w", err)
506+
}
507+
var claims map[string]any
508+
if err := json.Unmarshal(payload, &claims); err != nil {
509+
return principalInfo{}, fmt.Errorf("parse token payload: %w", err)
510+
}
511+
info := principalInfo{
512+
Sub: stringClaim(claims, "sub"),
513+
Email: stringClaim(claims, "email"),
514+
PrincipalType: stringClaim(claims, "principal_type"),
515+
OrgID: stringClaim(claims, "org_id"),
516+
Scopes: scopesClaim(claims),
517+
ClientID: stringClaim(claims, "client_id"),
518+
}
519+
if info.Sub == "" {
520+
return principalInfo{}, errors.New("saved token is missing sub")
521+
}
522+
if info.PrincipalType == "" {
523+
info.PrincipalType = "user"
524+
}
525+
return info, nil
526+
}
527+
528+
func stringClaim(claims map[string]any, key string) string {
529+
v, _ := claims[key].(string)
530+
return v
531+
}
532+
533+
func scopesClaim(claims map[string]any) []string {
534+
if scope, _ := claims["scope"].(string); scope != "" {
535+
return strings.Fields(scope)
536+
}
537+
raw, ok := claims["scp"]
538+
if !ok {
539+
return nil
540+
}
541+
switch v := raw.(type) {
542+
case string:
543+
return strings.Fields(v)
544+
case []any:
545+
out := make([]string, 0, len(v))
546+
for _, item := range v {
547+
if s, ok := item.(string); ok && s != "" {
548+
out = append(out, s)
549+
}
550+
}
551+
return out
552+
}
553+
return nil
554+
}
555+
440556
func newAuthLogoutCmd() *cobra.Command {
441557
return &cobra.Command{
442558
Use: "logout",

internal/commands/auth_test.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
package commands
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/base64"
7+
"encoding/json"
8+
"io"
9+
"net/http"
10+
"net/http/httptest"
11+
"os"
12+
"path/filepath"
13+
"strings"
14+
"testing"
15+
16+
"github.com/latere-ai/latere-cli/internal/api"
17+
)
18+
19+
func TestExchangeForCellaTokenFallsBackToDirectExchangeOnActorAudienceMismatch(t *testing.T) {
20+
authSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
21+
if r.URL.Path != "/actor-tokens" {
22+
http.NotFound(w, r)
23+
return
24+
}
25+
if got := r.Header.Get("Authorization"); got != "Bearer auth-token" {
26+
t.Errorf("auth Authorization = %q", got)
27+
}
28+
w.WriteHeader(http.StatusUnauthorized)
29+
_, _ = w.Write([]byte(`{"error":"unauthorized","message":"invalid token: audience mismatch"}`))
30+
}))
31+
defer authSrv.Close()
32+
33+
apiSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
34+
if r.URL.Path != "/v1/tokens/exchange" {
35+
http.NotFound(w, r)
36+
return
37+
}
38+
if got := r.Header.Get("Authorization"); got != "Bearer auth-token" {
39+
t.Errorf("cella Authorization = %q", got)
40+
}
41+
var body map[string]any
42+
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
43+
t.Errorf("decode body: %v", err)
44+
}
45+
if label, _ := body["label"].(string); !strings.HasPrefix(label, "CLI on ") {
46+
t.Errorf("label = %q", label)
47+
}
48+
_, _ = w.Write([]byte(`{"access_token":"cella-token"}`))
49+
}))
50+
defer apiSrv.Close()
51+
52+
got, err := exchangeForCellaToken(context.Background(), deviceFlowOpts{
53+
AuthURL: authSrv.URL,
54+
APIURL: apiSrv.URL,
55+
}, "auth-token")
56+
if err != nil {
57+
t.Fatalf("exchangeForCellaToken: %v", err)
58+
}
59+
if got != "cella-token" {
60+
t.Fatalf("token = %q, want cella-token", got)
61+
}
62+
}
63+
64+
func TestAuthWhoamiFallsBackToVerifiedJWTClaims(t *testing.T) {
65+
token := fakeJWT(t, map[string]any{
66+
"sub": "user-123",
67+
"email": "dev@example.com",
68+
"principal_type": "user",
69+
"org_id": "org-456",
70+
"client_id": "latere-cli",
71+
"scp": []string{"read:sandbox", "write:sandbox"},
72+
})
73+
tokenPath := filepath.Join(t.TempDir(), "token.json")
74+
t.Setenv("LATERE_TOKEN_FILE", tokenPath)
75+
if err := api.SaveToken(tokenPath, api.Token{AccessToken: token, TokenType: "Bearer"}); err != nil {
76+
t.Fatalf("SaveToken: %v", err)
77+
}
78+
79+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80+
if got := r.Header.Get("Authorization"); got != "Bearer "+token {
81+
t.Errorf("Authorization = %q", got)
82+
}
83+
switch r.URL.Path {
84+
case "/tokeninfo":
85+
w.WriteHeader(http.StatusUnauthorized)
86+
_, _ = w.Write([]byte(`{"error":"unauthorized","message":"invalid token: audience mismatch"}`))
87+
case "/v1/sandboxes":
88+
_, _ = w.Write([]byte(`[]`))
89+
default:
90+
http.NotFound(w, r)
91+
}
92+
}))
93+
defer srv.Close()
94+
95+
cmd := newAuthWhoamiCmd()
96+
cmd.SetArgs([]string{"--api-url", srv.URL})
97+
out, err := captureStdout(func() error { return cmd.Execute() })
98+
if err != nil {
99+
t.Fatalf("whoami: %v", err)
100+
}
101+
for _, want := range []string{
102+
"sub: user-123",
103+
"email: dev@example.com",
104+
"principal: user",
105+
"context: org",
106+
"org_id: org-456",
107+
"client_id: latere-cli",
108+
"scopes: read:sandbox write:sandbox",
109+
} {
110+
if !strings.Contains(out, want) {
111+
t.Fatalf("output missing %q:\n%s", want, out)
112+
}
113+
}
114+
}
115+
116+
func fakeJWT(t *testing.T, payload map[string]any) string {
117+
t.Helper()
118+
enc := func(v any) string {
119+
b, err := json.Marshal(v)
120+
if err != nil {
121+
t.Fatalf("marshal JWT part: %v", err)
122+
}
123+
return base64.RawURLEncoding.EncodeToString(b)
124+
}
125+
return enc(map[string]any{"alg": "none"}) + "." + enc(payload) + ".sig"
126+
}
127+
128+
func captureStdout(fn func() error) (string, error) {
129+
orig := os.Stdout
130+
r, w, err := os.Pipe()
131+
if err != nil {
132+
return "", err
133+
}
134+
os.Stdout = w
135+
runErr := fn()
136+
_ = w.Close()
137+
os.Stdout = orig
138+
var buf bytes.Buffer
139+
_, copyErr := io.Copy(&buf, r)
140+
_ = r.Close()
141+
if runErr != nil {
142+
return buf.String(), runErr
143+
}
144+
return buf.String(), copyErr
145+
}

0 commit comments

Comments
 (0)