Skip to content

Commit ee6587c

Browse files
authored
✨ grpc server options and authenticator (#109)
* grpc server options and authenticator Signed-off-by: Jian Qiu <[email protected]> * Add ut Signed-off-by: Jian Qiu <[email protected]> --------- Signed-off-by: Jian Qiu <[email protected]>
1 parent d637c06 commit ee6587c

File tree

9 files changed

+547
-1
lines changed

9 files changed

+547
-1
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ require (
2222
github.com/openshift/library-go v0.0.0-20240621150525-4bb4238aef81
2323
github.com/prometheus/client_golang v1.19.1
2424
github.com/prometheus/client_model v0.6.1
25+
github.com/spf13/pflag v1.0.5
2526
github.com/stretchr/testify v1.9.0
2627
golang.org/x/oauth2 v0.23.0
2728
google.golang.org/grpc v1.65.0
@@ -72,7 +73,6 @@ require (
7273
github.com/prometheus/procfs v0.15.1 // indirect
7374
github.com/rs/xid v1.4.0 // indirect
7475
github.com/spf13/cobra v1.8.1 // indirect
75-
github.com/spf13/pflag v1.0.5 // indirect
7676
github.com/stoewer/go-strcase v1.3.0 // indirect
7777
github.com/x448/float16 v0.8.4 // indirect
7878
go.opencensus.io v0.24.0 // indirect
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package authn
2+
3+
import "context"
4+
5+
// Context key type defined to avoid collisions in other pkgs using context
6+
// See https://golang.org/pkg/context/#WithValue
7+
type contextKey string
8+
9+
const (
10+
contextUserKey contextKey = "user"
11+
contextGroupsKey contextKey = "groups"
12+
)
13+
14+
// Authenticator is the interface to authenticat for grpc server
15+
type Authenticator interface {
16+
Authenticate(ctx context.Context) (context.Context, error)
17+
}
18+
19+
func newContextWithIdentity(ctx context.Context, user string, groups []string) context.Context {
20+
ctx = context.WithValue(ctx, contextUserKey, user)
21+
return context.WithValue(ctx, contextGroupsKey, groups)
22+
}
+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package authn
2+
3+
import (
4+
"context"
5+
6+
"google.golang.org/grpc/codes"
7+
"google.golang.org/grpc/credentials"
8+
"google.golang.org/grpc/peer"
9+
"google.golang.org/grpc/status"
10+
)
11+
12+
type MtlsAuthenticator struct {
13+
}
14+
15+
func NewMtlsAuthenticator() *MtlsAuthenticator {
16+
return &MtlsAuthenticator{}
17+
}
18+
19+
func (a *MtlsAuthenticator) Authenticate(ctx context.Context) (context.Context, error) {
20+
p, ok := peer.FromContext(ctx)
21+
if !ok {
22+
return ctx, status.Error(codes.Unauthenticated, "no peer found")
23+
}
24+
25+
tlsAuth, ok := p.AuthInfo.(credentials.TLSInfo)
26+
if !ok {
27+
return ctx, status.Error(codes.Unauthenticated, "unexpected peer transport credentials")
28+
}
29+
30+
if len(tlsAuth.State.VerifiedChains) == 0 || len(tlsAuth.State.VerifiedChains[0]) == 0 {
31+
return ctx, status.Error(codes.Unauthenticated, "could not verify peer certificate")
32+
}
33+
34+
if tlsAuth.State.VerifiedChains[0][0] == nil {
35+
return ctx, status.Error(codes.Unauthenticated, "could not verify peer certificate")
36+
}
37+
38+
user := tlsAuth.State.VerifiedChains[0][0].Subject.CommonName
39+
groups := tlsAuth.State.VerifiedChains[0][0].Subject.Organization
40+
newCtx := newContextWithIdentity(ctx, user, groups)
41+
return newCtx, nil
42+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package authn
2+
3+
import (
4+
"context"
5+
"crypto/tls"
6+
"crypto/x509"
7+
"crypto/x509/pkix"
8+
"google.golang.org/grpc/credentials"
9+
"google.golang.org/grpc/peer"
10+
"testing"
11+
)
12+
13+
func TestMtlsAuthenticator(t *testing.T) {
14+
tests := []struct {
15+
name string
16+
authInfo credentials.TLSInfo
17+
valid bool
18+
}{
19+
{
20+
name: "no info",
21+
authInfo: credentials.TLSInfo{},
22+
valid: false,
23+
},
24+
{
25+
name: "nil chain",
26+
authInfo: credentials.TLSInfo{
27+
State: tls.ConnectionState{
28+
VerifiedChains: [][]*x509.Certificate{nil},
29+
},
30+
},
31+
valid: false,
32+
},
33+
{
34+
name: "valid chain",
35+
authInfo: credentials.TLSInfo{
36+
State: tls.ConnectionState{
37+
VerifiedChains: [][]*x509.Certificate{
38+
{
39+
{
40+
Subject: pkix.Name{},
41+
},
42+
},
43+
},
44+
},
45+
},
46+
valid: true,
47+
},
48+
}
49+
50+
for _, test := range tests {
51+
t.Run(test.name, func(t *testing.T) {
52+
p := &peer.Peer{
53+
AuthInfo: test.authInfo,
54+
}
55+
ctx := peer.NewContext(context.Background(), p)
56+
authenticator := MtlsAuthenticator{}
57+
_, err := authenticator.Authenticate(ctx)
58+
if test.valid && err != nil {
59+
t.Errorf("authenticator.Authenticate() = %v", err)
60+
} else if !test.valid && err == nil {
61+
t.Errorf("authenticator.Authenticate() = %v, wanted error", err)
62+
}
63+
})
64+
}
65+
}
+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package authn
2+
3+
import (
4+
"context"
5+
"strings"
6+
7+
"google.golang.org/grpc/codes"
8+
"google.golang.org/grpc/metadata"
9+
"google.golang.org/grpc/status"
10+
authenticationv1 "k8s.io/api/authentication/v1"
11+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
12+
"k8s.io/client-go/kubernetes"
13+
)
14+
15+
type TokenAuthenticator struct {
16+
client kubernetes.Interface
17+
}
18+
19+
var _ Authenticator = &TokenAuthenticator{}
20+
21+
func NewTokenAuthenticator(client kubernetes.Interface) *TokenAuthenticator {
22+
return &TokenAuthenticator{client: client}
23+
}
24+
25+
func (t *TokenAuthenticator) Authenticate(ctx context.Context) (context.Context, error) {
26+
// Extract the metadata from the context
27+
md, ok := metadata.FromIncomingContext(ctx)
28+
if !ok {
29+
return ctx, status.Error(codes.InvalidArgument, "missing metadata")
30+
}
31+
32+
// Extract the access token from the metadata
33+
authorization, ok := md["authorization"]
34+
if !ok || len(authorization) == 0 {
35+
return ctx, status.Error(codes.Unauthenticated, "invalid token")
36+
}
37+
38+
token := strings.TrimPrefix(authorization[0], "Bearer ")
39+
tr, err := t.client.AuthenticationV1().TokenReviews().Create(ctx, &authenticationv1.TokenReview{
40+
Spec: authenticationv1.TokenReviewSpec{
41+
Token: token,
42+
},
43+
}, metav1.CreateOptions{})
44+
if err != nil {
45+
return ctx, err
46+
}
47+
48+
if !tr.Status.Authenticated {
49+
return ctx, status.Error(codes.Unauthenticated, "token not authenticated")
50+
}
51+
52+
newCtx := newContextWithIdentity(ctx, tr.Status.User.Username, tr.Status.User.Groups)
53+
return newCtx, nil
54+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package authn
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"google.golang.org/grpc/metadata"
7+
authenticationv1 "k8s.io/api/authentication/v1"
8+
"k8s.io/apimachinery/pkg/runtime"
9+
"k8s.io/client-go/kubernetes/fake"
10+
clienttesting "k8s.io/client-go/testing"
11+
"testing"
12+
)
13+
14+
func TestTokenAuthenticator(t *testing.T) {
15+
tests := []struct {
16+
name string
17+
metadata metadata.MD
18+
token string
19+
valid bool
20+
}{
21+
{
22+
name: "no authorization field",
23+
metadata: metadata.MD{},
24+
valid: false,
25+
},
26+
{
27+
name: "token is not correct",
28+
metadata: metadata.MD{
29+
"Authorization": []string{"Bearer foo"},
30+
},
31+
token: "bar",
32+
valid: false,
33+
},
34+
{
35+
name: "authorization header is set",
36+
metadata: metadata.MD{
37+
"Authorization": []string{"Bearer foo"},
38+
},
39+
token: "foo",
40+
valid: true,
41+
},
42+
}
43+
44+
for _, test := range tests {
45+
t.Run(test.name, func(t *testing.T) {
46+
ctx := metadata.NewIncomingContext(context.Background(), test.metadata)
47+
client := fake.NewClientset()
48+
client.PrependReactor("create", "tokenreviews", func(action clienttesting.Action) (handled bool, ret runtime.Object, err error) {
49+
createAction := action.(clienttesting.CreateAction)
50+
tr, ok := createAction.GetObject().(*authenticationv1.TokenReview)
51+
if !ok {
52+
return false, nil, fmt.Errorf("not a TokenReview")
53+
}
54+
if tr.Spec.Token != test.token {
55+
return false, nil, fmt.Errorf("invalid token")
56+
}
57+
tr.Status = authenticationv1.TokenReviewStatus{Authenticated: true}
58+
return true, tr, nil
59+
})
60+
authenticator := NewTokenAuthenticator(client)
61+
_, err := authenticator.Authenticate(ctx)
62+
if test.valid {
63+
if err != nil {
64+
t.Errorf("authenticator.Authenticate() = %v", err)
65+
}
66+
67+
}
68+
if !test.valid && err == nil {
69+
t.Errorf("authenticator.Authenticate() = %v, wanted error", err)
70+
}
71+
})
72+
}
73+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package options
2+
3+
import (
4+
"github.com/spf13/pflag"
5+
"math"
6+
"time"
7+
)
8+
9+
type GRPCServerOptions struct {
10+
TLSCertFile string
11+
TLSKeyFile string
12+
ClientCAFile string
13+
ServerBindPort string
14+
MaxConcurrentStreams uint32
15+
MaxReceiveMessageSize int
16+
MaxSendMessageSize int
17+
ConnectionTimeout time.Duration
18+
WriteBufferSize int
19+
ReadBufferSize int
20+
MaxConnectionAge time.Duration
21+
ClientMinPingInterval time.Duration
22+
ServerPingInterval time.Duration
23+
ServerPingTimeout time.Duration
24+
PermitPingWithoutStream bool
25+
}
26+
27+
func NewGRPCServerOptions() *GRPCServerOptions {
28+
return &GRPCServerOptions{
29+
ServerBindPort: "8090",
30+
MaxConcurrentStreams: math.MaxUint32,
31+
MaxReceiveMessageSize: 1024 * 1024 * 4,
32+
MaxSendMessageSize: math.MaxInt32,
33+
ConnectionTimeout: 120 * time.Second,
34+
MaxConnectionAge: time.Duration(math.MaxInt64),
35+
ClientMinPingInterval: 5 * time.Second,
36+
ServerPingInterval: 30 * time.Second,
37+
ServerPingTimeout: 10 * time.Second,
38+
WriteBufferSize: 32 * 1024,
39+
ReadBufferSize: 32 * 1024,
40+
}
41+
}
42+
43+
func (o *GRPCServerOptions) AddFlags(flags *pflag.FlagSet) {
44+
flags.StringVar(&o.ServerBindPort, "grpc-server-bindport", o.ServerBindPort, "gPRC server bind port")
45+
flags.Uint32Var(&o.MaxConcurrentStreams, "grpc-max-concurrent-streams", o.MaxConcurrentStreams, "gPRC max concurrent streams")
46+
flags.IntVar(&o.MaxReceiveMessageSize, "grpc-max-receive-message-size", o.MaxReceiveMessageSize, "gPRC max receive message size")
47+
flags.IntVar(&o.MaxSendMessageSize, "grpc-max-send-message-size", o.MaxSendMessageSize, "gPRC max send message size")
48+
flags.DurationVar(&o.ConnectionTimeout, "grpc-connection-timeout", o.ConnectionTimeout, "gPRC connection timeout")
49+
flags.DurationVar(&o.MaxConnectionAge, "grpc-max-connection-age", o.MaxConnectionAge, "A duration for the maximum amount of time connection may exist before closing")
50+
flags.DurationVar(&o.ClientMinPingInterval, "grpc-client-min-ping-interval", o.ClientMinPingInterval, "Server will terminate the connection if the client pings more than once within this duration")
51+
flags.DurationVar(&o.ServerPingInterval, "grpc-server-ping-interval", o.ServerPingInterval, "Duration after which the server pings the client if no activity is detected")
52+
flags.DurationVar(&o.ServerPingTimeout, "grpc-server-ping-timeout", o.ServerPingTimeout, "Duration the client waits for a response after sending a keepalive ping")
53+
flags.BoolVar(&o.PermitPingWithoutStream, "permit-ping-without-stream", o.PermitPingWithoutStream, "Allow keepalive pings even when there are no active streams")
54+
flags.IntVar(&o.WriteBufferSize, "grpc-write-buffer-size", o.WriteBufferSize, "gPRC write buffer size")
55+
flags.IntVar(&o.ReadBufferSize, "grpc-read-buffer-size", o.ReadBufferSize, "gPRC read buffer size")
56+
flags.StringVar(&o.TLSCertFile, "grpc-tls-cert-file", "", "The path to the tls.crt file")
57+
flags.StringVar(&o.TLSKeyFile, "grpc-tls-key-file", "", "The path to the tls.key file")
58+
flags.StringVar(&o.ClientCAFile, "grpc-client-ca-file", "", "The path to the client ca file, must specify if using mtls authentication type")
59+
}

0 commit comments

Comments
 (0)