diff --git a/server/auth/store.go b/server/auth/store.go index cfacfb001c5..2da1b7e3ec4 100644 --- a/server/auth/store.go +++ b/server/auth/store.go @@ -72,6 +72,8 @@ const ( tokenTypeSimple = "simple" tokenTypeJWT = "jwt" + + bearerPrefix = "Bearer " ) type AuthInfo struct { @@ -1072,6 +1074,12 @@ func (as *authStore) AuthInfoFromCtx(ctx context.Context) (*AuthInfo, error) { } token := ts[0] + + // support authorization headers with bearer prefix. + if strings.HasPrefix(token, bearerPrefix) { + token = strings.Split(token, bearerPrefix)[1] + } + authInfo, uok := as.authInfoFromToken(ctx, token) if !uok { as.lg.Warn("invalid auth token", zap.String("token", token)) diff --git a/server/auth/store_test.go b/server/auth/store_test.go index df13fbc297d..644edd615b3 100644 --- a/server/auth/store_test.go +++ b/server/auth/store_test.go @@ -836,12 +836,30 @@ func TestAuthInfoFromCtx(t *testing.T) { t.Errorf("expected %v, got %v", ErrInvalidAuthToken, err) } + ctx = metadata.NewIncomingContext(t.Context(), metadata.New(map[string]string{rpctypes.TokenFieldNameGRPC: "Bearer"})) + _, err = as.AuthInfoFromCtx(ctx) + if !errors.Is(err, ErrInvalidAuthToken) { + t.Errorf("expected %v, got %v", ErrInvalidAuthToken, err) + } + ctx = metadata.NewIncomingContext(t.Context(), metadata.New(map[string]string{rpctypes.TokenFieldNameGRPC: "Invalid.Token"})) _, err = as.AuthInfoFromCtx(ctx) if !errors.Is(err, ErrInvalidAuthToken) { t.Errorf("expected %v, got %v", ErrInvalidAuthToken, err) } + ctx = metadata.NewIncomingContext(t.Context(), metadata.New(map[string]string{rpctypes.TokenFieldNameGRPC: bearerPrefix + "Invalid.Token"})) + _, err = as.AuthInfoFromCtx(ctx) + if !errors.Is(err, ErrInvalidAuthToken) { + t.Errorf("expected %v, got %v", ErrInvalidAuthToken, err) + } + + ctx = metadata.NewIncomingContext(t.Context(), metadata.New(map[string]string{rpctypes.TokenFieldNameGRPC: bearerPrefix})) + _, err = as.AuthInfoFromCtx(ctx) + if !errors.Is(err, ErrInvalidAuthToken) { + t.Errorf("expected %v, got %v", ErrInvalidAuthToken, err) + } + ctx = metadata.NewIncomingContext(t.Context(), metadata.New(map[string]string{rpctypes.TokenFieldNameGRPC: resp.Token})) ai, err = as.AuthInfoFromCtx(ctx) if err != nil { @@ -850,6 +868,15 @@ func TestAuthInfoFromCtx(t *testing.T) { if ai.Username != "foo" { t.Errorf("expected %v, got %v", "foo", ai.Username) } + + ctx = metadata.NewIncomingContext(t.Context(), metadata.New(map[string]string{rpctypes.TokenFieldNameGRPC: bearerPrefix + resp.Token})) + ai, err = as.AuthInfoFromCtx(ctx) + if err != nil { + t.Error(err) + } + if ai.Username != "foo" { + t.Errorf("expected %v, got %v", "foo", ai.Username) + } } func TestAuthDisable(t *testing.T) {