Skip to content

Commit ab63930

Browse files
committed
feat: added context.Context to each handler
feat: small code cleanup
1 parent 032089c commit ab63930

File tree

7 files changed

+52
-51
lines changed

7 files changed

+52
-51
lines changed

autobahn/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func main() {
1212
r := http.NewServeMux()
1313
r.HandleFunc("/", wsServer.Handler)
1414

15-
wsServer.On("echo", func(c *websocket.Conn, msg *websocket.Message) {
15+
wsServer.On("echo", func(ctx context.Context, c *websocket.Conn, msg *websocket.Message) {
1616
_ = c.Emit("echo", msg.Data)
1717
})
1818

channel.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,11 @@ func newChannel(id string) *Channel {
2121
}
2222

2323
go func() {
24-
for {
25-
select {
26-
case conn := <-c.delConn:
27-
c.mu.Lock()
28-
_ = conn.Close()
29-
delete(c.connections, conn)
30-
c.mu.Unlock()
31-
}
24+
for conn := range c.delConn {
25+
c.mu.Lock()
26+
_ = conn.Close()
27+
delete(c.connections, conn)
28+
c.mu.Unlock()
3229
}
3330
}()
3431

channel_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func TestChannel_Add(t *testing.T) {
1818

1919
ch := wsServer.NewChannel("test-channel-add")
2020

21-
wsServer.OnConnect(func(c *Conn) {
21+
wsServer.OnConnect(func(ctx context.Context, c *Conn) {
2222
ch.Add(c)
2323
require.Equal(t, 1, ch.Count(), "channel must contain only 1 connection")
2424
})
@@ -44,7 +44,7 @@ func TestChannel_Emit(t *testing.T) {
4444
messageBytes, err := json.Marshal(_message)
4545
require.NoError(t, err)
4646

47-
wsServer.OnConnect(func(c *Conn) {
47+
wsServer.OnConnect(func(ctx context.Context, c *Conn) {
4848
ch.Add(c)
4949
time.Sleep(300 * time.Millisecond)
5050
ch.Emit(_message.Name, _message.Data)
@@ -78,7 +78,7 @@ func TestChannel_Remove(t *testing.T) {
7878

7979
ch := wsServer.NewChannel("test-channel-add")
8080

81-
wsServer.OnConnect(func(c *Conn) {
81+
wsServer.OnConnect(func(ctx context.Context, c *Conn) {
8282
ch.Add(c)
8383
require.Equal(t, 1, ch.Count(), "channel must contain only 1 connection")
8484
ch.Remove(c)

conn.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ var pingHeader = ws.Header{
2626
}
2727

2828
var PingInterval = time.Second * 5
29-
var TextMessage = false
29+
30+
const TextMessage = false
3031

3132
// ID return an connection identifier (could be not unique)
3233
func (c *Conn) ID() string {
@@ -43,7 +44,10 @@ func (c *Conn) Emit(name string, data interface{}) error {
4344
Data: data,
4445
}
4546

46-
b, _ := json.Marshal(msg)
47+
b, err := json.Marshal(msg)
48+
if err != nil {
49+
return err
50+
}
4751

4852
opCode := ws.OpBinary
4953
if TextMessage {
@@ -96,8 +100,7 @@ func (c *Conn) Send(data any) error {
96100
Length: int64(len(b)),
97101
}
98102

99-
err := c.Write(h, b)
100-
return err
103+
return c.Write(h, b)
101104
}
102105

103106
// Close closing websocket connection.

conn_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func TestConn_Send_bytes(t *testing.T) {
8080
defer shutdown()
8181

8282
msg := []byte{0, 1, 2, 3, 4, 5, 6, 7}
83-
wsServer.OnConnect(func(c *Conn) {
83+
wsServer.OnConnect(func(ctx context.Context, c *Conn) {
8484
time.Sleep(300 * time.Millisecond)
8585
err := c.Send(msg)
8686
require.NoError(t, err)
@@ -111,7 +111,7 @@ func TestConn_Send_struct(t *testing.T) {
111111
}{
112112
Value: "test",
113113
}
114-
wsServer.OnConnect(func(c *Conn) {
114+
wsServer.OnConnect(func(ctx context.Context, c *Conn) {
115115
time.Sleep(300 * time.Millisecond)
116116
err := c.Send(msg)
117117
require.NoError(t, err)

websocket.go

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Echo server:
1616
1717
r.HandleFunc("/ws", wsServer.Handler)
1818
19-
wsServer.On("echo", func(c *websocket.Conn, msg *websocket.Message) {
19+
wsServer.On("echo", func(ctx context.Context, c *websocket.Conn, msg *websocket.Message) {
2020
c.Emit("echo", msg.Data)
2121
})
2222
@@ -38,7 +38,7 @@ Websocket with group:
3838
3939
ch := wsServer.NewChannel("test")
4040
41-
wsServer.OnConnect(func(c *websocket.Conn) {
41+
wsServer.OnConnect(func(ctx context.Context, c *websocket.Conn) {
4242
ch.Add(c)
4343
ch.Emit("connection", []byte("new connection come"))
4444
})
@@ -75,9 +75,9 @@ type Server struct {
7575

7676
delChan []chan *Conn
7777

78-
onConnect func(c *Conn)
79-
onDisconnect func(c *Conn)
80-
onMessage func(c *Conn, h ws.Header, b []byte)
78+
onConnect func(ctx context.Context, c *Conn)
79+
onDisconnect func(ctx context.Context, c *Conn)
80+
onMessage func(ctx context.Context, c *Conn, h ws.Header, b []byte)
8181

8282
done bool
8383
mu sync.RWMutex
@@ -94,7 +94,7 @@ type Message struct {
9494
// HandlerFunc is a type for handle function all function which has callback have this struct
9595
// as first element returns pointer to connection
9696
// its give opportunity to close connection or emit message to exactly this connection.
97-
type HandlerFunc func(c *Conn, msg *Message)
97+
type HandlerFunc func(ctx context.Context, c *Conn, msg *Message)
9898

9999
// New websocket server handler with the provided options.
100100
func New() *Server {
@@ -104,7 +104,7 @@ func New() *Server {
104104
broadcast: make(chan Message),
105105
callbacks: make(map[string]HandlerFunc),
106106
}
107-
srv.onMessage = func(c *Conn, h ws.Header, b []byte) {
107+
srv.onMessage = func(ctx context.Context, c *Conn, h ws.Header, b []byte) {
108108
_ = c.Write(h, b)
109109
}
110110
return srv
@@ -168,6 +168,7 @@ func (s *Server) Shutdown() error {
168168

169169
// Handler get upgrade connection to RFC 6455 and starting listener for it.
170170
func (s *Server) Handler(w http.ResponseWriter, r *http.Request) {
171+
ctx := r.Context()
171172
var params url.Values = nil
172173

173174
conn, _, _, err := ws.UpgradeHTTP(r, w)
@@ -194,7 +195,7 @@ func (s *Server) Handler(w http.ResponseWriter, r *http.Request) {
194195
done: make(chan bool, 1),
195196
}
196197
connection.startPing()
197-
s.addConn(connection)
198+
s.addConn(ctx, connection)
198199

199200
textPending := false
200201

@@ -206,7 +207,7 @@ func (s *Server) Handler(w http.ResponseWriter, r *http.Request) {
206207
header, _ := ws.ReadHeader(conn)
207208
if err = ws.CheckHeader(header, state); err != nil {
208209
log.Printf("drop ws connection: %v", err)
209-
s.dropConn(connection)
210+
s.dropConn(ctx, connection)
210211
break
211212
}
212213

@@ -263,12 +264,12 @@ func (s *Server) Handler(w http.ResponseWriter, r *http.Request) {
263264
if err != nil {
264265
log.Printf("drop ws connection: OpClose (%v)", err)
265266
}
266-
s.dropConn(connection)
267+
s.dropConn(ctx, connection)
267268
break
268269
}
269270

270271
header.Masked = false
271-
if err = s.processMessage(connection, header, payload); err != nil {
272+
if err = s.processMessage(ctx, connection, header, payload); err != nil {
272273
log.Print(err)
273274
}
274275
}
@@ -316,21 +317,21 @@ func (s *Server) Channels() []string {
316317
}
317318

318319
// OnConnect function which will be called when new connections come.
319-
func (s *Server) OnConnect(f func(c *Conn)) {
320+
func (s *Server) OnConnect(f func(ctx context.Context, c *Conn)) {
320321
s.mu.Lock()
321322
s.onConnect = f
322323
s.mu.Unlock()
323324
}
324325

325326
// OnDisconnect function which will be called when new connections come.
326-
func (s *Server) OnDisconnect(f func(c *Conn)) {
327+
func (s *Server) OnDisconnect(f func(ctx context.Context, c *Conn)) {
327328
s.mu.Lock()
328329
s.onDisconnect = f
329330
s.mu.Unlock()
330331
}
331332

332333
// OnMessage handling byte message. This function works as echo by default
333-
func (s *Server) OnMessage(f func(c *Conn, h ws.Header, b []byte)) {
334+
func (s *Server) OnMessage(f func(ctx context.Context, c *Conn, h ws.Header, b []byte)) {
334335
s.mu.Lock()
335336
s.onMessage = f
336337
s.mu.Unlock()
@@ -372,14 +373,14 @@ func (s *Server) IsClosed() bool {
372373
return s.done
373374
}
374375

375-
func (s *Server) processMessage(c *Conn, h ws.Header, b []byte) error {
376+
func (s *Server) processMessage(ctx context.Context, c *Conn, h ws.Header, b []byte) error {
376377
if len(b) == 0 {
377-
s.onMessage(c, h, b)
378+
s.onMessage(ctx, c, h, b)
378379
return nil
379380
}
380381

381382
if h.OpCode != ws.OpBinary && h.OpCode != ws.OpText {
382-
s.onMessage(c, h, b)
383+
s.onMessage(ctx, c, h, b)
383384
return nil
384385
}
385386

@@ -393,30 +394,30 @@ func (s *Server) processMessage(c *Conn, h ws.Header, b []byte) error {
393394
if err != nil {
394395
return err
395396
}
396-
s.callbacks[msg.Name](c, &Message{
397+
s.callbacks[msg.Name](ctx, c, &Message{
397398
Name: msg.Name,
398399
Data: buf,
399400
})
400401
return nil
401402
}
402-
s.onMessage(c, h, b)
403+
s.onMessage(ctx, c, h, b)
403404

404405
return nil
405406
}
406407

407-
func (s *Server) addConn(conn *Conn) {
408-
if !reflect.ValueOf(s.onConnect).IsNil() {
409-
go s.onConnect(conn)
408+
func (s *Server) addConn(ctx context.Context, conn *Conn) {
409+
if s.onConnect != nil {
410+
go s.onConnect(ctx, conn)
410411
}
411412

412413
s.mu.Lock()
413414
s.connections[conn] = true
414415
s.mu.Unlock()
415416
}
416417

417-
func (s *Server) dropConn(conn *Conn) {
418+
func (s *Server) dropConn(ctx context.Context, conn *Conn) {
418419
if !reflect.ValueOf(s.onDisconnect).IsNil() {
419-
go s.onDisconnect(conn)
420+
go s.onDisconnect(ctx, conn)
420421
}
421422

422423
go func() {

websocket_test.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ func TestServer_OnConnect(t *testing.T) {
115115
messageBytes, err := json.Marshal(msg)
116116
require.NoError(t, err)
117117

118-
wsServer.OnConnect(func(c *Conn) {
118+
wsServer.OnConnect(func(ctx context.Context, c *Conn) {
119119
time.Sleep(300 * time.Millisecond)
120120
err := c.Emit(msg.Name, msg.Data)
121121
require.NoError(t, err)
@@ -155,7 +155,7 @@ func TestServer_OnConnect2(t *testing.T) {
155155
Length: int64(len(msg)),
156156
}
157157

158-
wsServer.OnConnect(func(c *Conn) {
158+
wsServer.OnConnect(func(ctx context.Context, c *Conn) {
159159
time.Sleep(300 * time.Millisecond)
160160
err := c.Write(h, msg)
161161
require.NoError(t, err)
@@ -189,7 +189,7 @@ func TestServer_OnDisconnect(t *testing.T) {
189189
Data: []byte("Hello World"),
190190
}
191191

192-
wsServer.OnDisconnect(func(c *Conn) {
192+
wsServer.OnDisconnect(func(ctx context.Context, c *Conn) {
193193
time.Sleep(300 * time.Millisecond)
194194
_ = c.Emit(msg.Name, msg.Data)
195195
done <- true
@@ -232,7 +232,7 @@ func TestServer_OnMessage(t *testing.T) {
232232
msg := []byte("Hello from byte array")
233233

234234
done := make(chan bool, 1)
235-
wsServer.OnMessage(func(c *Conn, h ws.Header, b []byte) {
235+
wsServer.OnMessage(func(ctx context.Context, c *Conn, h ws.Header, b []byte) {
236236
require.Equal(t, msg, b, "response message must be the same as send")
237237
done <- true
238238
})
@@ -283,7 +283,7 @@ func TestServer_On(t *testing.T) {
283283

284284
done := make(chan bool, 1)
285285

286-
wsServer.On("LoL", func(c *Conn, msg *Message) {
286+
wsServer.On("LoL", func(ctx context.Context, c *Conn, msg *Message) {
287287
require.Equal(t, _message.Name, msg.Name)
288288
var respData dataStruct
289289
require.NoError(t, json.Unmarshal(msg.Data, &respData))
@@ -386,7 +386,7 @@ func TestServerListen(t *testing.T) {
386386
messageBytes, err := json.Marshal(message)
387387
require.NoError(t, err)
388388

389-
wsServer.On("echo", func(c *Conn, msg *Message) {
389+
wsServer.On("echo", func(ctx context.Context, c *Conn, msg *Message) {
390390
require.Equal(t, message.Name, msg.Name)
391391
var respData string
392392
require.NoError(t, json.Unmarshal(msg.Data, &respData))
@@ -495,15 +495,15 @@ func TestServer_ConnectionClose(t *testing.T) {
495495
ticker := time.NewTicker(time.Millisecond * 1)
496496
done := make(chan bool, 1)
497497

498-
wsServer.OnConnect(func(c *Conn) {
498+
wsServer.OnConnect(func(ctx context.Context, c *Conn) {
499499
ch.Add(c)
500500
require.Equal(t, 1, ch.Count(), "channel must contain only 1 connection")
501501
log.Print("Connected")
502502
})
503-
wsServer.OnDisconnect(func(c *Conn) {
503+
wsServer.OnDisconnect(func(ctx context.Context, c *Conn) {
504504
log.Print("Disconnected")
505505
})
506-
wsServer.On("test", func(c *Conn, msg *Message) {
506+
wsServer.On("test", func(ctx context.Context, c *Conn, msg *Message) {
507507
log.Printf("message: %s", msg.Name)
508508
})
509509

0 commit comments

Comments
 (0)