diff --git a/README.md b/README.md index c212c8d82..5c809a758 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,11 @@ +# Signal Louketo Fork + +Signal AI specific additions to Louketo Proxy + +## Releasing + +To release changes follow the instructions in [the releasing documentation](docs/release.md). This uses GitHub Actions. + # EOL notice Louketo Proxy reached end of line in November 21, 2020. This means that we no longer support, or update it. The details are available [here](https://www.keycloak.org/2020/08/sunsetting-louketo-project.adoc). diff --git a/config.go b/config.go index 080af6ad3..14d6d1474 100644 --- a/config.go +++ b/config.go @@ -43,6 +43,7 @@ func newDefaultConfig() *Config { EnableDefaultDeny: true, EnableSessionCookies: true, EnableTokenHeader: true, + EnableCSRFCheck: true, HTTPOnlyCookie: true, Headers: make(map[string]string), LetsEncryptCacheDir: "./cache/", diff --git a/cookies.go b/cookies.go index bcd7dd249..9e2ef7215 100644 --- a/cookies.go +++ b/cookies.go @@ -17,6 +17,7 @@ package main import ( "encoding/base64" + "encoding/json" "net/http" "strconv" "strings" @@ -119,8 +120,13 @@ func (r *oauthProxy) dropRefreshTokenCookie(req *http.Request, w http.ResponseWr r.dropCookieWithChunks(req, w, r.config.CookieRefreshName, value, duration) } +type StateParameter struct { + Token string `json:"token"` + Url string `json:"url"` +} + // writeStateParameterCookie sets a state parameter cookie into the response -func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.ResponseWriter) string { +func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.ResponseWriter) (string, error) { uuid, err := uuid.NewV4() if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -128,7 +134,12 @@ func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.Respons requestURI := base64.StdEncoding.EncodeToString([]byte(req.URL.RequestURI())) r.dropCookie(w, req.Host, requestURICookie, requestURI, 0) r.dropCookie(w, req.Host, requestStateCookie, uuid.String(), 0) - return uuid.String() + + stateParam := StateParameter{Token: uuid.String(), + Url: req.URL.RequestURI()} + output, err := json.Marshal(stateParam) + + return string(output), err } // clearAllCookies is just a helper function for the below diff --git a/doc.go b/doc.go index c42c53f0e..1aad5f2a7 100644 --- a/doc.go +++ b/doc.go @@ -252,6 +252,8 @@ type Config struct { LocalhostMetrics bool `json:"localhost-metrics" yaml:"localhost-metrics" usage:"enforces the metrics page can only been requested from 127.0.0.1"` // EnableCompression enables gzip compression for response EnableCompression bool `json:"enable-compression" yaml:"enable-compression" usage:"enable gzip compression for response"` + // EnableCSRFCheck enables CSRF protection between authorise/callback using cookies and state parameter + EnableCSRFCheck bool `json:"enable-csrf-check" yaml:"enable-csrf-check" usage:"enable crsf check between authorise and callback"` // AccessTokenDuration is default duration applied to the access token cookie AccessTokenDuration time.Duration `json:"access-token-duration" yaml:"access-token-duration" usage:"fallback cookie duration for the access token when using refresh tokens"` diff --git a/forwarding.go b/forwarding.go index a9547aba2..1ab13c6e1 100644 --- a/forwarding.go +++ b/forwarding.go @@ -88,8 +88,9 @@ func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler { // forwardProxyHandler is responsible for signing outbound requests func (r *oauthProxy) forwardProxyHandler() func(*http.Request, *http.Response) { ctx := context.Background() + fmt.Printf("%+v\n", r.config.RedirectionURL) conf := r.newOAuth2Config(r.config.RedirectionURL) - + fmt.Printf("%+v\n", conf) // the loop state var state struct { // the access token diff --git a/go.mod b/go.mod index 9e0b2a3b7..9d9e70db6 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/louketo/louketo-proxy require ( + github.com/MicahParks/keyfunc v1.9.0 github.com/PuerkitoBio/purell v1.1.0 github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f @@ -14,6 +15,7 @@ require ( github.com/garyburd/redigo v1.6.0 // indirect github.com/go-chi/chi v3.3.3+incompatible github.com/gofrs/uuid v3.3.0+incompatible + github.com/golang-jwt/jwt/v4 v4.5.1 github.com/jonboulle/clockwork v0.1.0 // indirect github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 // indirect github.com/onsi/ginkgo v1.8.0 // indirect diff --git a/go.sum b/go.sum index 437e413c2..ec5f29588 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/MicahParks/keyfunc v1.9.0 h1:lhKd5xrFHLNOWrDc4Tyb/Q1AJ4LCzQ48GVJyVIID3+o= +github.com/MicahParks/keyfunc v1.9.0/go.mod h1:IdnCilugA0O/99dW+/MkvlyrsX8+L8+x95xuVNtM5jw= github.com/PuerkitoBio/purell v1.1.0 h1:rmGxhojJlM0tuKtfdvliR84CFHljx9ag64t2xmVkjK4= github.com/PuerkitoBio/purell v1.1.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M= @@ -28,6 +30,9 @@ github.com/go-chi/chi v3.3.3+incompatible h1:KHkmBEMNkwKuK4FdQL7N2wOeB9jnIx7jR5w github.com/go-chi/chi v3.3.3+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= github.com/gofrs/uuid v3.3.0+incompatible h1:8K4tyRfvU1CYPgJsveYFQMhpFd/wXNM7iK6rR7UHz84= github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo= +github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= @@ -78,7 +83,6 @@ golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8U golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181011144130-49bb7cea24b1/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181201002055-351d144fa1fc h1:a3CU5tJYVj92DY2LaA1kUkrsqD5/3mLDhx2NcNqyW+0= golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= @@ -86,7 +90,6 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/handlers.go b/handlers.go index 87b55b73b..4f5866bfb 100644 --- a/handlers.go +++ b/handlers.go @@ -58,11 +58,23 @@ func (r *oauthProxy) getRedirectionURL(w http.ResponseWriter, req *http.Request) redirect = r.config.RedirectionURL } - state, _ := req.Cookie(requestStateCookie) - if state != nil && req.URL.Query().Get("state") != state.Value { - r.log.Error("state parameter mismatch") - w.WriteHeader(http.StatusForbidden) - return "" + if r.config.EnableCSRFCheck { + state, _ := req.Cookie(requestStateCookie) + + stateParameter := req.URL.Query().Get("state") + stateParam := StateParameter{} + if stateParameter != "" { + err := json.Unmarshal([]byte(stateParameter), &stateParam) + if err != nil { + r.log.Warn("failed to deserialise state parameter from json") + } + } + + if state != nil && stateParam.Token != state.Value { + r.log.Error("state parameter mismatch") + w.WriteHeader(http.StatusForbidden) + return "" + } } return fmt.Sprintf("%s%s", redirect, r.config.WithOAuthURI("callback")) } @@ -209,8 +221,17 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque // step: decode the request variable redirectURI := "/" - if req.URL.Query().Get("state") != "" { - if encodedRequestURI, _ := req.Cookie(requestURICookie); encodedRequestURI != nil { + stateParameter := req.URL.Query().Get("state") + if stateParameter != "" { + stateParam := StateParameter{} + err := json.Unmarshal([]byte(stateParameter), &stateParam) + if err != nil { + r.log.Warn("failed to deserialise state parameter from json") + } + + if stateParam.Url != "" { + redirectURI = stateParam.Url + } else if encodedRequestURI, _ := req.Cookie(requestURICookie); encodedRequestURI != nil { decoded, _ := base64.StdEncoding.DecodeString(encodedRequestURI.Value) redirectURI = string(decoded) } @@ -292,9 +313,9 @@ func (r *oauthProxy) loginHandler(w http.ResponseWriter, req *http.Request) { func emptyHandler(w http.ResponseWriter, req *http.Request) {} // logoutHandler performs a logout -// - if it's just a access token, the cookie is deleted -// - if the user has a refresh token, the token is invalidated by the provider -// - optionally, the user can be redirected by to a url +// - if it's just a access token, the cookie is deleted +// - if the user has a refresh token, the token is invalidated by the provider +// - optionally, the user can be redirected by to a url func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { // @check if the redirection is there var redirectURL string @@ -309,7 +330,7 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { } // @step: drop the access token - user, err := r.getIdentity(req) + user, err := r.getIdentityFromRequest(req) if err != nil { w.WriteHeader(http.StatusBadRequest) return @@ -413,7 +434,7 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { // expirationHandler checks if the token has expired func (r *oauthProxy) expirationHandler(w http.ResponseWriter, req *http.Request) { - user, err := r.getIdentity(req) + user, err := r.getIdentityFromRequest(req) if err != nil { w.WriteHeader(http.StatusUnauthorized) return @@ -428,7 +449,7 @@ func (r *oauthProxy) expirationHandler(w http.ResponseWriter, req *http.Request) // tokenHandler display access token to screen func (r *oauthProxy) tokenHandler(w http.ResponseWriter, req *http.Request) { - user, err := r.getIdentity(req) + user, err := r.getIdentityFromRequest(req) if err != nil { w.WriteHeader(http.StatusBadRequest) return diff --git a/handlers_test.go b/handlers_test.go index 60d1723b8..b2513ab80 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -16,7 +16,9 @@ limitations under the License. package main import ( + "encoding/json" "net/http" + "net/url" "testing" "time" ) @@ -282,6 +284,16 @@ func TestAuthorizationURL(t *testing.T) { newFakeProxy(nil).RunTests(t, requests) } +func getStateQueryParameter(t *testing.T, redirectUrl string) string { + stateParameter := StateParameter{Url: redirectUrl} + res, err := json.Marshal(stateParameter) + if err != nil { + t.Fatalf("Failed to marshall state parameter to string: %s", err) + return "" + } + return url.QueryEscape(string(res)) +} + func TestCallbackURL(t *testing.T) { cfg := newFakeKeycloakConfig() requests := []fakeRequest{ @@ -301,13 +313,13 @@ func TestCallbackURL(t *testing.T) { ExpectedCode: http.StatusSeeOther, }, { - URI: cfg.WithOAuthURI(callbackURL) + "?code=fake&state=/admin", + URI: cfg.WithOAuthURI(callbackURL) + "?code=fake&state=" + getStateQueryParameter(t, "/admin?some-param=true&some-other=false"), ExpectedCookies: map[string]string{cfg.CookieAccessName: ""}, - ExpectedLocation: "/", + ExpectedLocation: "/admin?some-param=true&some-other=false", ExpectedCode: http.StatusSeeOther, }, { - URI: cfg.WithOAuthURI(callbackURL) + "?code=fake&state=L2FkbWlu", + URI: cfg.WithOAuthURI(callbackURL) + "?code=fake&state=" + getStateQueryParameter(t, "/"), ExpectedCookies: map[string]string{cfg.CookieAccessName: ""}, ExpectedLocation: "/", ExpectedCode: http.StatusSeeOther, diff --git a/middleware.go b/middleware.go index 3d3bf5aca..1572f6b1d 100644 --- a/middleware.go +++ b/middleware.go @@ -169,7 +169,7 @@ func (r *oauthProxy) authenticationMiddleware() func(http.Handler) http.Handler return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { clientIP := req.RemoteAddr // grab the user identity from the request - user, err := r.getIdentity(req) + user, err := r.getIdentityFromRequest(req) if err != nil { r.log.Error("no session found in request, redirecting for authorization", zap.Error(err)) next.ServeHTTP(w, req.WithContext(r.redirectToAuthorization(w, req))) @@ -315,7 +315,13 @@ func (r *oauthProxy) authenticationMiddleware() func(http.Handler) http.Handler } // update the with the new access token and inject into the context - user.token = token + user, err = r.getIdentityFromToken(token, user.bearerToken) + if err != nil { + r.log.Error("regenerated token is invalid, redirecting for authorization", zap.Error(err)) + next.ServeHTTP(w, req.WithContext(r.redirectToAuthorization(w, req))) + return + } + scope.Identity = user ctx = context.WithValue(req.Context(), contextScopeName, scope) } } diff --git a/middleware_test.go b/middleware_test.go index 774bc629d..0b4cbb0bc 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -16,6 +16,7 @@ limitations under the License. package main import ( + "encoding/json" "fmt" "io/ioutil" "log" @@ -73,6 +74,8 @@ type fakeRequest struct { // advanced test cases ExpectedCookiesValidator map[string]func(string) bool + + ExpectedStateUrl string } type fakeProxy struct { @@ -223,11 +226,24 @@ func (f *fakeProxy) RunTests(t *testing.T, requests []fakeRequest) { if c.ExpectedCode != 0 { assert.Equal(t, c.ExpectedCode, status, "case %d, expected status code: %d, got: %d", i, c.ExpectedCode, status) } - if c.ExpectedLocation != "" { + if c.ExpectedLocation != "" || c.ExpectedStateUrl != "" { l, _ := url.Parse(resp.Header().Get("Location")) - assert.True(t, strings.Contains(l.String(), c.ExpectedLocation), "expected location to contain %s", l.String()) - if l.Query().Get("state") != "" { - state, err := uuid.FromString(l.Query().Get("state")) + if c.ExpectedLocation != "" { + assert.True(t, strings.Contains(l.String(), c.ExpectedLocation), "expected location to contain %s", l.String()) + } + stateStr := l.Query().Get("state") + if stateStr != "" { + stateParam := StateParameter{} + err := json.Unmarshal([]byte(stateStr), &stateParam) + if err != nil { + assert.Fail(t, "failed to deserialise state parameter from json, got: %s with error %s", stateStr, err) + } + + if c.ExpectedStateUrl != "" { + assert.Equal(t, c.ExpectedStateUrl, stateParam.Url, "expected state url to contain %s", stateParam.Url) + } + + state, err := uuid.FromString(stateParam.Token) if err != nil { assert.Fail(t, "expected state parameter with valid UUID, got: %s with error %s", state.String(), err) } diff --git a/misc.go b/misc.go index a03113567..2cc6619f7 100644 --- a/misc.go +++ b/misc.go @@ -19,6 +19,7 @@ import ( "context" "fmt" "net/http" + "net/url" "path" "strings" "time" @@ -97,8 +98,13 @@ func (r *oauthProxy) redirectToAuthorization(w http.ResponseWriter, req *http.Re } // step: add a state referrer to the authorization page - uuid := r.writeStateParameterCookie(req, w) - authQuery := fmt.Sprintf("?state=%s", uuid) + state, err := r.writeStateParameterCookie(req, w) + if err != nil { + r.log.Error("failed to create state parameter") + w.WriteHeader(http.StatusInternalServerError) + return r.revokeProxy(w, req) + } + authQuery := fmt.Sprintf("?state=%s", url.QueryEscape(state)) // step: if verification is switched off, we can't authorization if r.config.SkipTokenVerification { diff --git a/misc_test.go b/misc_test.go index c416bd273..15daac9ca 100644 --- a/misc_test.go +++ b/misc_test.go @@ -33,9 +33,10 @@ func TestRedirectToAuthorizationUnauthorized(t *testing.T) { func TestRedirectToAuthorization(t *testing.T) { requests := []fakeRequest{ { - URI: "/admin", + URI: "/admin?blah=1", Redirects: true, ExpectedLocation: "/oauth/authorize?state", + ExpectedStateUrl: "/admin?blah=1", ExpectedCode: http.StatusSeeOther, }, } diff --git a/oauth.go b/oauth.go index 230b4ab7e..f9847989b 100644 --- a/oauth.go +++ b/oauth.go @@ -28,11 +28,12 @@ import ( "golang.org/x/oauth2" "github.com/coreos/go-oidc/jose" + "github.com/golang-jwt/jwt/v4" "github.com/coreos/go-oidc/oidc" ) -//FIXME remove constants in the future which hopefully won't be necessary in the next releases +// FIXME remove constants in the future which hopefully won't be necessary in the next releases const ( GrantTypeAuthCode = "authorization_code" GrantTypeUserCreds = "password" @@ -67,6 +68,15 @@ func verifyToken(client *oidc.Client, token jose.JWT) error { return nil } +// verifyToken verify that the token in the user context is valid +func verifyTokenSignature(keyfunc jwt.Keyfunc, token string) error { + if _, err := jwt.Parse(token, keyfunc, jwt.WithoutClaimsValidation()); err != nil { + return err + } + + return nil +} + // getRefreshedToken attempts to refresh the access token, returning the parsed token, optionally with a renewed // refresh token and the time the access and refresh tokens expire // diff --git a/server.go b/server.go index b2ae8bbf6..1920728b2 100644 --- a/server.go +++ b/server.go @@ -38,11 +38,13 @@ import ( httplog "log" + "github.com/MicahParks/keyfunc" proxyproto "github.com/armon/go-proxyproto" "github.com/coreos/go-oidc/oidc" "github.com/elazarl/goproxy" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" + "github.com/golang-jwt/jwt/v4" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/cors" @@ -63,6 +65,7 @@ type oauthProxy struct { store storage templates *template.Template upstream reverseProxy + keyFunc jwt.Keyfunc } func init() { @@ -75,7 +78,7 @@ func init() { prometheus.MustRegister(statusMetric) } -// newProxy create's a new proxy from configuration +// newProxy creates a new proxy from configuration func newProxy(config *Config) (*oauthProxy, error) { // create the service logger log, err := createLogger(config) @@ -126,6 +129,22 @@ func newProxy(config *Config) (*oauthProxy, error) { } } + keyfuncOptions := keyfunc.Options{ + Ctx: context.Background(), + RefreshErrorHandler: func(err error) { + log.Warn(fmt.Sprintf("There was an error with the jwt.Keyfunc\nError: %s", err.Error())) + }, + RefreshInterval: time.Hour, + RefreshRateLimit: time.Minute * 5, + RefreshTimeout: time.Second * 10, + RefreshUnknownKID: true, + } + jwks, err := keyfunc.Get(svc.idp.KeysEndpoint.String(), keyfuncOptions) + if err != nil { + log.Fatal(fmt.Sprintf("Failed to create JWKS from resource at the given URL.\nError: %s", err.Error())) + } + svc.keyFunc = jwks.Keyfunc + return svc, nil } diff --git a/session.go b/session.go index d58ce67b9..4f2caa085 100644 --- a/session.go +++ b/session.go @@ -17,6 +17,7 @@ package main import ( "bytes" + "fmt" "net/http" "strconv" "strings" @@ -25,8 +26,8 @@ import ( "go.uber.org/zap" ) -// getIdentity retrieves the user identity from a request, either from a session cookie or a bearer token -func (r *oauthProxy) getIdentity(req *http.Request) (*userContext, error) { +// getIdentityFromRequest retrieves the user identity from a request, either from a session cookie or a bearer token +func (r *oauthProxy) getIdentityFromRequest(req *http.Request) (*userContext, error) { var isBearer bool // step: check for a bearer token or cookie with jwt token access, isBearer, err := getTokenInRequest(req, r.config.CookieAccessName) @@ -38,15 +39,14 @@ func (r *oauthProxy) getIdentity(req *http.Request) (*userContext, error) { return nil, ErrDecryption } } - token, err := jose.ParseJWT(access) + parsedToken, err := jose.ParseJWT(access) if err != nil { return nil, err } - user, err := extractIdentity(token) + user, err := r.getIdentityFromToken(parsedToken, isBearer) if err != nil { return nil, err } - user.bearerToken = isBearer r.log.Debug("found the user identity", zap.String("id", user.id), @@ -58,6 +58,19 @@ func (r *oauthProxy) getIdentity(req *http.Request) (*userContext, error) { return user, nil } +// getIdentityFromToken retrieves the user identity from a session cookie or a bearer token +func (r *oauthProxy) getIdentityFromToken(token jose.JWT, isBearer bool) (*userContext, error) { + if err := verifyTokenSignature(r.keyFunc, token.Encode()); err != nil { + return nil, fmt.Errorf("failed to verify the token signature: %w", err) + } + user, err := extractIdentity(token) + if err != nil { + return nil, err + } + user.bearerToken = isBearer + return user, nil +} + // getRefreshTokenFromCookie returns the refresh token from the cookie if any func (r *oauthProxy) getRefreshTokenFromCookie(req *http.Request) (string, error) { token, err := getTokenInCookie(req, r.config.CookieRefreshName) diff --git a/session_test.go b/session_test.go index ea91ab697..ee9729185 100644 --- a/session_test.go +++ b/session_test.go @@ -62,7 +62,7 @@ func TestGetIndentity(t *testing.T) { } for i, c := range testCases { - user, err := p.getIdentity(c.Request) + user, err := p.getIdentityFromRequest(c.Request) if err != nil && c.Ok { t.Errorf("test case %d should not have errored", i) continue diff --git a/websocket_test.go b/websocket_test.go index 7da05a97c..ca226a69b 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -26,7 +26,7 @@ import ( "golang.org/x/net/websocket" ) -//TestWebSocket is used to validate that the proxy reverse proxy WebSocket connections. +// TestWebSocket is used to validate that the proxy reverse proxy WebSocket connections. func TestWebSocket(t *testing.T) { // Setup an upstream service. upstream := &fakeUpstreamService{}