Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions server/refreshhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,16 @@ func (s *Server) updateOfflineSession(ctx context.Context, refresh *storage.Refr
return nil
}

// updateRefreshToken updates refresh token and offline session in the storage
// updateRefreshToken updates refresh token and offline session in the storage.
// Connector refresh is guarded by a per-refresh-ID mutex so only one concurrent
// caller hits the IdP.
func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (*internal.RefreshToken, connector.Identity, *refreshError) {
var rerr *refreshError

newToken := &internal.RefreshToken{
Token: rCtx.requestToken.Token,
RefreshId: rCtx.requestToken.RefreshId,
}

lastUsed := s.now()

ident := connector.Identity{
Expand All @@ -250,6 +251,31 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (
Email: rCtx.storageToken.Claims.Email,
EmailVerified: rCtx.storageToken.Claims.EmailVerified,
Groups: rCtx.storageToken.Claims.Groups,
ConnectorData: rCtx.connectorData,
}

rotationEnabled := s.refreshTokenPolicy.RotationEnabled()
reusingAllowed := s.refreshTokenPolicy.AllowedToReuse(rCtx.storageToken.LastUsed)
needConnectorRefresh := rotationEnabled && !reusingAllowed

if needConnectorRefresh {
// Serialize concurrent refreshes for the same refresh ID.
lock := s.getRefreshLock(rCtx.storageToken.ID)
lock.Lock()
s.logger.Debug("Acquired refresh lock", "refreshID", rCtx.storageToken.ID)
defer func() {
lock.Unlock()
s.logger.Debug("Released refresh lock", "refreshID", rCtx.storageToken.ID)
}()

// Double-check if another goroutine already refreshed while we waited:
if !s.refreshTokenPolicy.AllowedToReuse(rCtx.storageToken.LastUsed) {
var rerr *refreshError
ident, rerr = s.refreshWithConnector(ctx, rCtx, ident)
if rerr != nil {
return nil, ident, rerr
}
}
}

refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
Expand Down Expand Up @@ -293,14 +319,6 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (
// ConnectorData has been moved to OfflineSession
old.ConnectorData = nil

// Call only once if there is a request which is not in the reuse interval.
// This is required to avoid multiple calls to the external IdP for concurrent requests.
// Dex will call the connector's Refresh method only once if request is not in reuse interval.
ident, rerr = s.refreshWithConnector(ctx, rCtx, ident)
if rerr != nil {
return old, rerr
}

// Update the claims of the refresh token.
//
// UserID intentionally ignored for now.
Expand Down
8 changes: 8 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ type Server struct {
deviceRequestsValidFor time.Duration

refreshTokenPolicy *RefreshTokenPolicy
// mutex to refresh the same token only once for concurrent requests
refreshLocks sync.Map
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it work when Dex is running in the HA mode, e.g., two instances of Dex are serving clients at the same time? For the Kubernetes storage, we did locking on the Kubernetes side, but in in-memory lock doesn't help.


logger *slog.Logger
}
Expand Down Expand Up @@ -758,6 +760,12 @@ func (s *Server) getConnector(ctx context.Context, id string) (Connector, error)
return conn, nil
}

// getRefreshLock returns a per-refresh-ID mutex.
func (s *Server) getRefreshLock(refreshID string) *sync.Mutex {
m, _ := s.refreshLocks.LoadOrStore(refreshID, &sync.Mutex{})
return m.(*sync.Mutex)
}

type logRequestKey string

const (
Expand Down