Skip to content

Commit 620018f

Browse files
authored
Merge pull request #65 from zalando/access-token-in-header
Access token in header
2 parents 8864d0c + a5e8c9b commit 620018f

File tree

1 file changed

+40
-14
lines changed

1 file changed

+40
-14
lines changed

ginoauth2.go

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ type TokenContainer struct {
8787
// access.
8888
type AccessCheckFunction func(tc *TokenContainer, ctx *gin.Context) bool
8989

90+
type Options struct {
91+
Endpoint oauth2.Endpoint
92+
AccessTokenInHeader bool
93+
}
94+
9095
func extractToken(r *http.Request) (*oauth2.Token, error) {
9196
hdr := r.Header.Get("Authorization")
9297
if hdr == "" {
@@ -101,17 +106,26 @@ func extractToken(r *http.Request) (*oauth2.Token, error) {
101106
return &oauth2.Token{AccessToken: th[1], TokenType: th[0]}, nil
102107
}
103108

104-
func RequestAuthInfo(t *oauth2.Token) ([]byte, error) {
105-
var uv = make(url.Values)
106-
// uv.Set("realm", o.Realm)
107-
uv.Set("access_token", t.AccessToken)
108-
infoURL := AuthInfoURL + "?" + uv.Encode()
109+
func requestAuthInfo(o Options, t *oauth2.Token) ([]byte, error) {
110+
var infoURL string
111+
if o.AccessTokenInHeader {
112+
infoURL = AuthInfoURL
113+
} else {
114+
var uv = make(url.Values)
115+
uv.Set("access_token", t.AccessToken)
116+
infoURL = AuthInfoURL + "?" + uv.Encode()
117+
}
118+
109119
client := &http.Client{Transport: &Transport}
110120
req, err := http.NewRequest("GET", infoURL, nil)
111121
if err != nil {
112122
return nil, err
113123
}
114124

125+
if o.AccessTokenInHeader {
126+
req.Header.Set("Authorization", "Bearer " + t.AccessToken)
127+
}
128+
115129
resp, err := client.Do(req)
116130
if err != nil {
117131
return nil, err
@@ -121,6 +135,10 @@ func RequestAuthInfo(t *oauth2.Token) ([]byte, error) {
121135
return ioutil.ReadAll(resp.Body)
122136
}
123137

138+
func RequestAuthInfo(t *oauth2.Token) ([]byte, error) {
139+
return requestAuthInfo(Options{}, t)
140+
}
141+
124142
func ParseTokenContainer(t *oauth2.Token, data map[string]interface{}) (*TokenContainer, error) {
125143
tdata := make(map[string]interface{})
126144

@@ -158,8 +176,8 @@ func ParseTokenContainer(t *oauth2.Token, data map[string]interface{}) (*TokenCo
158176
}, nil
159177
}
160178

161-
func GetTokenContainer(token *oauth2.Token) (*TokenContainer, error) {
162-
body, err := RequestAuthInfo(token)
179+
func getTokenContainerForToken(o Options, token *oauth2.Token) (*TokenContainer, error) {
180+
body, err := requestAuthInfo(o, token)
163181
if err != nil {
164182
glog.Errorf("[Gin-OAuth] RequestAuthInfo failed caused by: %s", err)
165183
return nil, err
@@ -180,7 +198,11 @@ func GetTokenContainer(token *oauth2.Token) (*TokenContainer, error) {
180198
return ParseTokenContainer(token, data)
181199
}
182200

183-
func getTokenContainer(ctx *gin.Context) (*TokenContainer, bool) {
201+
func GetTokenContainer(token *oauth2.Token) (*TokenContainer, error) {
202+
return getTokenContainerForToken(Options{}, token)
203+
}
204+
205+
func getTokenContainer(o Options, ctx *gin.Context) (*TokenContainer, bool) {
184206
var oauthToken *oauth2.Token
185207
var tc *TokenContainer
186208
var err error
@@ -194,7 +216,7 @@ func getTokenContainer(ctx *gin.Context) (*TokenContainer, bool) {
194216
return nil, false
195217
}
196218

197-
if tc, err = GetTokenContainer(oauthToken); err != nil {
219+
if tc, err = getTokenContainerForToken(o, oauthToken); err != nil {
198220
glog.Errorf("[Gin-OAuth] Can not extract TokenContainer, caused by: %s", err)
199221
return nil, false
200222
}
@@ -253,27 +275,31 @@ func Auth(accessCheckFunction AccessCheckFunction, endpoints oauth2.Endpoint) gi
253275
// c.JSON(200, gin.H{"message": "Hello from private"})
254276
// })
255277
//
256-
func AuthChain(endpoints oauth2.Endpoint, accessCheckFunctions ...AccessCheckFunction) gin.HandlerFunc {
278+
func AuthChain(endpoint oauth2.Endpoint, accessCheckFunctions ...AccessCheckFunction) gin.HandlerFunc {
279+
return AuthChainOptions(Options{Endpoint: endpoint}, accessCheckFunctions...)
280+
}
281+
282+
func AuthChainOptions(o Options, accessCheckFunctions ...AccessCheckFunction) gin.HandlerFunc {
257283
// init
258-
AuthInfoURL = endpoints.TokenURL
284+
AuthInfoURL = o.Endpoint.TokenURL
259285
// middleware
260286
return func(ctx *gin.Context) {
261287
t := time.Now()
262288
varianceControl := make(chan bool, 1)
263289

264290
go func() {
265-
tokenContainer, ok := getTokenContainer(ctx)
291+
tokenContainer, ok := getTokenContainer(o, ctx)
266292
if !ok {
267293
// set LOCATION header to auth endpoint such that the user can easily get a new access-token
268-
ctx.Writer.Header().Set("Location", endpoints.AuthURL)
294+
ctx.Writer.Header().Set("Location", o.Endpoint.AuthURL)
269295
ctx.AbortWithError(http.StatusUnauthorized, errors.New("No token in context"))
270296
varianceControl <- false
271297
return
272298
}
273299

274300
if !tokenContainer.Valid() {
275301
// set LOCATION header to auth endpoint such that the user can easily get a new access-token
276-
ctx.Writer.Header().Set("Location", endpoints.AuthURL)
302+
ctx.Writer.Header().Set("Location", o.Endpoint.AuthURL)
277303
ctx.AbortWithError(http.StatusUnauthorized, errors.New("Invalid Token"))
278304
varianceControl <- false
279305
return

0 commit comments

Comments
 (0)