-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathauthentication.go
111 lines (95 loc) · 2.78 KB
/
authentication.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package authn
import (
"context"
"errors"
"net/http"
"go.uber.org/multierr"
)
type contextKey struct{}
type Authentication[T any] struct {
authenticators []RequestAuthenticator[T]
contextKey contextKey
unknownErrorHandler func(w http.ResponseWriter, r *http.Request, err error)
}
func New[T any](authenticators ...RequestAuthenticator[T]) *Authentication[T] {
return &Authentication[T]{authenticators: authenticators, contextKey: contextKey{}}
}
func (a *Authentication[T]) SetUnknownErrorHandler(handler func(w http.ResponseWriter, r *http.Request, err error)) {
a.unknownErrorHandler = handler
}
func (a *Authentication[T]) NewContext(ctx context.Context, auth T) context.Context {
return context.WithValue(ctx, a.contextKey, auth)
}
func (a *Authentication[T]) Get(ctx context.Context) (result T, err error) {
if auth, ok := ctx.Value(a.contextKey).(T); ok {
return auth, nil
} else {
return result, ErrNoAuthentication
}
}
func (a *Authentication[T]) Require(ctx context.Context) T {
if auth, err := a.Get(ctx); err != nil {
panic(err)
} else {
return auth
}
}
func (a *Authentication[T]) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var err error
for _, provider := range a.authenticators {
var result T
if result, err = provider.Authenticate(r.Context(), r); err != nil {
if errors.Is(err, ErrNoAuthentication) {
continue
} else {
break
}
} else {
next.ServeHTTP(w, r.WithContext(a.NewContext(r.Context(), result)))
return
}
}
a.handleError(w, r, err)
})
}
func (a *Authentication[T]) ValidatorMiddleware(fn func(value T) error) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := fn(a.Require(r.Context())); err != nil {
a.handleError(w, r, err)
} else {
next.ServeHTTP(w, r)
}
})
}
}
func (a *Authentication[T]) handleError(w http.ResponseWriter, r *http.Request, err error) {
for _, err := range multierr.Errors(err) {
var rh WithResponseHeaders
if errors.As(err, &rh) {
for key, value := range rh.ResponseHeaders() {
for _, v := range value {
w.Header().Add(key, v)
}
}
}
}
statusCode := http.StatusInternalServerError
var rs WithResponseStatus
if errors.As(err, &rs) {
statusCode = rs.ResponseStatus()
} else if errors.Is(err, ErrBadAuthentication) || errors.Is(err, ErrNoAuthentication) {
statusCode = http.StatusUnauthorized
} else if a.unknownErrorHandler != nil {
a.unknownErrorHandler(w, r, err)
return
}
var rw ResponseBodyWriter
if errors.As(err, &rw) {
w.WriteHeader(statusCode)
rw.WriteResponse(w)
} else {
http.Error(w, http.StatusText(statusCode), statusCode)
}
}