Skip to content

Commit 2fb05f8

Browse files
committed
fix(api): Google OAuth provider
1 parent 0b83f5c commit 2fb05f8

6 files changed

Lines changed: 338 additions & 23 deletions

File tree

api_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ func TestServerWrapHandler(t *testing.T) {
286286
t.Errorf("Expected WWW-Authenticate header with error, got: %s", authHeader)
287287
}
288288

289+
if !strings.Contains(authHeader, "resource_metadata=") {
290+
t.Errorf("Expected WWW-Authenticate header to include resource_metadata, got: %s", authHeader)
291+
}
292+
289293
if !strings.Contains(w.Body.String(), "invalid_token") {
290294
t.Errorf("Expected JSON error response, got: %s", w.Body.String())
291295
}

handlers.go

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -610,37 +610,57 @@ func (h *OAuth2Handler) HandleToken(w http.ResponseWriter, r *http.Request) {
610610

611611
h.logger.Info("OAuth2: Token exchange successful")
612612

613-
// Build response
613+
response := h.buildTokenResponse(token)
614+
615+
// Send response
616+
w.Header().Set("Content-Type", "application/json")
617+
w.Header().Set("Cache-Control", "no-store")
618+
w.Header().Set("Pragma", "no-cache")
619+
w.WriteHeader(http.StatusOK)
620+
621+
if err := json.NewEncoder(w).Encode(response); err != nil {
622+
h.logger.Error("OAuth2: Failed to encode token response: %v", err)
623+
}
624+
}
625+
626+
// buildTokenResponse builds an RFC 6749-style token response with provider-specific behavior.
627+
func (h *OAuth2Handler) buildTokenResponse(token *oauth2.Token) map[string]interface{} {
628+
accessToken := token.AccessToken
629+
630+
if h.config != nil && h.config.Provider == "google" {
631+
if idToken, ok := token.Extra("id_token").(string); ok && idToken != "" {
632+
if !looksLikeJWT(token.AccessToken) {
633+
if h.logger != nil {
634+
h.logger.Info("OAuth2: Google provider detected opaque access token, using id_token as access_token for downstream compatibility")
635+
}
636+
accessToken = idToken
637+
}
638+
}
639+
}
640+
614641
response := map[string]interface{}{
615-
"access_token": token.AccessToken,
642+
"access_token": accessToken,
616643
"token_type": token.TokenType,
617644
"expires_in": int(time.Until(token.Expiry).Seconds()),
618645
}
619646

620-
// Add optional fields
621647
if token.RefreshToken != "" {
622648
response["refresh_token"] = token.RefreshToken
623649
}
624650

625-
// Add ID token if present
626651
if idToken, ok := token.Extra("id_token").(string); ok {
627652
response["id_token"] = idToken
628653
}
629654

630-
// Add scope if present
631655
if scope, ok := token.Extra("scope").(string); ok {
632656
response["scope"] = scope
633657
}
634658

635-
// Send response
636-
w.Header().Set("Content-Type", "application/json")
637-
w.Header().Set("Cache-Control", "no-store")
638-
w.Header().Set("Pragma", "no-cache")
639-
w.WriteHeader(http.StatusOK)
659+
return response
660+
}
640661

641-
if err := json.NewEncoder(w).Encode(response); err != nil {
642-
h.logger.Error("OAuth2: Failed to encode token response: %v", err)
643-
}
662+
func looksLikeJWT(token string) bool {
663+
return strings.Count(token, ".") == 2
644664
}
645665

646666
// showSuccessPage displays a success page after OAuth completion

