Skip to content

Commit 3f6375a

Browse files
fix(auth): enforce remote manual auth state
1 parent 438da4b commit 3f6375a

5 files changed

Lines changed: 240 additions & 13 deletions

File tree

internal/cmd/auth.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,8 @@ type AuthAddCmd struct {
484484
Manual bool `name:"manual" help:"Browserless auth flow (paste redirect URL)"`
485485
Remote bool `name:"remote" help:"Remote/server-friendly manual flow (print URL, then exchange code)"`
486486
Step int `name:"step" help:"Remote auth step: 1=print URL, 2=exchange code"`
487-
AuthURL string `name:"auth-url" help:"Redirect URL from browser (manual flow)"`
488-
AuthCode string `name:"auth-code" help:"Authorization code from browser (manual flow; skips state check)"`
487+
AuthURL string `name:"auth-url" help:"Redirect URL from browser (manual flow; required for --remote --step 2)"`
488+
AuthCode string `name:"auth-code" help:"Authorization code from browser (manual flow; skips state check; not valid with --remote)"`
489489
Timeout time.Duration `name:"timeout" help:"Authorization timeout (manual flows default to 5m)"`
490490
ForceConsent bool `name:"force-consent" help:"Force consent screen to obtain a refresh token"`
491491
ServicesCSV string `name:"services" help:"Services to authorize: user|all or comma-separated ${auth_services} (Keep uses service account: gog auth service-account set)" default:"user"`
@@ -567,12 +567,15 @@ func (c *AuthAddCmd) Run(ctx context.Context) error {
567567
}
568568
u.Out().Printf("auth_url\t%s", result.URL)
569569
u.Out().Printf("state_reused\t%t", result.StateReused)
570-
u.Err().Println("Run again with --remote --step 2 --auth-url <redirect-url> (or --auth-code <code>)")
570+
u.Err().Println("Run again with --remote --step 2 --auth-url <redirect-url>")
571571
return nil
572572
case 2:
573573
if authURL == "" && authCode == "" {
574574
return usage("remote step 2 requires --auth-url or --auth-code")
575575
}
576+
if authCode != "" {
577+
return usage("remote step 2 requires --auth-url (state check is mandatory)")
578+
}
576579
}
577580
}
578581

@@ -595,6 +598,7 @@ func (c *AuthAddCmd) Run(ctx context.Context) error {
595598
Client: client,
596599
AuthURL: authURL,
597600
AuthCode: authCode,
601+
RequireState: c.Remote,
598602
})
599603
if err != nil {
600604
return err

internal/cmd/auth_add_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,81 @@ func TestAuthAddCmd_RemoteStep1_PrintsAuthURL(t *testing.T) {
528528
}
529529
}
530530

