@@ -220,6 +220,43 @@ func TestClient_Connect(t *testing.T) {
220
220
return atomic .LoadInt32 (& onClosedCalled ) == 1
221
221
}, 100 * time .Millisecond , 20 * time .Millisecond , "onClose should be called" )
222
222
})
223
+
224
+ t .Run ("when OnCloseCtx returns error, we still close the connection" , func (t * testing.T ) {
225
+ server , err := NewTestServer ()
226
+ require .NoError (t , err )
227
+ defer server .Close ()
228
+
229
+ var onClosedCalled int32
230
+ onCloseCtx := func (ctx context.Context , c * connection.Connection ) error {
231
+ // increase the counter
232
+ atomic .AddInt32 (& onClosedCalled , 1 )
233
+ return errors .New ("error from on close handler" )
234
+ }
235
+
236
+ var onErrCalled int32
237
+ errHandler := func (err error ) {
238
+ atomic .AddInt32 (& onErrCalled , 1 )
239
+ require .Contains (t , err .Error (), "error from on close handler" )
240
+ }
241
+
242
+ c , err := connection .New (
243
+ server .Addr ,
244
+ testSpec ,
245
+ readMessageLength ,
246
+ writeMessageLength ,
247
+ connection .ErrorHandler (errHandler ),
248
+ connection .OnCloseCtx (onCloseCtx ),
249
+ )
250
+ require .NoError (t , err )
251
+
252
+ err = c .CloseCtx (context .Background ())
253
+ require .NoError (t , err )
254
+
255
+ // eventually the onClosedCalled should be 1
256
+ require .Eventually (t , func () bool {
257
+ return atomic .LoadInt32 (& onClosedCalled ) == 1
258
+ }, 100 * time .Millisecond , 20 * time .Millisecond , "onClose should be called" )
259
+ })
223
260
}
224
261
225
262
func TestClient_Write (t * testing.T ) {
0 commit comments