handlers_token_response_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package oauth
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"golang.org/x/oauth2"
8+
)
9+
10+
func TestBuildTokenResponseGoogleOpaqueAccessTokenUsesIDToken(t *testing.T) {
11+
handler := &OAuth2Handler{
12+
config: &OAuth2Config{Provider: "google"},
13+
logger: &defaultLogger{},
14+
}
15+
16+
idToken := "header.payload.signature"
17+
token := (&oauth2.Token{
18+
AccessToken: "ya29.a0ARW5m7Opaque",
19+
TokenType: "Bearer",
20+
RefreshToken: "refresh-token",
21+
Expiry: time.Now().Add(time.Hour),
22+
}).WithExtra(map[string]interface{}{
23+
"id_token": idToken,
24+
"scope": "openid profile email",
25+
})
26+
27+
response := handler.buildTokenResponse(token)
28+
29+
if response["access_token"] != idToken {
30+
t.Fatalf("expected access_token to be mapped to id_token for Google opaque token")
31+
}
32+
33+
if response["id_token"] != idToken {
34+
t.Fatalf("expected id_token in response")
35+
}
36+
37+
if response["refresh_token"] != "refresh-token" {
38+
t.Fatalf("expected refresh_token in response")
39+
}
40+
}
41+
42+
func TestBuildTokenResponseGoogleJWTAccessTokenPreserved(t *testing.T) {
43+
handler := &OAuth2Handler{
44+
config: &OAuth2Config{Provider: "google"},
45+
logger: &defaultLogger{},
46+
}
47+
48+
jwtAccessToken := "jwt.access.token"
49+
token := (&oauth2.Token{
50+
AccessToken: jwtAccessToken,
51+
TokenType: "Bearer",
52+
Expiry: time.Now().Add(time.Hour),
53+
}).WithExtra(map[string]interface{}{
54+
"id_token": "id.token.value",
55+
})
56+
57+
response := handler.buildTokenResponse(token)
58+
59+
if response["access_token"] != jwtAccessToken {
60+
t.Fatalf("expected JWT access_token to remain unchanged")
61+
}
62+
}
63+
64+
func TestBuildTokenResponseNonGoogleAccessTokenPreserved(t *testing.T) {
65+
handler := &OAuth2Handler{
66+
config: &OAuth2Config{Provider: "okta"},
67+
logger: &defaultLogger{},
68+
}
69+
70+
token := (&oauth2.Token{
71+
AccessToken: "opaque-token-value",
72+
TokenType: "Bearer",
73+
Expiry: time.Now().Add(time.Hour),
74+
}).WithExtra(map[string]interface{}{
75+
"id_token": "id.token.value",
76+
})
77+
78+
response := handler.buildTokenResponse(token)
79+
80+
if response["access_token"] != "opaque-token-value" {
81+
t.Fatalf("expected non-Google access_token to remain unchanged")
82+
}
83+
}

oauth.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,13 @@ func (s *Server) WrapHandler(next http.Handler) http.Handler {
263263
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
264264
authHeader := r.Header.Get("Authorization")
265265
if authHeader == "" || len(authHeader) < 7 || authHeader[:7] != "Bearer " {
266-
s.logger.Info("OAuth: No bearer token provided, returning 401 with discovery info")
266+
s.logger.Debug("OAuth: No bearer token provided, returning 401 with discovery info")
267267

268268
metadataURL := s.GetProtectedResourceMetadataURL()
269-
w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_token", error_description="Missing or invalid access token"`)
270-
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL))
269+
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
270+
`Bearer realm="OAuth", error="invalid_token", error_description="Missing or invalid access token", resource_metadata="%s"`,
271+
metadataURL,
272+
))
271273
w.Header().Set("Content-Type", "application/json")
272274
w.WriteHeader(http.StatusUnauthorized)
273275

@@ -287,8 +289,10 @@ func (s *Server) WrapHandler(next http.Handler) http.Handler {
287289
s.logger.Info("OAuth: Token validation failed: %v", err)
288290

289291
metadataURL := s.GetProtectedResourceMetadataURL()
290-
w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed"`)
291-
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL))
292+
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
293+
`Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed", resource_metadata="%s"`,
294+
metadataURL,
295+
))
292296
w.Header().Set("Content-Type", "application/json")
293297
w.WriteHeader(http.StatusUnauthorized)
294298

