1- package sdp
1+ package auth
22
33import (
44 "context"
@@ -18,11 +18,11 @@ import (
1818 "go.opentelemetry.io/otel/trace"
1919)
2020
21- // AuthBypassedContextKey is a key that is stored in the request context when auth is
22- // actively being bypassed, e.g. in development. When this is set the
23- // `HasScopes()` function will always return true, and can be set using the
24- // `BypassAuth ()` middleware.
25- type AuthBypassedContextKey struct {}
21+ // ScopeCheckBypassedContextKey is a key that is stored in the request context
22+ // when scope checking is actively being bypassed, e.g. in development. When
23+ // this is set the `HasScopes()` function will always return true, and can be
24+ // set using the `WithBypassScopeCheck ()` middleware.
25+ type ScopeCheckBypassedContextKey struct {}
2626
2727// CustomClaimsContextKey is the key that is used to store the custom claims
2828// from the JWT
@@ -84,7 +84,7 @@ func HasAllScopes(ctx context.Context, requiredScopes ...string) bool {
8484 attribute .StringSlice ("ovm.auth.requiredScopes.all" , requiredScopes ),
8585 )
8686
87- if ctx .Value (AuthBypassedContextKey {}) == true {
87+ if ctx .Value (ScopeCheckBypassedContextKey {}) == true {
8888 // this is always set when auth is bypassed
8989 // set it here again to capture non-standard auth configs
9090 span .SetAttributes (attribute .Bool ("ovm.auth.bypass" , true ))
@@ -116,7 +116,7 @@ func HasAnyScopes(ctx context.Context, requiredScopes ...string) bool {
116116 attribute .StringSlice ("ovm.auth.requiredScopes.any" , requiredScopes ),
117117 )
118118
119- if ctx .Value (AuthBypassedContextKey {}) == true {
119+ if ctx .Value (ScopeCheckBypassedContextKey {}) == true {
120120 // this is always set when auth is bypassed
121121 // set it here again to capture non-standard auth configs
122122 span .SetAttributes (attribute .Bool ("ovm.auth.bypass" , true ))
@@ -169,7 +169,20 @@ func ExtractAccount(ctx context.Context) (string, error) {
169169// must also be set.
170170func NewAuthMiddleware (config AuthConfig , next http.Handler ) http.Handler {
171171 processOverrides := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
172- ctx := OverrideCustomClaims (r .Context (), config .ScopeOverride , config .AccountOverride )
172+ options := []OverrideAuthOptionFunc {}
173+
174+ if config .ScopeOverride != nil {
175+ options = append (options , WithScope (* config .ScopeOverride ))
176+ }
177+
178+ if config .AccountOverride != nil {
179+ options = append (options , WithAccount (* config .AccountOverride ))
180+ }
181+
182+ ctx := r .Context ()
183+ if len (options ) > 0 {
184+ ctx = OverrideAuth (r .Context (), options ... )
185+ }
173186
174187 r = r .Clone (ctx )
175188
@@ -179,51 +192,83 @@ func NewAuthMiddleware(config AuthConfig, next http.Handler) http.Handler {
179192 return ensureValidTokenHandler (config , processOverrides )
180193}
181194
182- // AddBypassAuthConfig Adds the requires keys to the context so that
183- // authentication is bypassed. This is intended to be used in tests
184- func AddBypassAuthConfig (ctx context.Context ) context.Context {
185- return context .WithValue (ctx , AuthBypassedContextKey {}, true )
186- }
195+ type OverrideAuthOptionFunc func (ctx context.Context ) context.Context
187196
188- // OverrideAuthContext overrides the authentication data and token stored in the context.
189- // This is mostly useful for testing or delegating access locally into a protected API.
190- func OverrideAuthContext (ctx context.Context , claims * validator.ValidatedClaims ) context.Context {
191- customClaims := claims .CustomClaims .(* CustomClaims )
192- ctx = context .WithValue (ctx , jwtmiddleware.ContextKey {}, claims )
193- ctx = context .WithValue (ctx , CustomClaimsContextKey {}, customClaims )
194- ctx = context .WithValue (ctx , CurrentSubjectContextKey {}, claims .RegisteredClaims .Subject )
195- ctx = context .WithValue (ctx , AccountNameContextKey {}, customClaims .AccountName )
196- return ctx
197+ // Sets the scope in the context to the given value. This should be the value
198+ // that would be embedded directly in the token, with each scope being separated
199+ // by a space.
200+ func WithScope (scope string ) OverrideAuthOptionFunc {
201+ return withCustomClaims (func (claims * CustomClaims ) {
202+ claims .Scope = scope
203+ })
197204}
198205
199- // OverrideCustomClaims Overrides the custom claims in the context that have
200- // been set at CustomClaimsContextKey
201- func OverrideCustomClaims (ctx context.Context , scope * string , account * string ) context.Context {
202- // Read existing claims from the context
203- i := ctx .Value (CustomClaimsContextKey {})
204-
205- var claims * CustomClaims
206- var newClaims CustomClaims
207- var ok bool
206+ // Sets the account in the context to the given value.
207+ func WithAccount (account string ) OverrideAuthOptionFunc {
208+ return withCustomClaims (func (claims * CustomClaims ) {
209+ claims .AccountName = account
210+ })
211+ }
208212
209- if claims , ok = i .(* CustomClaims ); ok {
210- // clone out the values to avoid false sharing
211- newClaims = * claims
213+ // Sets the auth info in the context directly from the validated claims produced
214+ // by the `github.com/auth0/go-jwt-middleware/v2/validator` package. This is
215+ // essentially what the middleware already does when receiving a request, and
216+ // therefore should only be used in exceptional circumstances, like testing, when the
217+ // middleware is not being used.
218+ //
219+ // If this is being used, there is no need to use the `WithScope` or `WithAccount`
220+ // options as the claims will be extracted directly from the validated claims.
221+ func WithValidatedClaims (claims * validator.ValidatedClaims ) OverrideAuthOptionFunc {
222+ return func (ctx context.Context ) context.Context {
223+ customClaims := claims .CustomClaims .(* CustomClaims )
224+ ctx = context .WithValue (ctx , jwtmiddleware.ContextKey {}, claims )
225+ ctx = context .WithValue (ctx , CustomClaimsContextKey {}, customClaims )
226+ ctx = context .WithValue (ctx , CurrentSubjectContextKey {}, claims .RegisteredClaims .Subject )
227+ ctx = context .WithValue (ctx , AccountNameContextKey {}, customClaims .AccountName )
228+ return ctx
212229 }
230+ }
213231
214- if scope != nil {
215- newClaims .Scope = * scope
232+ // Bypasses the scope check, meaning that `HasScopes()` and `HasAllScopes` will
233+ // always return true. This is useful for testing.
234+ func WithBypassScopeCheck () OverrideAuthOptionFunc {
235+ return func (ctx context.Context ) context.Context {
236+ return context .WithValue (ctx , ScopeCheckBypassedContextKey {}, true )
216237 }
238+ }
217239
218- if account != nil {
219- newClaims .AccountName = * account
240+ // Overrides the authentication that is currently stored in the context. This
241+ // can only be used within a single process, and doesn't mean that the overrides
242+ // set here will be passed on if you are using `NewAuthenticatedClient` to pass
243+ // through auth. It is however useful for testing, or for calling other handlers
244+ // within the same process.
245+ func OverrideAuth (ctx context.Context , opts ... OverrideAuthOptionFunc ) context.Context {
246+ for _ , opt := range opts {
247+ ctx = opt (ctx )
220248 }
249+ return ctx
250+ }
221251
222- // Store the new claims in the context
223- ctx = context .WithValue (ctx , CustomClaimsContextKey {}, & newClaims )
224- ctx = context .WithValue (ctx , AccountNameContextKey {}, newClaims .AccountName )
252+ func withCustomClaims (modify func (* CustomClaims )) OverrideAuthOptionFunc {
253+ return func (ctx context.Context ) context.Context {
254+ i := ctx .Value (CustomClaimsContextKey {})
255+ var claims * CustomClaims
256+ var newClaims CustomClaims
257+ var ok bool
225258
226- return ctx
259+ if claims , ok = i .(* CustomClaims ); ok {
260+ // clone out the values to avoid sharing
261+ newClaims = * claims
262+ }
263+
264+ modify (& newClaims )
265+
266+ // Store the new claims in the context
267+ ctx = context .WithValue (ctx , CustomClaimsContextKey {}, & newClaims )
268+ ctx = context .WithValue (ctx , AccountNameContextKey {}, newClaims .AccountName )
269+
270+ return ctx
271+ }
227272}
228273
229274// ensureValidTokenHandler is a middleware that will check the validity of our
@@ -375,7 +420,7 @@ func ensureValidTokenHandler(config AuthConfig, next http.Handler) http.Handler
375420 span .SetAttributes (attribute .Bool ("ovm.auth.bypass" , shouldBypass ))
376421
377422 if shouldBypass {
378- ctx = AddBypassAuthConfig (ctx )
423+ ctx = OverrideAuth (ctx , WithBypassScopeCheck () )
379424
380425 r = r .Clone (ctx )
381426
0 commit comments