Skip to content

Commit dcd76c3

Browse files
committed
support refresh token
1 parent c64f389 commit dcd76c3

1 file changed

Lines changed: 61 additions & 37 deletions

File tree

handlers.go

Lines changed: 61 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -524,56 +524,80 @@ func (h *OAuth2Handler) HandleToken(w http.ResponseWriter, r *http.Request) {
524524
// Extract parameters
525525
grantType := r.FormValue("grant_type")
526526
code := r.FormValue("code")
527+
refreshToken := r.FormValue("refresh_token")
527528
clientRedirectURI := r.FormValue("redirect_uri")
528529
clientID := r.FormValue("client_id")
529530
codeVerifier := r.FormValue("code_verifier")
530531

531532
h.logger.Info("OAuth2: Token request - grant_type: %s, client_id: %s, redirect_uri: %s, code: %s",
532533
grantType, clientID, clientRedirectURI, truncateString(code, 10))
533534

534-
// Validate parameters
535-
if code == "" {
536-
h.logger.Error("OAuth2: Missing authorization code")
537-
http.Error(w, "Missing authorization code", http.StatusBadRequest)
538-
return
539-
}
535+
var token *oauth2.Token
536+
var err error
540537

541-
if grantType != "authorization_code" {
542-
h.logger.Error("OAuth2: Unsupported grant type: %s", grantType)
543-
http.Error(w, "Unsupported grant type", http.StatusBadRequest)
544-
return
545-
}
538+
switch grantType {
539+
case "authorization_code":
540+
// Validate parameters for authorization_code flow
541+
if code == "" {
542+
h.logger.Error("OAuth2: Missing authorization code")
543+
http.Error(w, "Missing authorization code", http.StatusBadRequest)
544+
return
545+
}
546546

547-
// Set redirect URI for token exchange
548-
redirectURI := clientRedirectURI
549-
if h.config.RedirectURIs != "" && !strings.Contains(h.config.RedirectURIs, ",") {
550-
redirectURI = strings.TrimSpace(h.config.RedirectURIs)
551-
h.logger.Info("OAuth2: Token exchange using fixed redirect URI: %s", redirectURI)
552-
}
547+
// Set redirect URI for token exchange
548+
redirectURI := clientRedirectURI
549+
if h.config.RedirectURIs != "" && !strings.Contains(h.config.RedirectURIs, ",") {
550+
redirectURI = strings.TrimSpace(h.config.RedirectURIs)
551+
h.logger.Info("OAuth2: Token exchange using fixed redirect URI: %s", redirectURI)
552+
}
553553

554-
h.oauth2Config.RedirectURL = redirectURI
554+
h.oauth2Config.RedirectURL = redirectURI
555555

556-
// For PKCE, we need to manually add the code_verifier to the token exchange
557-
// Since oauth2 library doesn't support PKCE directly, we'll use a custom approach
558-
ctx := context.Background()
559-
560-
// Create custom HTTP client for token exchange with PKCE
561-
if codeVerifier != "" {
562-
// Create a custom client that adds code_verifier to the token request
563-
customClient := &http.Client{
564-
Transport: &pkceTransport{
565-
base: http.DefaultTransport,
566-
codeVerifier: codeVerifier,
567-
},
556+
// For PKCE, we need to manually add the code_verifier to the token exchange
557+
// Since oauth2 library doesn't support PKCE directly, we'll use a custom approach
558+
ctx := context.Background()
559+
560+
// Create custom HTTP client for token exchange with PKCE
561+
if codeVerifier != "" {
562+
// Create a custom client that adds code_verifier to the token request
563+
customClient := &http.Client{
564+
Transport: &pkceTransport{
565+
base: http.DefaultTransport,
566+
codeVerifier: codeVerifier,
567+
},
568+
}
569+
ctx = context.WithValue(ctx, oauth2.HTTPClient, customClient)
568570
}
569-
ctx = context.WithValue(ctx, oauth2.HTTPClient, customClient)
570-
}
571571

572-
// Exchange code for tokens
573-
token, err := h.oauth2Config.Exchange(ctx, code)
574-
if err != nil {
575-
h.logger.Error("OAuth2: Token exchange failed: %v", err)
576-
http.Error(w, "Token exchange failed", http.StatusInternalServerError)
572+
// Exchange code for tokens
573+
token, err = h.oauth2Config.Exchange(ctx, code)
574+
if err != nil {
575+
h.logger.Error("OAuth2: Token exchange failed: %v", err)
576+
http.Error(w, "Token exchange failed", http.StatusInternalServerError)
577+
return
578+
}
579+
case "refresh_token":
580+
// Validate parameters for refresh_token flow
581+
if refreshToken == "" {
582+
h.logger.Error("OAuth2: Missing refresh token")
583+
http.Error(w, "Missing refresh token", http.StatusBadRequest)
584+
return
585+
}
586+
587+
ctx := context.Background()
588+
src := h.oauth2Config.TokenSource(ctx, &oauth2.Token{
589+
RefreshToken: refreshToken,
590+
})
591+
592+
token, err = src.Token()
593+
if err != nil {
594+
h.logger.Error("OAuth2: Refresh token exchange failed: %v", err)
595+
http.Error(w, "Token refresh failed", http.StatusBadGateway)
596+
return
597+
}
598+
default:
599+
h.logger.Error("OAuth2: Unsupported grant type: %s", grantType)
600+
http.Error(w, "Unsupported grant type", http.StatusBadRequest)
577601
return
578602
}
579603

0 commit comments

Comments
 (0)