Skip to content

Commit f288ad2

Browse files
committed
Add bearer support for authorization headers
1 parent 638b410 commit f288ad2

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

server/auth/store.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ const (
7272

7373
tokenTypeSimple = "simple"
7474
tokenTypeJWT = "jwt"
75+
76+
bearerPrefix = "Bearer "
7577
)
7678

7779
type AuthInfo struct {
@@ -1072,6 +1074,12 @@ func (as *authStore) AuthInfoFromCtx(ctx context.Context) (*AuthInfo, error) {
10721074
}
10731075

10741076
token := ts[0]
1077+
1078+
// support authorization headers with bearer prefix.
1079+
if strings.HasPrefix(token, bearerPrefix) {
1080+
token = strings.Split(token, bearerPrefix)[1]
1081+
}
1082+
10751083
authInfo, uok := as.authInfoFromToken(ctx, token)
10761084
if !uok {
10771085
as.lg.Warn("invalid auth token", zap.String("token", token))

server/auth/store_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,12 +836,30 @@ func TestAuthInfoFromCtx(t *testing.T) {
836836
t.Errorf("expected %v, got %v", ErrInvalidAuthToken, err)
837837
}
838838

839+
ctx = metadata.NewIncomingContext(t.Context(), metadata.New(map[string]string{rpctypes.TokenFieldNameGRPC: "Bearer"}))
840+
_, err = as.AuthInfoFromCtx(ctx)
841+
if !errors.Is(err, ErrInvalidAuthToken) {
842+
t.Errorf("expected %v, got %v", ErrInvalidAuthToken, err)
843+
}
844+
839845
ctx = metadata.NewIncomingContext(t.Context(), metadata.New(map[string]string{rpctypes.TokenFieldNameGRPC: "Invalid.Token"}))
840846
_, err = as.AuthInfoFromCtx(ctx)
841847
if !errors.Is(err, ErrInvalidAuthToken) {
842848
t.Errorf("expected %v, got %v", ErrInvalidAuthToken, err)
843849
}
844850

851+
ctx = metadata.NewIncomingContext(t.Context(), metadata.New(map[string]string{rpctypes.TokenFieldNameGRPC: bearerPrefix + "Invalid.Token"}))
852+
_, err = as.AuthInfoFromCtx(ctx)
853+
if !errors.Is(err, ErrInvalidAuthToken) {
854+
t.Errorf("expected %v, got %v", ErrInvalidAuthToken, err)
855+
}
856+
857+
ctx = metadata.NewIncomingContext(t.Context(), metadata.New(map[string]string{rpctypes.TokenFieldNameGRPC: bearerPrefix}))
858+
_, err = as.AuthInfoFromCtx(ctx)
859+
if !errors.Is(err, ErrInvalidAuthToken) {
860+
t.Errorf("expected %v, got %v", ErrInvalidAuthToken, err)
861+
}
862+
845863
ctx = metadata.NewIncomingContext(t.Context(), metadata.New(map[string]string{rpctypes.TokenFieldNameGRPC: resp.Token}))
846864
ai, err = as.AuthInfoFromCtx(ctx)
847865
if err != nil {
@@ -850,6 +868,15 @@ func TestAuthInfoFromCtx(t *testing.T) {
850868
if ai.Username != "foo" {
851869
t.Errorf("expected %v, got %v", "foo", ai.Username)
852870
}
871+
872+
ctx = metadata.NewIncomingContext(t.Context(), metadata.New(map[string]string{rpctypes.TokenFieldNameGRPC: bearerPrefix + resp.Token}))
873+
ai, err = as.AuthInfoFromCtx(ctx)
874+
if err != nil {
875+
t.Error(err)
876+
}
877+
if ai.Username != "foo" {
878+
t.Errorf("expected %v, got %v", "foo", ai.Username)
879+
}
853880
}
854881

855882
func TestAuthDisable(t *testing.T) {

0 commit comments

Comments
 (0)