|
1 | 1 | package connectiontracker_test |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "errors" |
4 | 5 | "net" |
5 | 6 | "sync" |
6 | 7 | "sync/atomic" |
@@ -268,6 +269,10 @@ func TestCloseConnCancelsAndRemoves(t *testing.T) { |
268 | 269 | if atomic.LoadInt32(&cancelled) != 1 { |
269 | 270 | t.Error("CloseConn: cancel function was not called") |
270 | 271 | } |
| 272 | + if len(tracker.ListConnections()) != 0 { |
| 273 | + t.Error("closing connection should not appear in active list") |
| 274 | + } |
| 275 | + tracker.Unregister("user@example.com", id) |
271 | 276 | if len(tracker.ListConnections()) != 0 { |
272 | 277 | t.Error("connection still present after CloseConn") |
273 | 278 | } |
@@ -327,6 +332,92 @@ func TestGetConnCountDecreasesAfterUnregister(t *testing.T) { |
327 | 332 | } |
328 | 333 | } |
329 | 334 |
|
| 335 | +func TestCloseConnKeepsStatsUntilUnregister(t *testing.T) { |
| 336 | + manager := connectiontracker.NewManager() |
| 337 | + tracker := manager.NewTracker() |
| 338 | + |
| 339 | + id, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") |
| 340 | + conn := connectiontracker.WrapConn(&fakeConn{readData: make([]byte, 10)}, entry) |
| 341 | + if _, err := conn.Read(make([]byte, 10)); err != nil { |
| 342 | + t.Fatalf("read before close failed: %v", err) |
| 343 | + } |
| 344 | + |
| 345 | + if ok := tracker.CloseConn(id); !ok { |
| 346 | + t.Fatal("CloseConn: expected true for existing connection") |
| 347 | + } |
| 348 | + |
| 349 | + if _, err := conn.Write(make([]byte, 5)); err != nil { |
| 350 | + t.Fatalf("write after close failed: %v", err) |
| 351 | + } |
| 352 | + |
| 353 | + uplink, downlink, connCount := manager.GetUserStats("user@example.com") |
| 354 | + if uplink != 10 || downlink != 5 { |
| 355 | + t.Fatalf("GetUserStats after close: got uplink=%d downlink=%d, want 10/5", uplink, downlink) |
| 356 | + } |
| 357 | + if connCount != 0 { |
| 358 | + t.Fatalf("GetUserStats connCount after close: got %d, want 0", connCount) |
| 359 | + } |
| 360 | + |
| 361 | + tracker.Unregister("user@example.com", id) |
| 362 | +} |
| 363 | + |
| 364 | +func TestDisconnectEventUsesFinalCountersAfterForcedClose(t *testing.T) { |
| 365 | + manager := connectiontracker.NewManager() |
| 366 | + tracker := manager.NewTracker() |
| 367 | + ch := manager.Subscribe() |
| 368 | + defer manager.Unsubscribe(ch) |
| 369 | + |
| 370 | + id, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") |
| 371 | + <-ch // connected event |
| 372 | + |
| 373 | + conn := connectiontracker.WrapConn(&fakeConn{readData: make([]byte, 7)}, entry) |
| 374 | + if _, err := conn.Read(make([]byte, 7)); err != nil { |
| 375 | + t.Fatalf("read before close failed: %v", err) |
| 376 | + } |
| 377 | + |
| 378 | + if ok := tracker.CloseConn(id); !ok { |
| 379 | + t.Fatal("CloseConn: expected true for existing connection") |
| 380 | + } |
| 381 | + if _, err := conn.Write(make([]byte, 9)); err != nil { |
| 382 | + t.Fatalf("write after close failed: %v", err) |
| 383 | + } |
| 384 | + |
| 385 | + tracker.Unregister("user@example.com", id) |
| 386 | + |
| 387 | + select { |
| 388 | + case ev := <-ch: |
| 389 | + if ev.Connected { |
| 390 | + t.Fatal("expected disconnect event") |
| 391 | + } |
| 392 | + if ev.Info.Uplink != 7 || ev.Info.Downlink != 9 { |
| 393 | + t.Fatalf("disconnect counters: got uplink=%d downlink=%d, want 7/9", ev.Info.Uplink, ev.Info.Downlink) |
| 394 | + } |
| 395 | + case <-time.After(time.Second): |
| 396 | + t.Fatal("timed out waiting for disconnect event") |
| 397 | + } |
| 398 | +} |
| 399 | + |
| 400 | +func TestRepeatedForcedCloseIsIdempotent(t *testing.T) { |
| 401 | + tracker := connectiontracker.New() |
| 402 | + |
| 403 | + var cancelled int32 |
| 404 | + id, _ := tracker.RegisterWithMeta("user@example.com", func() { |
| 405 | + atomic.AddInt32(&cancelled, 1) |
| 406 | + }, "", "") |
| 407 | + |
| 408 | + if !tracker.CloseConn(id) { |
| 409 | + t.Fatal("first CloseConn should find connection") |
| 410 | + } |
| 411 | + if !tracker.CloseConn(id) { |
| 412 | + t.Fatal("second CloseConn should still find tracked connection before Unregister") |
| 413 | + } |
| 414 | + tracker.CancelAll("user@example.com") |
| 415 | + |
| 416 | + if got := atomic.LoadInt32(&cancelled); got != 1 { |
| 417 | + t.Fatalf("cancel count: got %d, want 1", got) |
| 418 | + } |
| 419 | +} |
| 420 | + |
330 | 421 | func TestListConnectionsMetadataFields(t *testing.T) { |
331 | 422 | tracker := connectiontracker.New() |
332 | 423 |
|
@@ -558,3 +649,29 @@ func TestWrapPacketConnUpdatesLastActivity(t *testing.T) { |
558 | 649 | t.Errorf("LastActivity not updated: before=%v after=%v", before, after) |
559 | 650 | } |
560 | 651 | } |
| 652 | + |
| 653 | +func TestWrapConnCountsBytesBeforeErrors(t *testing.T) { |
| 654 | + tracker := connectiontracker.New() |
| 655 | + _, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "") |
| 656 | + |
| 657 | + wrapped := connectiontracker.WrapConn(&fakeConn{ |
| 658 | + readData: []byte("hello"), |
| 659 | + readErr: errors.New("read stopped"), |
| 660 | + writeErr: errors.New("write stopped"), |
| 661 | + }, entry) |
| 662 | + |
| 663 | + if _, err := wrapped.Read(make([]byte, 5)); err == nil { |
| 664 | + t.Fatal("expected read error") |
| 665 | + } |
| 666 | + if _, err := wrapped.Write([]byte("bye")); err == nil { |
| 667 | + t.Fatal("expected write error") |
| 668 | + } |
| 669 | + |
| 670 | + conns := tracker.ListConnections() |
| 671 | + if len(conns) != 1 { |
| 672 | + t.Fatalf("expected 1 connection") |
| 673 | + } |
| 674 | + if conns[0].Uplink != 5 || conns[0].Downlink != 3 { |
| 675 | + t.Fatalf("counters with errors: got uplink=%d downlink=%d, want 5/3", conns[0].Uplink, conns[0].Downlink) |
| 676 | + } |
| 677 | +} |
0 commit comments