provider/provider.go

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ package provider
33
import (
44
"context"
55
"crypto/tls"
6+
"encoding/json"
67
"fmt"
8+
"io"
79
"net/http"
10+
"net/url"
811
"strings"
912
"sync"
1013
"time"
@@ -54,13 +57,17 @@ type HMACValidator struct {
5457

5558
// OIDCValidator validates JWT tokens using OIDC/JWKS (Okta, Google, Azure)
5659
type OIDCValidator struct {
57-
verifier *oidc.IDTokenVerifier
58-
provider *oidc.Provider
59-
audience string
60-
TokenValidators []func(claims jwt.MapClaims) error
61-
logger Logger
60+
verifier *oidc.IDTokenVerifier
61+
provider *oidc.Provider
62+
audience string
63+
providerName string
64+
skipAudienceCheck bool
65+
TokenValidators []func(claims jwt.MapClaims) error
66+
logger Logger
6267
}
6368

69+
var googleTokenInfoURL = "https://oauth2.googleapis.com/tokeninfo"
70+
6471
// Initialize sets up the HMAC validator with JWT secret and audience
6572
func (v *HMACValidator) Initialize(cfg *Config) error {
6673
v.secretOnce.Do(func() {
@@ -173,6 +180,8 @@ func (v *OIDCValidator) Initialize(cfg *Config) error {
173180
v.logger = &noOpLogger{}
174181
}
175182
v.audience = cfg.Audience
183+
v.providerName = strings.ToLower(cfg.Provider)
184+
v.skipAudienceCheck = cfg.SkipAudienceCheck
176185

177186
// Use standard library context with timeout
178187
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
@@ -226,14 +235,22 @@ func (v *OIDCValidator) Initialize(cfg *Config) error {
226235
func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (*User, error) {
227236
// Remove Bearer prefix if present
228237
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
238+
tokenString = strings.TrimSpace(tokenString)
229239

230240
// Use incoming context with timeout for OIDC provider call
231241
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
232242
defer cancel()
233243

244+
if v.providerName == "google" && !looksLikeJWT(tokenString) {
245+
return v.validateGoogleOpaqueToken(ctx, tokenString)
246+
}
247+
234248
// go-oidc handles RSA signature validation, JWKS fetching, and key rotation
235249
idToken, err := v.verifier.Verify(ctx, tokenString)
236250
if err != nil {
251+
if v.providerName == "google" && isMalformedJWTError(err) {
252+
return v.validateGoogleOpaqueToken(ctx, tokenString)
253+
}
237254
return nil, fmt.Errorf("token verification failed: %w", err)
238255
}
239256

@@ -277,6 +294,79 @@ func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (
277294
}, nil
278295
}
279296

297+
func (v *OIDCValidator) validateGoogleOpaqueToken(ctx context.Context, tokenString string) (*User, error) {
298+
endpoint := fmt.Sprintf("%s?access_token=%s", googleTokenInfoURL, url.QueryEscape(tokenString))
299+
300+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
301+
if err != nil {
302+
return nil, fmt.Errorf("failed to create google tokeninfo request: %w", err)
303+
}
304+
305+
client := &http.Client{Timeout: 10 * time.Second}
306+
resp, err := client.Do(req)
307+
if err != nil {
308+
return nil, fmt.Errorf("google tokeninfo request failed: %w", err)
309+
}
310+
defer func() { _ = resp.Body.Close() }()
311+
312+
body, err := io.ReadAll(resp.Body)
313+
if err != nil {
314+
return nil, fmt.Errorf("failed reading google tokeninfo response: %w", err)
315+
}
316+
317+
if resp.StatusCode != http.StatusOK {
318+
return nil, fmt.Errorf("google tokeninfo validation failed: status %d", resp.StatusCode)
319+
}
320+
321+
var claims map[string]interface{}
322+
if err := json.Unmarshal(body, &claims); err != nil {
323+
return nil, fmt.Errorf("failed parsing google tokeninfo response: %w", err)
324+
}
325+
326+
return v.userFromGoogleTokenInfoClaims(claims)
327+
}
328+
329+
func (v *OIDCValidator) userFromGoogleTokenInfoClaims(claims map[string]interface{}) (*User, error) {
330+
aud, _ := claims["aud"].(string)
331+
if !v.skipAudienceCheck {
332+
if aud == "" {
333+
return nil, fmt.Errorf("missing audience claim")
334+
}
335+
if aud != v.audience {
336+
return nil, fmt.Errorf("invalid audience: expected %s, got %s", v.audience, aud)
337+
}
338+
}
339+
340+
subject, _ := claims["sub"].(string)
341+
if subject == "" {
342+
return nil, fmt.Errorf("missing subject in token")
343+
}
344+
345+
email, _ := claims["email"].(string)
346+
username := email
347+
if username == "" {
348+
username = subject
349+
}
350+
351+
return &User{
352+
Subject: subject,
353+
Username: username,
354+
Email: email,
355+
}, nil
356+
}
357+
358+
func looksLikeJWT(token string) bool {
359+
return strings.Count(token, ".") == 2
360+
}
361+
362+
func isMalformedJWTError(err error) bool {
363+
if err == nil {
364+
return false
365+
}
366+
msg := err.Error()
367+
return strings.Contains(msg, "malformed jwt") || strings.Contains(msg, "compact JWS format must have three parts")
368+
}
369+
280370
// validateAudience validates the audience claim matches the expected value for OIDC tokens
281371
func (v *OIDCValidator) validateAudience(claims jwt.MapClaims) error {
282372
// Extract audience claim (can be string or []string)

0 commit comments

Comments
 (0)