Skip to content

Commit f3558ac

Browse files
committed
feat: Implement connection closing logic and enhance connection tracking
1 parent b352e4d commit f3558ac

2 files changed

Lines changed: 146 additions & 26 deletions

File tree

app/connectiontracker/tracker.go

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type ConnEntry struct {
2929
lastActivity int64 // atomic, Unix nanosecond timestamp
3030
uplink int64 // atomic, bytes received from client
3131
downlink int64 // atomic, bytes sent to client
32+
closing atomic.Bool
3233
}
3334

3435
// ConnectionInfo is a read-only snapshot of an active connection's state.
@@ -295,6 +296,14 @@ func disconnectInfo(id uint32, entry *ConnEntry) ConnectionInfo {
295296
}
296297
}
297298

299+
func (e *ConnEntry) isClosing() bool {
300+
return e != nil && e.closing.Load()
301+
}
302+
303+
func (e *ConnEntry) beginClose() bool {
304+
return e != nil && e.closing.CompareAndSwap(false, true)
305+
}
306+
298307
// ListAllConnections returns a snapshot of every active connection across all
299308
// Tracker instances that were created by NewTracker.
300309
func (m *Manager) ListAllConnections() []ConnectionInfo {
@@ -315,7 +324,9 @@ func (m *Manager) GetUserStats(email string) (uplink, downlink int64, connCount
315324
for _, e := range t.conns[email] {
316325
uplink += atomic.LoadInt64(&e.uplink)
317326
downlink += atomic.LoadInt64(&e.downlink)
318-
connCount++
327+
if !e.isClosing() {
328+
connCount++
329+
}
319330
}
320331
t.mu.Unlock()
321332
}
@@ -393,19 +404,16 @@ func (t *Tracker) Unregister(email string, id uint32) {
393404
// CancelAll cancels every active connection belonging to email.
394405
func (t *Tracker) CancelAll(email string) {
395406
t.mu.Lock()
396-
entries := t.conns[email]
397-
delete(t.conns, email)
398-
for id := range entries {
399-
delete(t.byID, id)
407+
entries := make(map[uint32]*ConnEntry)
408+
for id, entry := range t.conns[email] {
409+
entries[id] = entry
400410
}
401411
t.mu.Unlock()
402412

403-
for id, entry := range entries {
404-
t.manager.emit(WatchEvent{
405-
Connected: false,
406-
Info: disconnectInfo(id, entry),
407-
})
408-
entry.Cancel()
413+
for _, entry := range entries {
414+
if entry.beginClose() && entry.Cancel != nil {
415+
entry.Cancel()
416+
}
409417
}
410418
}
411419

@@ -414,22 +422,9 @@ func (t *Tracker) CancelAll(email string) {
414422
func (t *Tracker) CloseConn(id uint32) bool {
415423
t.mu.Lock()
416424
entry, ok := t.byID[id]
417-
if ok {
418-
delete(t.byID, id)
419-
if m := t.conns[entry.Email]; m != nil {
420-
delete(m, id)
421-
if len(m) == 0 {
422-
delete(t.conns, entry.Email)
423-
}
424-
}
425-
}
426425
t.mu.Unlock()
427426

428-
if ok {
429-
t.manager.emit(WatchEvent{
430-
Connected: false,
431-
Info: disconnectInfo(id, entry),
432-
})
427+
if ok && entry.beginClose() && entry.Cancel != nil {
433428
entry.Cancel()
434429
}
435430
return ok
@@ -438,7 +433,12 @@ func (t *Tracker) CloseConn(id uint32) bool {
438433
// GetConnCount returns the number of active connections for email.
439434
func (t *Tracker) GetConnCount(email string) int {
440435
t.mu.Lock()
441-
n := len(t.conns[email])
436+
n := 0
437+
for _, entry := range t.conns[email] {
438+
if !entry.isClosing() {
439+
n++
440+
}
441+
}
442442
t.mu.Unlock()
443443
return n
444444
}
@@ -448,6 +448,9 @@ func (t *Tracker) ListConnections() []ConnectionInfo {
448448
t.mu.Lock()
449449
result := make([]ConnectionInfo, 0, len(t.byID))
450450
for id, entry := range t.byID {
451+
if entry.isClosing() {
452+
continue
453+
}
451454
info := disconnectInfo(id, entry)
452455
info.Uplink = atomic.LoadInt64(&entry.uplink)
453456
info.Downlink = atomic.LoadInt64(&entry.downlink)

app/connectiontracker/tracker_test.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package connectiontracker_test
22

33
import (
4+
"errors"
45
"net"
56
"sync"
67
"sync/atomic"
@@ -268,6 +269,10 @@ func TestCloseConnCancelsAndRemoves(t *testing.T) {
268269
if atomic.LoadInt32(&cancelled) != 1 {
269270
t.Error("CloseConn: cancel function was not called")
270271
}
272+
if len(tracker.ListConnections()) != 0 {
273+
t.Error("closing connection should not appear in active list")
274+
}
275+
tracker.Unregister("user@example.com", id)
271276
if len(tracker.ListConnections()) != 0 {
272277
t.Error("connection still present after CloseConn")
273278
}
@@ -327,6 +332,92 @@ func TestGetConnCountDecreasesAfterUnregister(t *testing.T) {
327332
}
328333
}
329334

335+
func TestCloseConnKeepsStatsUntilUnregister(t *testing.T) {
336+
manager := connectiontracker.NewManager()
337+
tracker := manager.NewTracker()
338+
339+
id, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "")
340+
conn := connectiontracker.WrapConn(&fakeConn{readData: make([]byte, 10)}, entry)
341+
if _, err := conn.Read(make([]byte, 10)); err != nil {
342+
t.Fatalf("read before close failed: %v", err)
343+
}
344+
345+
if ok := tracker.CloseConn(id); !ok {
346+
t.Fatal("CloseConn: expected true for existing connection")
347+
}
348+
349+
if _, err := conn.Write(make([]byte, 5)); err != nil {
350+
t.Fatalf("write after close failed: %v", err)
351+
}
352+
353+
uplink, downlink, connCount := manager.GetUserStats("user@example.com")
354+
if uplink != 10 || downlink != 5 {
355+
t.Fatalf("GetUserStats after close: got uplink=%d downlink=%d, want 10/5", uplink, downlink)
356+
}
357+
if connCount != 0 {
358+
t.Fatalf("GetUserStats connCount after close: got %d, want 0", connCount)
359+
}
360+
361+
tracker.Unregister("user@example.com", id)
362+
}
363+
364+
func TestDisconnectEventUsesFinalCountersAfterForcedClose(t *testing.T) {
365+
manager := connectiontracker.NewManager()
366+
tracker := manager.NewTracker()
367+
ch := manager.Subscribe()
368+
defer manager.Unsubscribe(ch)
369+
370+
id, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "")
371+
<-ch // connected event
372+
373+
conn := connectiontracker.WrapConn(&fakeConn{readData: make([]byte, 7)}, entry)
374+
if _, err := conn.Read(make([]byte, 7)); err != nil {
375+
t.Fatalf("read before close failed: %v", err)
376+
}
377+
378+
if ok := tracker.CloseConn(id); !ok {
379+
t.Fatal("CloseConn: expected true for existing connection")
380+
}
381+
if _, err := conn.Write(make([]byte, 9)); err != nil {
382+
t.Fatalf("write after close failed: %v", err)
383+
}
384+
385+
tracker.Unregister("user@example.com", id)
386+
387+
select {
388+
case ev := <-ch:
389+
if ev.Connected {
390+
t.Fatal("expected disconnect event")
391+
}
392+
if ev.Info.Uplink != 7 || ev.Info.Downlink != 9 {
393+
t.Fatalf("disconnect counters: got uplink=%d downlink=%d, want 7/9", ev.Info.Uplink, ev.Info.Downlink)
394+
}
395+
case <-time.After(time.Second):
396+
t.Fatal("timed out waiting for disconnect event")
397+
}
398+
}
399+
400+
func TestRepeatedForcedCloseIsIdempotent(t *testing.T) {
401+
tracker := connectiontracker.New()
402+
403+
var cancelled int32
404+
id, _ := tracker.RegisterWithMeta("user@example.com", func() {
405+
atomic.AddInt32(&cancelled, 1)
406+
}, "", "")
407+
408+
if !tracker.CloseConn(id) {
409+
t.Fatal("first CloseConn should find connection")
410+
}
411+
if !tracker.CloseConn(id) {
412+
t.Fatal("second CloseConn should still find tracked connection before Unregister")
413+
}
414+
tracker.CancelAll("user@example.com")
415+
416+
if got := atomic.LoadInt32(&cancelled); got != 1 {
417+
t.Fatalf("cancel count: got %d, want 1", got)
418+
}
419+
}
420+
330421
func TestListConnectionsMetadataFields(t *testing.T) {
331422
tracker := connectiontracker.New()
332423

@@ -558,3 +649,29 @@ func TestWrapPacketConnUpdatesLastActivity(t *testing.T) {
558649
t.Errorf("LastActivity not updated: before=%v after=%v", before, after)
559650
}
560651
}
652+
653+
func TestWrapConnCountsBytesBeforeErrors(t *testing.T) {
654+
tracker := connectiontracker.New()
655+
_, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "")
656+
657+
wrapped := connectiontracker.WrapConn(&fakeConn{
658+
readData: []byte("hello"),
659+
readErr: errors.New("read stopped"),
660+
writeErr: errors.New("write stopped"),
661+
}, entry)
662+
663+
if _, err := wrapped.Read(make([]byte, 5)); err == nil {
664+
t.Fatal("expected read error")
665+
}
666+
if _, err := wrapped.Write([]byte("bye")); err == nil {
667+
t.Fatal("expected write error")
668+
}
669+
670+
conns := tracker.ListConnections()
671+
if len(conns) != 1 {
672+
t.Fatalf("expected 1 connection")
673+
}
674+
if conns[0].Uplink != 5 || conns[0].Downlink != 3 {
675+
t.Fatalf("counters with errors: got uplink=%d downlink=%d, want 5/3", conns[0].Uplink, conns[0].Downlink)
676+
}
677+
}

0 commit comments

Comments
 (0)