Skip to content

Commit 9596f45

Browse files
Dylanactions-user
authored andcommitted
Move all auth code to a common utility package (#897)
Video description: https://drive.google.com/file/d/1C-jso8ZOhHZJ5Qw7C1NbRJTx__7BO9RY/view?usp=sharing GitOrigin-RevId: 1f6bf0cd0537d00a147e47a7c9098bd0f33f103b
1 parent 7cb21b7 commit 9596f45

19 files changed

Lines changed: 259 additions & 466 deletions
Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,14 @@ func WithImpersonateAccount(account string) TokenSourceOptionsFunc {
9090
// Cache this between invocations to avoid additional charges by Auth0 for M2M
9191
// tokens. The oAuthTokenURL looks like this:
9292
// https://somedomain.auth0.com/oauth/token
93-
func (flowConfig ClientCredentialsConfig) TokenSource(oAuthTokenURL, oAuthAudience string, opts ...TokenSourceOptionsFunc) oauth2.TokenSource {
94-
ctx := context.Background()
93+
//
94+
// The context that is passed to this function is used when getting new tokens,
95+
// which will happen initially, and then subsequently when the token expires.
96+
// This means that if this token source is going to be stored and used for many
97+
// requests, it should not use the context of the request that created it, as
98+
// this will be cancelled. Instead it should probably use `context.Background()`
99+
// or similar.
100+
func (flowConfig ClientCredentialsConfig) TokenSource(ctx context.Context, oAuthTokenURL, oAuthAudience string, opts ...TokenSourceOptionsFunc) oauth2.TokenSource {
95101
// inject otel into oauth2
96102
ctx = context.WithValue(ctx, oauth2.HTTPClient, otelhttp.DefaultClient)
97103

@@ -116,47 +122,6 @@ func (flowConfig ClientCredentialsConfig) TokenSource(oAuthTokenURL, oAuthAudien
116122
return conf.TokenSource(ctx)
117123
}
118124

119-
// NewOAuthTokenClient creates a token client that uses the provided TokenSource
120-
// to get a NATS token. `overmindAPIURL` is the root URL of the NATS token
121-
// exchange API that will be used e.g. https://api.server.test/v1
122-
//
123-
// Tokens will be minted under the specified account as long as the client has
124-
// admin permissions, if not, the account that is attached to the client via
125-
// Auth0 metadata will be used
126-
func NewOAuthTokenClient(overmindAPIURL string, account string, ts oauth2.TokenSource) *natsTokenClient {
127-
return NewOAuthTokenClientWithContext(context.Background(), overmindAPIURL, account, ts)
128-
}
129-
130-
// NewOAuthTokenClientWithContext creates a token client that uses the provided
131-
// TokenSource to get a NATS token. `overmindAPIURL` is the root URL of the NATS
132-
// token exchange API that will be used e.g. https://api.server.test/v1
133-
//
134-
// Tokens will be minted under the specified account as long as the client has
135-
// admin permissions, if not, the account that is attached to the client via
136-
// Auth0 metadata will be used
137-
//
138-
// The provided context is used for cancellation and to lookup the HTTP client
139-
// used by oauth2. See the oauth2.HTTPClient variable.
140-
//
141-
// Provide an account name and an admin token to create a token client for a
142-
// foreign account.
143-
func NewOAuthTokenClientWithContext(ctx context.Context, overmindAPIURL string, account string, ts oauth2.TokenSource) *natsTokenClient {
144-
authenticatedClient := oauth2.NewClient(ctx, ts)
145-
146-
// backwards compatibility: remove previously existing "/api" suffix from URL for connect
147-
apiUrl, err := url.Parse(overmindAPIURL)
148-
if err == nil {
149-
apiUrl.Path = ""
150-
overmindAPIURL = apiUrl.String()
151-
}
152-
153-
return &natsTokenClient{
154-
Account: account,
155-
adminClient: sdpconnect.NewAdminServiceClient(authenticatedClient, overmindAPIURL),
156-
mgmtClient: sdpconnect.NewManagementServiceClient(authenticatedClient, overmindAPIURL),
157-
}
158-
}
159-
160125
// natsTokenClient A client that is capable of getting NATS JWTs and signing the
161126
// required nonce to prove ownership of the NKeys. Satisfies the `TokenClient`
162127
// interface
@@ -414,3 +379,44 @@ func NewStaticTokenClient(overmindAPIURL, token, tokenType string) (*natsTokenCl
414379
mgmtClient: sdpconnect.NewManagementServiceClient(&httpClient, overmindAPIURL),
415380
}, nil
416381
}
382+
383+
// NewOAuthTokenClient creates a token client that uses the provided TokenSource
384+
// to get a NATS token. `overmindAPIURL` is the root URL of the NATS token
385+
// exchange API that will be used e.g. https://api.server.test/v1
386+
//
387+
// Tokens will be minted under the specified account as long as the client has
388+
// admin permissions, if not, the account that is attached to the client via
389+
// Auth0 metadata will be used
390+
func NewOAuthTokenClient(overmindAPIURL string, account string, ts oauth2.TokenSource) *natsTokenClient {
391+
return NewOAuthTokenClientWithContext(context.Background(), overmindAPIURL, account, ts)
392+
}
393+
394+
// NewOAuthTokenClientWithContext creates a token client that uses the provided
395+
// TokenSource to get a NATS token. `overmindAPIURL` is the root URL of the NATS
396+
// token exchange API that will be used e.g. https://api.server.test/v1
397+
//
398+
// Tokens will be minted under the specified account as long as the client has
399+
// admin permissions, if not, the account that is attached to the client via
400+
// Auth0 metadata will be used
401+
//
402+
// The provided context is used for cancellation and to lookup the HTTP client
403+
// used by oauth2. See the oauth2.HTTPClient variable.
404+
//
405+
// Provide an account name and an admin token to create a token client for a
406+
// foreign account.
407+
func NewOAuthTokenClientWithContext(ctx context.Context, overmindAPIURL string, account string, ts oauth2.TokenSource) *natsTokenClient {
408+
authenticatedClient := oauth2.NewClient(ctx, ts)
409+
410+
// backwards compatibility: remove previously existing "/api" suffix from URL for connect
411+
apiUrl, err := url.Parse(overmindAPIURL)
412+
if err == nil {
413+
apiUrl.Path = ""
414+
overmindAPIURL = apiUrl.String()
415+
}
416+
417+
return &natsTokenClient{
418+
Account: account,
419+
adminClient: sdpconnect.NewAdminServiceClient(authenticatedClient, overmindAPIURL),
420+
mgmtClient: sdpconnect.NewManagementServiceClient(authenticatedClient, overmindAPIURL),
421+
}
422+
}
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"fmt"
66
"net/http"
77

8-
"github.com/overmindtech/cli/sdp-go"
98
"github.com/overmindtech/cli/sdp-go/sdpconnect"
109
log "github.com/sirupsen/logrus"
1110
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
@@ -37,7 +36,7 @@ func (y *AuthenticatedTransport) RoundTrip(req *http.Request) (*http.Response, e
3736
// NewAuthenticatedClient creates a new AuthenticatedClient from the given
3837
// context and http.Client.
3938
func NewAuthenticatedClient(ctx context.Context, from *http.Client) *http.Client {
40-
token, ok := ctx.Value(sdp.UserTokenContextKey{}).(string)
39+
token, ok := ctx.Value(UserTokenContextKey{}).(string)
4140
if !ok {
4241
token = ""
4342
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func GetTestOAuthTokenClient(t *testing.T) *natsTokenClient {
9797
return NewOAuthTokenClient(
9898
exchangeURL,
9999
"overmind-development",
100-
flowConfig.TokenSource(fmt.Sprintf("https://%v/oauth/token", domain), os.Getenv("API_SERVER_AUDIENCE")),
100+
flowConfig.TokenSource(t.Context(), fmt.Sprintf("https://%v/oauth/token", domain), os.Getenv("API_SERVER_AUDIENCE")),
101101
)
102102
}
103103

Lines changed: 89 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package sdp
1+
package auth
22

33
import (
44
"context"
@@ -18,11 +18,11 @@ import (
1818
"go.opentelemetry.io/otel/trace"
1919
)
2020

21-
// AuthBypassedContextKey is a key that is stored in the request context when auth is
22-
// actively being bypassed, e.g. in development. When this is set the
23-
// `HasScopes()` function will always return true, and can be set using the
24-
// `BypassAuth()` middleware.
25-
type AuthBypassedContextKey struct{}
21+
// ScopeCheckBypassedContextKey is a key that is stored in the request context
22+
// when scope checking is actively being bypassed, e.g. in development. When
23+
// this is set the `HasScopes()` function will always return true, and can be
24+
// set using the `WithBypassScopeCheck()` middleware.
25+
type ScopeCheckBypassedContextKey struct{}
2626

2727
// CustomClaimsContextKey is the key that is used to store the custom claims
2828
// from the JWT
@@ -84,7 +84,7 @@ func HasAllScopes(ctx context.Context, requiredScopes ...string) bool {
8484
attribute.StringSlice("ovm.auth.requiredScopes.all", requiredScopes),
8585
)
8686

87-
if ctx.Value(AuthBypassedContextKey{}) == true {
87+
if ctx.Value(ScopeCheckBypassedContextKey{}) == true {
8888
// this is always set when auth is bypassed
8989
// set it here again to capture non-standard auth configs
9090
span.SetAttributes(attribute.Bool("ovm.auth.bypass", true))
@@ -116,7 +116,7 @@ func HasAnyScopes(ctx context.Context, requiredScopes ...string) bool {
116116
attribute.StringSlice("ovm.auth.requiredScopes.any", requiredScopes),
117117
)
118118

119-
if ctx.Value(AuthBypassedContextKey{}) == true {
119+
if ctx.Value(ScopeCheckBypassedContextKey{}) == true {
120120
// this is always set when auth is bypassed
121121
// set it here again to capture non-standard auth configs
122122
span.SetAttributes(attribute.Bool("ovm.auth.bypass", true))
@@ -169,7 +169,20 @@ func ExtractAccount(ctx context.Context) (string, error) {
169169
// must also be set.
170170
func NewAuthMiddleware(config AuthConfig, next http.Handler) http.Handler {
171171
processOverrides := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
172-
ctx := OverrideCustomClaims(r.Context(), config.ScopeOverride, config.AccountOverride)
172+
options := []OverrideAuthOptionFunc{}
173+
174+
if config.ScopeOverride != nil {
175+
options = append(options, WithScope(*config.ScopeOverride))
176+
}
177+
178+
if config.AccountOverride != nil {
179+
options = append(options, WithAccount(*config.AccountOverride))
180+
}
181+
182+
ctx := r.Context()
183+
if len(options) > 0 {
184+
ctx = OverrideAuth(r.Context(), options...)
185+
}
173186

174187
r = r.Clone(ctx)
175188

@@ -179,51 +192,83 @@ func NewAuthMiddleware(config AuthConfig, next http.Handler) http.Handler {
179192
return ensureValidTokenHandler(config, processOverrides)
180193
}
181194

182-
// AddBypassAuthConfig Adds the requires keys to the context so that
183-
// authentication is bypassed. This is intended to be used in tests
184-
func AddBypassAuthConfig(ctx context.Context) context.Context {
185-
return context.WithValue(ctx, AuthBypassedContextKey{}, true)
186-
}
195+
type OverrideAuthOptionFunc func(ctx context.Context) context.Context
187196

188-
// OverrideAuthContext overrides the authentication data and token stored in the context.
189-
// This is mostly useful for testing or delegating access locally into a protected API.
190-
func OverrideAuthContext(ctx context.Context, claims *validator.ValidatedClaims) context.Context {
191-
customClaims := claims.CustomClaims.(*CustomClaims)
192-
ctx = context.WithValue(ctx, jwtmiddleware.ContextKey{}, claims)
193-
ctx = context.WithValue(ctx, CustomClaimsContextKey{}, customClaims)
194-
ctx = context.WithValue(ctx, CurrentSubjectContextKey{}, claims.RegisteredClaims.Subject)
195-
ctx = context.WithValue(ctx, AccountNameContextKey{}, customClaims.AccountName)
196-
return ctx
197+
// Sets the scope in the context to the given value. This should be the value
198+
// that would be embedded directly in the token, with each scope being separated
199+
// by a space.
200+
func WithScope(scope string) OverrideAuthOptionFunc {
201+
return withCustomClaims(func(claims *CustomClaims) {
202+
claims.Scope = scope
203+
})
197204
}
198205

199-
// OverrideCustomClaims Overrides the custom claims in the context that have
200-
// been set at CustomClaimsContextKey
201-
func OverrideCustomClaims(ctx context.Context, scope *string, account *string) context.Context {
202-
// Read existing claims from the context
203-
i := ctx.Value(CustomClaimsContextKey{})
204-
205-
var claims *CustomClaims
206-
var newClaims CustomClaims
207-
var ok bool
206+
// Sets the account in the context to the given value.
207+
func WithAccount(account string) OverrideAuthOptionFunc {
208+
return withCustomClaims(func(claims *CustomClaims) {
209+
claims.AccountName = account
210+
})
211+
}
208212

209-
if claims, ok = i.(*CustomClaims); ok {
210-
// clone out the values to avoid false sharing
211-
newClaims = *claims
213+
// Sets the auth info in the context directly from the validated claims produced
214+
// by the `github.com/auth0/go-jwt-middleware/v2/validator` package. This is
215+
// essentially what the middleware already does when receiving a request, and
216+
// therefore should only be used in exceptional circumstances, like testing, when the
217+
// middleware is not being used.
218+
//
219+
// If this is being used, there is no need to use the `WithScope` or `WithAccount`
220+
// options as the claims will be extracted directly from the validated claims.
221+
func WithValidatedClaims(claims *validator.ValidatedClaims) OverrideAuthOptionFunc {
222+
return func(ctx context.Context) context.Context {
223+
customClaims := claims.CustomClaims.(*CustomClaims)
224+
ctx = context.WithValue(ctx, jwtmiddleware.ContextKey{}, claims)
225+
ctx = context.WithValue(ctx, CustomClaimsContextKey{}, customClaims)
226+
ctx = context.WithValue(ctx, CurrentSubjectContextKey{}, claims.RegisteredClaims.Subject)
227+
ctx = context.WithValue(ctx, AccountNameContextKey{}, customClaims.AccountName)
228+
return ctx
212229
}
230+
}
213231

214-
if scope != nil {
215-
newClaims.Scope = *scope
232+
// Bypasses the scope check, meaning that `HasScopes()` and `HasAllScopes` will
233+
// always return true. This is useful for testing.
234+
func WithBypassScopeCheck() OverrideAuthOptionFunc {
235+
return func(ctx context.Context) context.Context {
236+
return context.WithValue(ctx, ScopeCheckBypassedContextKey{}, true)
216237
}
238+
}
217239

218-
if account != nil {
219-
newClaims.AccountName = *account
240+
// Overrides the authentication that is currently stored in the context. This
241+
// can only be used within a single process, and doesn't mean that the overrides
242+
// set here will be passed on if you are using `NewAuthenticatedClient` to pass
243+
// through auth. It is however useful for testing, or for calling other handlers
244+
// within the same process.
245+
func OverrideAuth(ctx context.Context, opts ...OverrideAuthOptionFunc) context.Context {
246+
for _, opt := range opts {
247+
ctx = opt(ctx)
220248
}
249+
return ctx
250+
}
221251

222-
// Store the new claims in the context
223-
ctx = context.WithValue(ctx, CustomClaimsContextKey{}, &newClaims)
224-
ctx = context.WithValue(ctx, AccountNameContextKey{}, newClaims.AccountName)
252+
func withCustomClaims(modify func(*CustomClaims)) OverrideAuthOptionFunc {
253+
return func(ctx context.Context) context.Context {
254+
i := ctx.Value(CustomClaimsContextKey{})
255+
var claims *CustomClaims
256+
var newClaims CustomClaims
257+
var ok bool
225258

226-
return ctx
259+
if claims, ok = i.(*CustomClaims); ok {
260+
// clone out the values to avoid sharing
261+
newClaims = *claims
262+
}
263+
264+
modify(&newClaims)
265+
266+
// Store the new claims in the context
267+
ctx = context.WithValue(ctx, CustomClaimsContextKey{}, &newClaims)
268+
ctx = context.WithValue(ctx, AccountNameContextKey{}, newClaims.AccountName)
269+
270+
return ctx
271+
}
227272
}
228273

229274
// ensureValidTokenHandler is a middleware that will check the validity of our
@@ -375,7 +420,7 @@ func ensureValidTokenHandler(config AuthConfig, next http.Handler) http.Handler
375420
span.SetAttributes(attribute.Bool("ovm.auth.bypass", shouldBypass))
376421

377422
if shouldBypass {
378-
ctx = AddBypassAuthConfig(ctx)
423+
ctx = OverrideAuth(ctx, WithBypassScopeCheck())
379424

380425
r = r.Clone(ctx)
381426

0 commit comments

Comments
 (0)