@@ -7,10 +7,13 @@ import (
77 "io"
88 "os"
99 "testing"
10+ "time"
1011
1112 "github.com/stretchr/testify/assert"
1213 "github.com/stretchr/testify/require"
1314 "google.golang.org/grpc"
15+ "google.golang.org/grpc/codes"
16+ "google.golang.org/grpc/status"
1417 klog "k8s.io/klog/v2"
1518)
1619
@@ -20,9 +23,28 @@ type mockHandler struct {
2023 called bool
2124}
2225
23- func (h * mockHandler ) Handle (_ context.Context , _ interface {}) (interface {}, error ) {
26+ // Handle simulates the behavior of a gRPC handler with an optional delay.
27+ // If the delay completes before the context expires, it returns "test_response" along with predefined error.
28+ // If the context is canceled or the deadline is exceeded before the delay completes,
29+ // it returns a corresponding gRPC status error instead.
30+ func (h * mockHandler ) Handle (ctx context.Context , _ interface {}, delay time.Duration ) (interface {}, error ) {
2431 h .called = true
25- return "test_response" , h .returnErr
32+
33+ select {
34+ case <- time .After (delay ):
35+ return "test_response" , h .returnErr
36+ case <- ctx .Done ():
37+ var grpcCode codes.Code
38+ switch ctx .Err () {
39+ case context .Canceled :
40+ grpcCode = codes .Canceled
41+ case context .DeadlineExceeded :
42+ grpcCode = codes .DeadlineExceeded
43+ default :
44+ grpcCode = codes .Unknown
45+ }
46+ return nil , status .Error (grpcCode , ctx .Err ().Error ())
47+ }
2648}
2749
2850func TestAPIServerInterceptor (t * testing.T ) {
@@ -61,7 +83,7 @@ func TestAPIServerInterceptor(t *testing.T) {
6183 req ,
6284 info ,
6385 func (ctx context.Context , req interface {}) (interface {}, error ) {
64- return tt .handler .Handle (ctx , req )
86+ return tt .handler .Handle (ctx , req , 0 /*delay*/ )
6587 },
6688 )
6789
@@ -96,7 +118,7 @@ func TestAPIServerInterceptorContextPassing(t *testing.T) {
96118 func (receivedCtx context.Context , req interface {}) (interface {}, error ) {
97119 // Verify context value is passed through
98120 assert .Equal (t , "test_value" , receivedCtx .Value (testContextKey ("test_key" )))
99- return handler .Handle (receivedCtx , req )
121+ return handler .Handle (receivedCtx , req , 0 /*delay*/ )
100122 },
101123 )
102124}
@@ -153,7 +175,7 @@ func TestAPIServerInterceptorLogging(t *testing.T) {
153175 "test_request" ,
154176 info ,
155177 func (receivedCtx context.Context , req interface {}) (interface {}, error ) {
156- return handler .Handle (receivedCtx , req )
178+ return handler .Handle (receivedCtx , req , 0 /*delay*/ )
157179 },
158180 )
159181
@@ -192,3 +214,62 @@ func TestAPIServerInterceptorLogging(t *testing.T) {
192214 })
193215 }
194216}
217+
218+ func TestTimeoutInterceptor (t * testing.T ) {
219+ tests := []struct {
220+ expectedError error
221+ name string
222+ timeout time.Duration
223+ handlerDelay time.Duration
224+ expectedCalled bool
225+ }{
226+ {
227+ name : "handler completes before timeout" ,
228+ timeout : 100 * time .Millisecond ,
229+ handlerDelay : 50 * time .Millisecond ,
230+ expectedError : nil ,
231+ expectedCalled : true ,
232+ },
233+ {
234+ name : "handler exceeds timeout" ,
235+ timeout : 50 * time .Millisecond ,
236+ handlerDelay : 100 * time .Millisecond ,
237+ expectedError : status .Error (codes .DeadlineExceeded , context .DeadlineExceeded .Error ()),
238+ expectedCalled : true ,
239+ },
240+ }
241+
242+ for _ , tt := range tests {
243+ t .Run (tt .name , func (t * testing.T ) {
244+ // Create test context and request
245+ ctx := context .Background ()
246+ req := "test_request"
247+ handler := & mockHandler {}
248+
249+ // Create the interceptor with the specified timeout
250+ interceptor := TimeoutInterceptor (tt .timeout )
251+
252+ // Call the interceptor
253+ resp , err := interceptor (
254+ ctx ,
255+ req ,
256+ & grpc.UnaryServerInfo {FullMethod : "TestTimeoutMethod" },
257+ func (ctx context.Context , req interface {}) (interface {}, error ) {
258+ return handler .Handle (ctx , req , tt .handlerDelay )
259+ },
260+ )
261+
262+ // Verify response and error
263+ if tt .expectedError == nil {
264+ // Verify handler was called
265+ assert .Equal (t , tt .expectedCalled , handler .called , "handler call status should match expected" )
266+
267+ require .NoError (t , err )
268+ assert .Equal (t , "test_response" , resp , "response should match expected" )
269+ } else {
270+ require .Error (t , err )
271+ require .Equal (t , tt .expectedError , err , "A matching error is expected" )
272+ }
273+ })
274+ }
275+ }
0 commit comments