Skip to content

Commit c717377

Browse files
committed
feat: graceful server shutdown
Signed-off-by: Artsiom Koltun <artsiom.koltun@intel.com>
1 parent 325cfd9 commit c717377

File tree

3 files changed

+348
-3
lines changed

3 files changed

+348
-3
lines changed

cmd/main.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@ import (
88
"fmt"
99
"log"
1010
"net"
11+
"os"
1112
"strings"
13+
"time"
1214

1315
"github.com/opiproject/gospdk/spdk"
1416

1517
"github.com/opiproject/opi-spdk-bridge/pkg/backend"
1618
"github.com/opiproject/opi-spdk-bridge/pkg/frontend"
1719
"github.com/opiproject/opi-spdk-bridge/pkg/kvm"
1820
"github.com/opiproject/opi-spdk-bridge/pkg/middleend"
21+
"github.com/opiproject/opi-spdk-bridge/pkg/server"
1922

2023
pb "github.com/opiproject/opi-api/storage/v1alpha1/gen/go"
2124
"google.golang.org/grpc"
@@ -89,8 +92,12 @@ func main() {
8992

9093
reflection.Register(s)
9194

92-
log.Printf("Server listening at %v", lis.Addr())
93-
if err := s.Serve(lis); err != nil {
94-
log.Fatalf("failed to serve: %v", err)
95+
wrapper := server.NewGRPCServerWrapper(2*time.Second, s, lis)
96+
97+
wrapper.RunAsync()
98+
if err := wrapper.Wait(); err != nil {
99+
log.Printf("Server error: %v", err)
100+
os.Exit(-1)
95101
}
102+
log.Print("Server successfully stopped")
96103
}

pkg/server/grpcserver.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// Copyright (C) 2023 Intel Corporation
3+
4+
// Package server implements the server
5+
package server
6+
7+
import (
8+
"context"
9+
"errors"
10+
"log"
11+
"net"
12+
"os"
13+
"os/signal"
14+
"syscall"
15+
"time"
16+
17+
"google.golang.org/grpc"
18+
)
19+
20+
// GRPCServerWrapper wraps gRPC server to provide graceful shutdown capabilities
21+
type GRPCServerWrapper struct {
22+
waitSignal chan os.Signal
23+
signalsToWait []os.Signal
24+
25+
timeout time.Duration
26+
27+
server *grpc.Server
28+
listener net.Listener
29+
waitServeComplete chan error
30+
serve func(*grpc.Server, net.Listener) error
31+
}
32+
33+
func defaultServe(s *grpc.Server, l net.Listener) error { return s.Serve(l) }
34+
35+
// NewGRPCServerWrapper creates a new instance of GRPCServerWrapper
36+
func NewGRPCServerWrapper(
37+
timeout time.Duration, server *grpc.Server, listener net.Listener,
38+
) *GRPCServerWrapper {
39+
if timeout == 0 {
40+
log.Panicf("timeout cannot be zero")
41+
}
42+
43+
if server == nil {
44+
log.Panicf("grpc server cannot be nil")
45+
}
46+
47+
if listener == nil {
48+
log.Panic("listener cannot be nil")
49+
}
50+
51+
return &GRPCServerWrapper{
52+
waitSignal: make(chan os.Signal, 1),
53+
signalsToWait: []os.Signal{syscall.SIGINT, syscall.SIGTERM},
54+
timeout: timeout,
55+
server: server,
56+
listener: listener,
57+
waitServeComplete: make(chan error, 1),
58+
serve: defaultServe,
59+
}
60+
}
61+
62+
// RunAsync runs gRPC server
63+
func (s *GRPCServerWrapper) RunAsync() {
64+
go func() {
65+
log.Printf("Server listening at %v", s.listener.Addr())
66+
s.waitServeComplete <- s.serve(s.server, s.listener)
67+
}()
68+
}
69+
70+
// Wait waits for a signal and handles graceful completion
71+
func (s *GRPCServerWrapper) Wait() error {
72+
ctx, cancel := context.WithCancel(context.Background())
73+
defer cancel()
74+
go func() {
75+
signal.Notify(s.waitSignal, s.signalsToWait...)
76+
select {
77+
case sig := <-s.waitSignal:
78+
log.Printf("Got signal: %v", sig)
79+
log.Printf("Start graceful shutdown with timeout: %v", s.timeout)
80+
time.AfterFunc(s.timeout, func() { cancel() })
81+
s.stopServer(ctx)
82+
case <-ctx.Done():
83+
log.Println("Stop listening for a signal")
84+
}
85+
}()
86+
87+
select {
88+
case err := <-s.waitServeComplete:
89+
return err
90+
case <-ctx.Done():
91+
return errors.New("server stop timeout elapsed")
92+
}
93+
}
94+
95+
func (s *GRPCServerWrapper) stopServer(ctx context.Context) {
96+
log.Println("Stop server")
97+
98+
stopped := make(chan struct{}, 1)
99+
go func() {
100+
s.server.GracefulStop()
101+
close(stopped)
102+
}()
103+
104+
select {
105+
case <-ctx.Done():
106+
log.Println("Server stop context done")
107+
s.server.Stop()
108+
case <-stopped:
109+
log.Println("GracefulStop completed")
110+
}
111+
}

pkg/server/grpcserver_test.go

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// Copyright (C) 2023 Intel Corporation
3+
4+
// Package server implements the server
5+
package server
6+
7+
import (
8+
"context"
9+
"errors"
10+
"log"
11+
"net"
12+
"os"
13+
"sync"
14+
"syscall"
15+
"testing"
16+
"time"
17+
18+
"google.golang.org/grpc"
19+
"google.golang.org/grpc/credentials/insecure"
20+
"google.golang.org/grpc/test/bufconn"
21+
22+
pb "github.com/opiproject/opi-api/storage/v1alpha1/gen/go"
23+
)
24+
25+
const timeout = 50 * time.Millisecond
26+
27+
type TestServer struct {
28+
pb.MiddleendEncryptionServiceServer
29+
wait time.Duration
30+
startedHandlingCall sync.WaitGroup
31+
}
32+
33+
func (b *TestServer) CreateEncryptedVolume(_ context.Context, _ *pb.CreateEncryptedVolumeRequest) (*pb.EncryptedVolume, error) {
34+
b.startedHandlingCall.Done()
35+
time.Sleep(b.wait)
36+
return &pb.EncryptedVolume{}, nil
37+
}
38+
39+
type testEnv struct {
40+
testServer *TestServer
41+
client pb.MiddleendEncryptionServiceClient
42+
conn *grpc.ClientConn
43+
ln net.Listener
44+
grpcServer *grpc.Server
45+
}
46+
47+
func (e *testEnv) Close() {
48+
CloseGrpcConnection(e.conn)
49+
CloseListener(e.ln)
50+
}
51+
52+
func createTestEnvironment(callTime time.Duration) *testEnv {
53+
env := &testEnv{}
54+
env.testServer = &TestServer{
55+
pb.UnimplementedMiddleendEncryptionServiceServer{},
56+
callTime,
57+
sync.WaitGroup{},
58+
}
59+
env.grpcServer = grpc.NewServer()
60+
listener := bufconn.Listen(1024 * 1024)
61+
env.ln = listener
62+
pb.RegisterMiddleendEncryptionServiceServer(env.grpcServer, env.testServer)
63+
64+
ctx := context.Background()
65+
conn, err := grpc.DialContext(ctx,
66+
"",
67+
grpc.WithTransportCredentials(insecure.NewCredentials()),
68+
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
69+
return listener.Dial()
70+
}))
71+
if err != nil {
72+
log.Fatal(err)
73+
}
74+
env.conn = conn
75+
env.client = pb.NewMiddleendEncryptionServiceClient(env.conn)
76+
77+
return env
78+
}
79+
80+
func TestGRPCWrapperWait(t *testing.T) {
81+
tests := map[string]struct {
82+
callTime time.Duration
83+
wantErr bool
84+
serve func(*grpc.Server, net.Listener) error
85+
}{
86+
"server stop timeout": {
87+
callTime: timeout * 2,
88+
wantErr: true,
89+
},
90+
"successful server stop": {
91+
callTime: timeout / 10,
92+
wantErr: false,
93+
},
94+
}
95+
for testName, tt := range tests {
96+
t.Run(testName, func(t *testing.T) {
97+
testEnv := createTestEnvironment(tt.callTime)
98+
defer testEnv.Close()
99+
testEnv.testServer.startedHandlingCall.Add(1)
100+
101+
serverWrapper := NewGRPCServerWrapper(timeout, testEnv.grpcServer, testEnv.ln)
102+
// use rare signal in order not to catch a real interrupt
103+
serverWrapper.signalsToWait = []os.Signal{syscall.SIGILL}
104+
serverWrapper.RunAsync()
105+
106+
var (
107+
clientResponse *pb.EncryptedVolume
108+
clientErr error
109+
)
110+
clientDone := sync.WaitGroup{}
111+
clientDone.Add(1)
112+
go func() {
113+
clientResponse, clientErr = testEnv.client.CreateEncryptedVolume(
114+
context.Background(), &pb.CreateEncryptedVolumeRequest{})
115+
clientDone.Done()
116+
}()
117+
testEnv.testServer.startedHandlingCall.Wait()
118+
119+
serverWrapper.waitSignal <- os.Interrupt
120+
waitErr := serverWrapper.Wait()
121+
122+
if (waitErr != nil) != tt.wantErr {
123+
t.Errorf("Expected elapsed: %v. received: %v", tt.wantErr, waitErr)
124+
}
125+
clientDone.Wait()
126+
if (clientErr != nil) != tt.wantErr {
127+
t.Errorf("Expected error %v, received: %v", tt.wantErr, clientErr)
128+
}
129+
if (clientResponse == nil) != tt.wantErr {
130+
t.Errorf("Expected not nil response %v, received: %v", tt.wantErr, clientResponse)
131+
}
132+
})
133+
}
134+
135+
t.Run("failed serve", func(t *testing.T) {
136+
testEnv := createTestEnvironment(timeout)
137+
defer testEnv.Close()
138+
serverWrapper := NewGRPCServerWrapper(timeout, testEnv.grpcServer, testEnv.ln)
139+
// use rare signal in order not to catch a real interrupt
140+
serverWrapper.signalsToWait = []os.Signal{syscall.SIGILL}
141+
stubErr := errors.New("some serve error")
142+
serverWrapper.serve = func(s *grpc.Server, l net.Listener) error { return stubErr }
143+
serverWrapper.RunAsync()
144+
145+
waitErr := serverWrapper.Wait()
146+
147+
if waitErr != stubErr {
148+
t.Errorf("Expected error: %v, received: %v", stubErr, waitErr)
149+
}
150+
})
151+
152+
t.Run("failed serve after signal received", func(t *testing.T) {
153+
testEnv := createTestEnvironment(timeout)
154+
defer testEnv.Close()
155+
serverWrapper := NewGRPCServerWrapper(timeout, testEnv.grpcServer, testEnv.ln)
156+
// use rare signal in order not to catch a real interrupt
157+
serverWrapper.signalsToWait = []os.Signal{syscall.SIGILL}
158+
stubErr := errors.New("some serve error")
159+
wg := sync.WaitGroup{}
160+
wg.Add(1)
161+
serverWrapper.serve = func(s *grpc.Server, l net.Listener) error {
162+
wg.Wait()
163+
return stubErr
164+
}
165+
serverWrapper.RunAsync()
166+
go func() {
167+
serverWrapper.waitSignal <- os.Interrupt
168+
time.Sleep(timeout / 10)
169+
wg.Done()
170+
}()
171+
172+
waitErr := serverWrapper.Wait()
173+
174+
if waitErr != stubErr {
175+
t.Errorf("Expected error: %v, received: %v", stubErr, waitErr)
176+
}
177+
})
178+
}
179+
180+
func TestNewGRPCWrapper(t *testing.T) {
181+
tests := map[string]struct {
182+
timeout time.Duration
183+
grpcServer *grpc.Server
184+
listener net.Listener
185+
wantPanic bool
186+
}{
187+
"zero timeout": {
188+
timeout: 0,
189+
grpcServer: grpc.NewServer(),
190+
listener: bufconn.Listen(32),
191+
wantPanic: true,
192+
},
193+
"nil grpc server": {
194+
timeout: timeout,
195+
grpcServer: nil,
196+
listener: bufconn.Listen(32),
197+
wantPanic: true,
198+
},
199+
"nil listener": {
200+
timeout: timeout,
201+
grpcServer: grpc.NewServer(),
202+
listener: nil,
203+
wantPanic: true,
204+
},
205+
"successful wrapper creation": {
206+
timeout: timeout,
207+
grpcServer: grpc.NewServer(),
208+
listener: bufconn.Listen(32),
209+
wantPanic: false,
210+
},
211+
}
212+
for testName, tt := range tests {
213+
t.Run(testName, func(t *testing.T) {
214+
defer func() {
215+
r := recover()
216+
if (r != nil) != tt.wantPanic {
217+
t.Errorf("GRPCServerWrapper.Run() recover = %v, wantPanic = %v", r, tt.wantPanic)
218+
}
219+
}()
220+
221+
wrapper := NewGRPCServerWrapper(tt.timeout, tt.grpcServer, tt.listener)
222+
if !tt.wantPanic && wrapper == nil {
223+
t.Error("Expect not nil wrapper")
224+
}
225+
})
226+
}
227+
}

0 commit comments

Comments
 (0)