Skip to content

Commit c5ce015

Browse files
committed
grpc server options and authenticator
Signed-off-by: Jian Qiu <[email protected]>
1 parent 3a42496 commit c5ce015

File tree

6 files changed

+352
-1
lines changed

6 files changed

+352
-1
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ require (
2222
github.com/openshift/library-go v0.0.0-20240621150525-4bb4238aef81
2323
github.com/prometheus/client_golang v1.19.1
2424
github.com/prometheus/client_model v0.6.1
25+
github.com/spf13/pflag v1.0.5
2526
github.com/stretchr/testify v1.9.0
2627
golang.org/x/oauth2 v0.23.0
2728
google.golang.org/grpc v1.65.0
@@ -72,7 +73,6 @@ require (
7273
github.com/prometheus/procfs v0.15.1 // indirect
7374
github.com/rs/xid v1.4.0 // indirect
7475
github.com/spf13/cobra v1.8.1 // indirect
75-
github.com/spf13/pflag v1.0.5 // indirect
7676
github.com/stoewer/go-strcase v1.3.0 // indirect
7777
github.com/x448/float16 v0.8.4 // indirect
7878
go.opencensus.io v0.24.0 // indirect
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package authn
2+
3+
import "context"
4+
5+
// Context key type defined to avoid collisions in other pkgs using context
6+
// See https://golang.org/pkg/context/#WithValue
7+
type contextKey string
8+
9+
const (
10+
contextUserKey contextKey = "user"
11+
contextGroupsKey contextKey = "groups"
12+
)
13+
14+
// Authenticator is the interface to authenticat for grpc server
15+
type Authenticator interface {
16+
Authenticate(ctx context.Context) (context.Context, error)
17+
}
18+
19+
func newContextWithIdentity(ctx context.Context, user string, groups []string) context.Context {
20+
ctx = context.WithValue(ctx, contextUserKey, user)
21+
return context.WithValue(ctx, contextGroupsKey, groups)
22+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package authn
2+
3+
import (
4+
"context"
5+
6+
"google.golang.org/grpc/codes"
7+
"google.golang.org/grpc/credentials"
8+
"google.golang.org/grpc/peer"
9+
"google.golang.org/grpc/status"
10+
)
11+
12+
type MtlsAuthenticator struct {
13+
}
14+
15+
func NewMtlsAuthenticator() *MtlsAuthenticator {
16+
return &MtlsAuthenticator{}
17+
}
18+
19+
func (a *MtlsAuthenticator) Authenticate(ctx context.Context) (context.Context, error) {
20+
p, ok := peer.FromContext(ctx)
21+
if !ok {
22+
return ctx, status.Error(codes.Unauthenticated, "no peer found")
23+
}
24+
25+
tlsAuth, ok := p.AuthInfo.(credentials.TLSInfo)
26+
if !ok {
27+
return ctx, status.Error(codes.Unauthenticated, "unexpected peer transport credentials")
28+
}
29+
30+
if len(tlsAuth.State.VerifiedChains) == 0 || len(tlsAuth.State.VerifiedChains[0]) == 0 {
31+
return ctx, status.Error(codes.Unauthenticated, "could not verify peer certificate")
32+
}
33+
34+
if tlsAuth.State.VerifiedChains[0][0] == nil {
35+
return ctx, status.Error(codes.Unauthenticated, "could not verify peer certificate")
36+
}
37+
38+
user := tlsAuth.State.VerifiedChains[0][0].Subject.CommonName
39+
groups := tlsAuth.State.VerifiedChains[0][0].Subject.Organization
40+
newCtx := newContextWithIdentity(ctx, user, groups)
41+
return newCtx, nil
42+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package authn
2+
3+
import (
4+
"context"
5+
"strings"
6+
7+
"google.golang.org/grpc/codes"
8+
"google.golang.org/grpc/metadata"
9+
"google.golang.org/grpc/status"
10+
authenticationv1 "k8s.io/api/authentication/v1"
11+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
12+
"k8s.io/client-go/kubernetes"
13+
)
14+
15+
type TokenAuthenticator struct {
16+
client kubernetes.Interface
17+
}
18+
19+
var _ Authenticator = &TokenAuthenticator{}
20+
21+
func NewTokenAuthenticator(client kubernetes.Interface) *TokenAuthenticator {
22+
return &TokenAuthenticator{client: client}
23+
}
24+
25+
func (t *TokenAuthenticator) Authenticate(ctx context.Context) (context.Context, error) {
26+
// Extract the metadata from the context
27+
md, ok := metadata.FromIncomingContext(ctx)
28+
if !ok {
29+
return ctx, status.Error(codes.InvalidArgument, "missing metadata")
30+
}
31+
32+
// Extract the access token from the metadata
33+
authorization, ok := md["authorization"]
34+
if !ok || len(authorization) == 0 {
35+
return ctx, status.Error(codes.Unauthenticated, "invalid token")
36+
}
37+
38+
token := strings.TrimPrefix(authorization[0], "Bearer ")
39+
tr, err := t.client.AuthenticationV1().TokenReviews().Create(ctx, &authenticationv1.TokenReview{
40+
Spec: authenticationv1.TokenReviewSpec{
41+
Token: token,
42+
},
43+
}, metav1.CreateOptions{})
44+
if err != nil {
45+
return ctx, err
46+
}
47+
48+
if !tr.Status.Authenticated {
49+
return ctx, status.Error(codes.Unauthenticated, "token not authenticated")
50+
}
51+
52+
newCtx := newContextWithIdentity(ctx, tr.Status.User.Username, tr.Status.User.Groups)
53+
return newCtx, nil
54+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package options
2+
3+
import (
4+
"github.com/spf13/pflag"
5+
"math"
6+
"time"
7+
)
8+
9+
type GRPCServerOptions struct {
10+
TLSCertFile string
11+
TLSKeyFile string
12+
ClientCAFile string
13+
ServerBindPort string
14+
MaxConcurrentStreams uint32
15+
MaxReceiveMessageSize int
16+
MaxSendMessageSize int
17+
ConnectionTimeout time.Duration
18+
WriteBufferSize int
19+
ReadBufferSize int
20+
MaxConnectionAge time.Duration
21+
ClientMinPingInterval time.Duration
22+
ServerPingInterval time.Duration
23+
ServerPingTimeout time.Duration
24+
PermitPingWithoutStream bool
25+
}
26+
27+
func NewGRPCServerOptions() *GRPCServerOptions {
28+
return &GRPCServerOptions{}
29+
}
30+
31+
func (o *GRPCServerOptions) AddFlags(flags *pflag.FlagSet) {
32+
flags.StringVar(&o.ServerBindPort, "grpc-server-bindport", "8090", "gPRC server bind port")
33+
flags.Uint32Var(&o.MaxConcurrentStreams, "grpc-max-concurrent-streams", math.MaxUint32, "gPRC max concurrent streams")
34+
flags.IntVar(&o.MaxReceiveMessageSize, "grpc-max-receive-message-size", 1024*1024*4, "gPRC max receive message size")
35+
flags.IntVar(&o.MaxSendMessageSize, "grpc-max-send-message-size", math.MaxInt32, "gPRC max send message size")
36+
flags.DurationVar(&o.ConnectionTimeout, "grpc-connection-timeout", 120*time.Second, "gPRC connection timeout")
37+
flags.DurationVar(&o.MaxConnectionAge, "grpc-max-connection-age", time.Duration(math.MaxInt64), "A duration for the maximum amount of time connection may exist before closing")
38+
flags.DurationVar(&o.ClientMinPingInterval, "grpc-client-min-ping-interval", 5*time.Second, "Server will terminate the connection if the client pings more than once within this duration")
39+
flags.DurationVar(&o.ServerPingInterval, "grpc-server-ping-interval", 30*time.Second, "Duration after which the server pings the client if no activity is detected")
40+
flags.DurationVar(&o.ServerPingTimeout, "grpc-server-ping-timeout", 10*time.Second, "Duration the client waits for a response after sending a keepalive ping")
41+
flags.BoolVar(&o.PermitPingWithoutStream, "permit-ping-without-stream", false, "Allow keepalive pings even when there are no active streams")
42+
flags.IntVar(&o.WriteBufferSize, "grpc-write-buffer-size", 32*1024, "gPRC write buffer size")
43+
flags.IntVar(&o.ReadBufferSize, "grpc-read-buffer-size", 32*1024, "gPRC read buffer size")
44+
flags.StringVar(&o.TLSCertFile, "grpc-tls-cert-file", "", "The path to the tls.crt file")
45+
flags.StringVar(&o.TLSKeyFile, "grpc-tls-key-file", "", "The path to the tls.key file")
46+
flags.StringVar(&o.ClientCAFile, "grpc-client-ca-file", "", "The path to the client ca file, must specify if using mtls authentication type")
47+
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
package options
2+
3+
import (
4+
"context"
5+
"crypto/tls"
6+
"crypto/x509"
7+
"fmt"
8+
"google.golang.org/grpc"
9+
"google.golang.org/grpc/credentials"
10+
"google.golang.org/grpc/keepalive"
11+
"open-cluster-management.io/sdk-go/pkg/cloudevents/generic/types"
12+
"open-cluster-management.io/sdk-go/pkg/cloudevents/server"
13+
grpcserver "open-cluster-management.io/sdk-go/pkg/cloudevents/server/grpc"
14+
"open-cluster-management.io/sdk-go/pkg/cloudevents/server/grpc/authn"
15+
"os"
16+
)
17+
18+
// PreStartHook is an interface to start hook before grpc server is started.
19+
type PreStartHook interface {
20+
// Start should be a non-blocking call
21+
Run(ctx context.Context)
22+
}
23+
24+
type Server struct {
25+
options *GRPCServerOptions
26+
authenticators []authn.Authenticator
27+
services map[types.CloudEventsDataType]server.Service
28+
hooks []PreStartHook
29+
}
30+
31+
func NewServer(opt *GRPCServerOptions) *Server {
32+
return &Server{options: opt, services: make(map[types.CloudEventsDataType]server.Service)}
33+
}
34+
35+
func (s *Server) WithAuthenticator(authenticator authn.Authenticator) *Server {
36+
s.authenticators = append(s.authenticators, authenticator)
37+
return s
38+
}
39+
40+
func (s *Server) WithService(t types.CloudEventsDataType, service server.Service) *Server {
41+
s.services[t] = service
42+
return s
43+
}
44+
45+
func (s *Server) WithPreStartHooks(hooks ...PreStartHook) *Server {
46+
s.hooks = append(s.hooks, hooks...)
47+
return s
48+
}
49+
50+
func (s *Server) Run(ctx context.Context) error {
51+
var grpcServerOptions []grpc.ServerOption
52+
grpcServerOptions = append(grpcServerOptions, grpc.MaxRecvMsgSize(s.options.MaxReceiveMessageSize))
53+
grpcServerOptions = append(grpcServerOptions, grpc.MaxSendMsgSize(s.options.MaxSendMessageSize))
54+
grpcServerOptions = append(grpcServerOptions, grpc.MaxConcurrentStreams(s.options.MaxConcurrentStreams))
55+
grpcServerOptions = append(grpcServerOptions, grpc.ConnectionTimeout(s.options.ConnectionTimeout))
56+
grpcServerOptions = append(grpcServerOptions, grpc.WriteBufferSize(s.options.WriteBufferSize))
57+
grpcServerOptions = append(grpcServerOptions, grpc.ReadBufferSize(s.options.ReadBufferSize))
58+
grpcServerOptions = append(grpcServerOptions, grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
59+
MinTime: s.options.ClientMinPingInterval,
60+
PermitWithoutStream: s.options.PermitPingWithoutStream,
61+
}))
62+
grpcServerOptions = append(grpcServerOptions, grpc.KeepaliveParams(keepalive.ServerParameters{
63+
MaxConnectionAge: s.options.MaxConnectionAge,
64+
Time: s.options.ServerPingInterval,
65+
Timeout: s.options.ServerPingTimeout,
66+
}))
67+
68+
// Serve with TLS
69+
serverCerts, err := tls.LoadX509KeyPair(s.options.TLSCertFile, s.options.TLSKeyFile)
70+
if err != nil {
71+
return fmt.Errorf("failed to load broker certificates: %v", err)
72+
}
73+
tlsConfig := &tls.Config{
74+
Certificates: []tls.Certificate{serverCerts},
75+
MinVersion: tls.VersionTLS13,
76+
MaxVersion: tls.VersionTLS13,
77+
}
78+
79+
if s.options.ClientCAFile != "" {
80+
certPool, err := x509.SystemCertPool()
81+
if err != nil {
82+
return fmt.Errorf("failed to load system cert pool: %v", err)
83+
}
84+
caPEM, err := os.ReadFile(s.options.ClientCAFile)
85+
if err != nil {
86+
return fmt.Errorf("failed to read broker client CA file: %v", err)
87+
}
88+
if ok := certPool.AppendCertsFromPEM(caPEM); !ok {
89+
return fmt.Errorf("failed to append broker client CA to cert pool")
90+
}
91+
tlsConfig.ClientCAs = certPool
92+
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
93+
}
94+
95+
grpcServerOptions = append(grpcServerOptions, grpc.Creds(credentials.NewTLS(tlsConfig)))
96+
97+
grpcServerOptions = append(grpcServerOptions,
98+
grpc.ChainUnaryInterceptor(newAuthUnaryInterceptor(s.authenticators...)),
99+
grpc.ChainStreamInterceptor(newAuthStreamInterceptor(s.authenticators...)))
100+
101+
grpcServer := grpc.NewServer(grpcServerOptions...)
102+
grpcEventServer := grpcserver.NewGRPCBroker(grpcServer, ":"+s.options.ServerBindPort)
103+
104+
for t, service := range s.services {
105+
grpcEventServer.RegisterService(t, service)
106+
}
107+
108+
// start hook
109+
for _, hook := range s.hooks {
110+
hook.Run(ctx)
111+
}
112+
113+
go grpcEventServer.Start(ctx)
114+
<-ctx.Done()
115+
return nil
116+
}
117+
118+
func newAuthUnaryInterceptor(authenticators ...authn.Authenticator) grpc.UnaryServerInterceptor {
119+
return func(
120+
ctx context.Context,
121+
req interface{},
122+
info *grpc.UnaryServerInfo,
123+
handler grpc.UnaryHandler,
124+
) (interface{}, error) {
125+
var err error
126+
for _, authenticator := range authenticators {
127+
ctx, err = authenticator.Authenticate(ctx)
128+
if err == nil {
129+
return handler(ctx, req)
130+
}
131+
}
132+
133+
if err != nil {
134+
return nil, err
135+
}
136+
137+
return handler(ctx, req)
138+
}
139+
}
140+
141+
// wrappedAuthStream wraps a grpc.ServerStream associated with an incoming RPC, and
142+
// a custom context containing the user and groups derived from the client certificate
143+
// specified in the incoming RPC metadata
144+
type wrappedAuthStream struct {
145+
grpc.ServerStream
146+
ctx context.Context
147+
}
148+
149+
// Context returns the context associated with the stream
150+
func (w *wrappedAuthStream) Context() context.Context {
151+
return w.ctx
152+
}
153+
154+
// newWrappedAuthStream creates a new wrappedAuthStream
155+
func newWrappedAuthStream(ctx context.Context, s grpc.ServerStream) grpc.ServerStream {
156+
return &wrappedAuthStream{s, ctx}
157+
}
158+
159+
// newAuthStreamInterceptor creates a stream interceptor that retrieves the user and groups
160+
// based on the specified authentication type. It supports retrieving from either the access
161+
// token or the client certificate depending on the provided authNType.
162+
// The interceptor then adds the retrieved identity information (user and groups) to the
163+
// context and invokes the provided handler.
164+
func newAuthStreamInterceptor(authenticators ...authn.Authenticator) grpc.StreamServerInterceptor {
165+
return func(
166+
srv interface{},
167+
ss grpc.ServerStream,
168+
info *grpc.StreamServerInfo,
169+
handler grpc.StreamHandler,
170+
) error {
171+
var err error
172+
ctx := ss.Context()
173+
for _, authenticator := range authenticators {
174+
ctx, err = authenticator.Authenticate(ctx)
175+
if err == nil {
176+
return handler(srv, newWrappedAuthStream(ctx, ss))
177+
}
178+
}
179+
180+
if err != nil {
181+
return err
182+
}
183+
184+
return handler(srv, newWrappedAuthStream(ctx, ss))
185+
}
186+
}

0 commit comments

Comments
 (0)