531+
func TestAuthAddCmd_RemoteStep2_RejectsAuthCode(t *testing.T) {
532+
err := Execute([]string{
533+
"auth",
534+
"add",
535+
"user@example.com",
536+
"--services",
537+
"gmail",
538+
"--remote",
539+
"--step",
540+
"2",
541+
"--auth-code",
542+
"abc123",
543+
})
544+
if err == nil {
545+
t.Fatalf("expected error")
546+
}
547+
var ee *ExitError
548+
if !errors.As(err, &ee) || ee.Code != 2 {
549+
t.Fatalf("expected exit code 2, got %T %#v", err, err)
550+
}
551+
if !strings.Contains(err.Error(), "remote step 2 requires --auth-url") {
552+
t.Fatalf("unexpected error: %v", err)
553+
}
554+
}
555+
556+
func TestAuthAddCmd_RemoteStep2_PassesAuthURL(t *testing.T) {
557+
origAuth := authorizeGoogle
558+
origOpen := openSecretsStore
559+
origKeychain := ensureKeychainAccess
560+
origFetch := fetchAuthorizedEmail
561+
t.Cleanup(func() {
562+
authorizeGoogle = origAuth
563+
openSecretsStore = origOpen
564+
ensureKeychainAccess = origKeychain
565+
fetchAuthorizedEmail = origFetch
566+
})
567+
568+
ensureKeychainAccess = func() error { return nil }
569+
openSecretsStore = func() (secrets.Store, error) { return newMemSecretsStore(), nil }
570+
571+
var gotOpts googleauth.AuthorizeOptions
572+
authorizeGoogle = func(ctx context.Context, opts googleauth.AuthorizeOptions) (string, error) {
573+
gotOpts = opts
574+
return "rt", nil
575+
}
576+
fetchAuthorizedEmail = func(context.Context, string, string, []string, time.Duration) (string, error) {
577+
return "user@example.com", nil
578+
}
579+
580+
if err := Execute([]string{
581+
"auth",
582+
"add",
583+
"user@example.com",
584+
"--services",
585+
"gmail",
586+
"--remote",
587+
"--step",
588+
"2",
589+
"--auth-url",
590+
"http://localhost:1/?code=abc&state=state123",
591+
}); err != nil {
592+
t.Fatalf("Execute: %v", err)
593+
}
594+
595+
if !gotOpts.Manual {
596+
t.Fatalf("expected manual auth in remote step 2")
597+
}
598+
if !gotOpts.RequireState {
599+
t.Fatalf("expected require state in remote step 2")
600+
}
601+
if gotOpts.AuthURL == "" {
602+
t.Fatalf("expected auth URL to be passed through")
603+
}
604+
}
605+
531606
func TestAuthAddCmd_AuthCode_PassesThrough(t *testing.T) {
532607
origAuth := authorizeGoogle
533608
origOpen := openSecretsStore

internal/googleauth/manual_state.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,40 @@ func loadManualState(client string, scopes []string, forceConsent bool) (string,
7373
return st.State, true, nil
7474
}
7575

76+
func loadManualStateStrict(client string, scopes []string, forceConsent bool) (string, error) {
77+
path, err := manualStatePathFn()
78+
if err != nil {
79+
return "", err
80+
}
81+
82+
data, err := os.ReadFile(path) //nolint:gosec // config path
83+
if err != nil {
84+
if os.IsNotExist(err) {
85+
return "", errManualStateMissing
86+
}
87+
return "", fmt.Errorf("read manual auth state: %w", err)
88+
}
89+
90+
var st manualState
91+
if err := json.Unmarshal(data, &st); err != nil {
92+
_ = os.Remove(path)
93+
return "", errManualStateMissing
94+
}
95+
if st.State == "" {
96+
_ = os.Remove(path)
97+
return "", errManualStateMissing
98+
}
99+
if manualStateNowFn().Sub(st.CreatedAt) > manualStateTTL {
100+
_ = os.Remove(path)
101+
return "", errManualStateMissing
102+
}
103+
if st.Client != client || st.ForceConsent != forceConsent || !scopesEqual(st.Scopes, scopes) {
104+
return "", errManualStateMismatch
105+
}
106+
107+
return st.State, nil
108+
}
109+
76110
func saveManualState(client string, scopes []string, forceConsent bool, state string) error {
77111
path, err := manualStatePathFn()
78112
if err != nil {

internal/googleauth/oauth_flow.go

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ type AuthorizeOptions struct {
3131
Client string
3232
AuthCode string
3333
AuthURL string
34+
RequireState bool
3435
}
3536

3637
type ManualAuthURLResult struct {
@@ -58,12 +59,15 @@ var (
5859
)
5960

6061
var (
61-
errAuthorization = errors.New("authorization error")
62-
errMissingCode = errors.New("missing code")
63-
errMissingScopes = errors.New("missing scopes")
64-
errNoCodeInURL = errors.New("no code found in URL")
65-
errNoRefreshToken = errors.New("no refresh token received; try again with --force-consent")
66-
errStateMismatch = errors.New("state mismatch")
62+
errAuthorization = errors.New("authorization error")
63+
errMissingCode = errors.New("missing code")
64+
errMissingState = errors.New("missing state in redirect URL")
65+
errMissingScopes = errors.New("missing scopes")
66+
errNoCodeInURL = errors.New("no code found in URL")
67+
errNoRefreshToken = errors.New("no refresh token received; try again with --force-consent")
68+
errManualStateMissing = errors.New("manual auth state missing; run remote step 1 again")
69+
errManualStateMismatch = errors.New("manual auth state mismatch; run remote step 1 again")
70+
errStateMismatch = errors.New("state mismatch")
6771
)
6872

6973
func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) {
@@ -106,16 +110,29 @@ func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) {
106110
if parseErr != nil {
107111
return "", parseErr
108112
}
113+
if opts.RequireState && gotState == "" {
114+
return "", errMissingState
115+
}
109116
}
110117
if strings.TrimSpace(code) == "" {
111118
return "", errMissingCode
112119
}
113120

114121
if gotState != "" {
115-
if cachedState, ok, cacheErr := loadManualState(opts.Client, opts.Scopes, opts.ForceConsent); cacheErr != nil {
116-
return "", cacheErr
117-
} else if ok && gotState != cachedState {
118-
return "", errStateMismatch
122+
if opts.RequireState {
123+
cachedState, cacheErr := loadManualStateStrict(opts.Client, opts.Scopes, opts.ForceConsent)
124+
if cacheErr != nil {
125+
return "", cacheErr
126+
}
127+
if gotState != cachedState {
128+
return "", errManualStateMismatch
129+
}
130+
} else {
131+
if cachedState, ok, cacheErr := loadManualState(opts.Client, opts.Scopes, opts.ForceConsent); cacheErr != nil {
132+
return "", cacheErr
133+
} else if ok && gotState != cachedState {
134+
return "", errStateMismatch
135+
}
119136
}
120137
}
121138

internal/googleauth/oauth_flow_authorize_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,103 @@ func TestAuthorize_Manual_AuthCode(t *testing.T) {
314314
}
315315
}
316316

317+
func TestAuthorize_Manual_AuthURL_RequireStateMissing(t *testing.T) {
318+
origRead := readClientCredentials
319+
origEndpoint := oauthEndpoint
320+
321+
t.Cleanup(func() {
322+
readClientCredentials = origRead
323+
oauthEndpoint = origEndpoint
324+
})
325+
useTempManualStatePath(t)
326+
327+
readClientCredentials = func(string) (config.ClientCredentials, error) {
328+
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
329+
}
330+
oauthEndpoint = oauth2EndpointForTest("http://example.com")
331+
332+
_, err := Authorize(context.Background(), AuthorizeOptions{
333+
Scopes: []string{"s1"},
334+
Manual: true,
335+
AuthURL: "http://localhost:1/?code=abc",
336+
RequireState: true,
337+
Client: "default",
338+
Timeout: 2 * time.Second,
339+
})
340+
if err == nil {
341+
t.Fatalf("expected error")
342+
}
343+
if !errors.Is(err, errMissingState) {
344+
t.Fatalf("expected missing state error, got: %v", err)
345+
}
346+
}
347+
348+
func TestAuthorize_Manual_AuthURL_RequireStateMissingCache(t *testing.T) {
349+
origRead := readClientCredentials
350+
origEndpoint := oauthEndpoint
351+
352+
t.Cleanup(func() {
353+
readClientCredentials = origRead
354+
oauthEndpoint = origEndpoint
355+
})
356+
useTempManualStatePath(t)
357+
358+
readClientCredentials = func(string) (config.ClientCredentials, error) {
359+
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
360+
}
361+
oauthEndpoint = oauth2EndpointForTest("http://example.com")
362+
363+
_, err := Authorize(context.Background(), AuthorizeOptions{
364+
Scopes: []string{"s1"},
365+
Manual: true,
366+
AuthURL: "http://localhost:1/?code=abc&state=state123",
367+
RequireState: true,
368+
Client: "default",
369+
Timeout: 2 * time.Second,
370+
})
371+
if err == nil {
372+
t.Fatalf("expected error")
373+
}
374+
if !errors.Is(err, errManualStateMissing) {
375+
t.Fatalf("expected manual state missing error, got: %v", err)
376+
}
377+
}
378+
379+
func TestAuthorize_Manual_AuthURL_RequireStateMismatch(t *testing.T) {
380+
origRead := readClientCredentials
381+
origEndpoint := oauthEndpoint
382+
383+
t.Cleanup(func() {
384+
readClientCredentials = origRead
385+
oauthEndpoint = origEndpoint
386+
})
387+
useTempManualStatePath(t)
388+
389+
readClientCredentials = func(string) (config.ClientCredentials, error) {
390+
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
391+
}
392+
oauthEndpoint = oauth2EndpointForTest("http://example.com")
393+
394+
if err := saveManualState("default", []string{"s1"}, false, "state123"); err != nil {
395+
t.Fatalf("save manual state: %v", err)
396+
}
397+
398+
_, err := Authorize(context.Background(), AuthorizeOptions{
399+
Scopes: []string{"s1"},
400+
Manual: true,
401+
AuthURL: "http://localhost:1/?code=abc&state=DIFFERENT",
402+
RequireState: true,
403+
Client: "default",
404+
Timeout: 2 * time.Second,
405+
})
406+
if err == nil {
407+
t.Fatalf("expected error")
408+
}
409+
if !errors.Is(err, errManualStateMismatch) {
410+
t.Fatalf("expected manual state mismatch error, got: %v", err)
411+
}
412+
}
413+
317414
func TestAuthorize_ServerFlow_Success(t *testing.T) {
318415
origRead := readClientCredentials
319416
origEndpoint := oauthEndpoint

0 commit comments

Comments
 (0)