diff --git a/.golangci.yml b/.golangci.yml index 06650621..6f7105d0 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -19,12 +19,16 @@ linters-settings: recommendations: - errors forbidigo: + analyze-types: true forbid: - ^fmt.Print(f|ln)?$ - ^log.(Panic|Fatal|Print)(f|ln)?$ - ^os.Exit$ - ^panic$ - ^print(ln)?$ + - p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: "use testify/assert instead" varnamelen: max-distance: 12 min-name-length: 2 @@ -127,9 +131,12 @@ issues: exclude-dirs-use-default: false exclude-rules: # Allow complex tests and examples, better to be self contained - - path: (examples|main\.go|_test\.go) + - path: (examples|main\.go) linters: + - gocognit - forbidigo + - path: _test\.go + linters: - gocognit # Allow forbidden identifiers in CLI commands diff --git a/client_test.go b/client_test.go index 2c17a900..0cff8edc 100644 --- a/client_test.go +++ b/client_test.go @@ -108,15 +108,11 @@ func TestClientWithSTUN(t *testing.T) { // Block until go routine is started to make two almost parallel requests <-started - if _, err = client.SendBindingRequestTo(to); err != nil { - t.Fatal(err) - } + _, err = client.SendBindingRequestTo(to) + assert.NoError(t, err) <-finished - if err1 != nil { - t.Fatal(err) - } - + assert.NoErrorf(t, err1, "should succeed: %v", err) assert.NoError(t, pc.Close()) }) @@ -244,8 +240,7 @@ func TestTCPClient(t *testing.T) { peerAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:12345") require.NoError(t, err) - err = client.CreatePermission(peerAddr) - require.NoError(t, err) + require.NoError(t, client.CreatePermission(peerAddr)) var cid proto.ConnectionID = 5 transactionID := [stun.TransactionIDSize]byte{1, 2, 3} @@ -259,8 +254,7 @@ func TestTCPClient(t *testing.T) { msg, err := stun.Build(attrs...) require.NoError(t, err) - err = client.handleSTUNMessage(msg.Raw, peerAddr) - require.NoError(t, err) + require.NoError(t, client.handleSTUNMessage(msg.Raw, peerAddr)) // Shutdown require.NoError(t, allocation.Close()) diff --git a/internal/allocation/allocation_manager_test.go b/internal/allocation/allocation_manager_test.go index 5a68cda2..1beb8cd5 100644 --- a/internal/allocation/allocation_manager_test.go +++ b/internal/allocation/allocation_manager_test.go @@ -35,9 +35,7 @@ func TestManager(t *testing.T) { network := "udp4" turnSocket, err := net.ListenPacket(network, "0.0.0.0:0") - if err != nil { - panic(err) - } + assert.NoError(t, err) for _, tc := range tt { f := tc.f @@ -54,15 +52,17 @@ func subTestCreateInvalidAllocation(t *testing.T, turnSocket net.PacketConn) { m, err := newTestManager() assert.NoError(t, err) - if a, err := m.CreateAllocation(nil, turnSocket, 0, proto.DefaultLifetime); a != nil || err == nil { - t.Errorf("Illegally created allocation with nil FiveTuple") - } - if a, err := m.CreateAllocation(randomFiveTuple(), nil, 0, proto.DefaultLifetime); a != nil || err == nil { - t.Errorf("Illegally created allocation with nil turnSocket") - } - if a, err := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, 0); a != nil || err == nil { - t.Errorf("Illegally created allocation with 0 lifetime") - } + a, err := m.CreateAllocation(nil, turnSocket, 0, proto.DefaultLifetime) + assert.Nil(t, a, "Illegally created allocation with nil FiveTuple") + assert.Error(t, err, "Illegally created allocation with nil FiveTuple") + + a, err = m.CreateAllocation(randomFiveTuple(), nil, 0, proto.DefaultLifetime) + assert.Nil(t, a, "Illegally created allocation with nil turnSocket") + assert.Error(t, err, "Illegally created allocation with nil turnSocket") + + a, err = m.CreateAllocation(randomFiveTuple(), turnSocket, 0, 0) + assert.Nil(t, a, "Illegally created allocation with 0 lifetime") + assert.Error(t, err, "Illegally created allocation with 0 lifetime") } // Test valid Allocation creations. @@ -73,13 +73,12 @@ func subTestCreateAllocation(t *testing.T, turnSocket net.PacketConn) { assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { - t.Errorf("Failed to create allocation %v %v", a, err) - } + a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime) + assert.NotNil(t, a, "Failed to create allocation") + assert.NoError(t, err, "Failed to create allocation") - if a := m.GetAllocation(fiveTuple); a == nil { - t.Errorf("Failed to get allocation right after creation") - } + a = m.GetAllocation(fiveTuple) + assert.NotNil(t, a, "Failed to get allocation right after creation") } // Test that two allocations can't be created with the same FiveTuple. @@ -90,13 +89,13 @@ func subTestCreateAllocationDuplicateFiveTuple(t *testing.T, turnSocket net.Pack assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { - t.Errorf("Failed to create allocation %v %v", a, err) - } + a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime) + assert.NotNil(t, a, "Failed to create allocation") + assert.NoError(t, err, "Failed to create allocation") - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a != nil || err == nil { - t.Errorf("Was able to create allocation with same FiveTuple twice") - } + a, err = m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime) + assert.Nil(t, a, "Was able to create allocation with same FiveTuple twice") + assert.Error(t, err, "Was able to create allocation with same FiveTuple twice") } func subTestDeleteAllocation(t *testing.T, turnSocket net.PacketConn) { @@ -106,18 +105,16 @@ func subTestDeleteAllocation(t *testing.T, turnSocket net.PacketConn) { assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := manager.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { - t.Errorf("Failed to create allocation %v %v", a, err) - } + a, err := manager.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime) + assert.NotNil(t, a, "Failed to create allocation") + assert.NoError(t, err, "Failed to create allocation") - if a := manager.GetAllocation(fiveTuple); a == nil { - t.Errorf("Failed to get allocation right after creation") - } + a = manager.GetAllocation(fiveTuple) + assert.NotNil(t, a, "Failed to get allocation right after creation") manager.DeleteAllocation(fiveTuple) - if a := manager.GetAllocation(fiveTuple); a != nil { - t.Errorf("Get allocation with %v should be nil after delete", fiveTuple) - } + a = manager.GetAllocation(fiveTuple) + assert.Nilf(t, a, "Failed to delete allocation %v", fiveTuple) } // Test that allocation should be closed if timeout. @@ -134,9 +131,7 @@ func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn) { fiveTuple := randomFiveTuple() a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, lifetime) - if err != nil { - t.Errorf("Failed to create allocation with %v", fiveTuple) - } + assert.NoErrorf(t, err, "Failed to create allocation with %v", fiveTuple) allocations[index] = a } @@ -144,9 +139,7 @@ func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn) { // Make sure all allocations timeout time.Sleep(lifetime + time.Second) for _, alloc := range allocations { - if !isClose(alloc.RelaySocket) { - t.Error("Allocation relay socket should be closed if lifetime timeout") - } + assert.True(t, isClose(alloc.RelaySocket), "Allocation relay socket should be closed if lifetime timeout") } } @@ -166,15 +159,10 @@ func subTestManagerClose(t *testing.T, turnSocket net.PacketConn) { // Make a1 timeout time.Sleep(2 * time.Second) - - if err := manager.Close(); err != nil { - t.Errorf("Manager close with error: %v", err) - } + assert.NoError(t, manager.Close()) for _, alloc := range allocations { - if !isClose(alloc.RelaySocket) { - t.Error("Manager's allocations should be closed") - } + assert.True(t, isClose(alloc.RelaySocket), "Manager's allocations should be closed") } } diff --git a/internal/allocation/allocation_test.go b/internal/allocation/allocation_test.go index b2c67717..1259faed 100644 --- a/internal/allocation/allocation_test.go +++ b/internal/allocation/allocation_test.go @@ -51,19 +51,13 @@ func subTestGetPermission(t *testing.T) { alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") - if err != nil { - t.Fatalf("failed to resolve: %s", err) - } + assert.NoError(t, err) addr2, err := net.ResolveUDPAddr("udp", "127.0.0.1:3479") - if err != nil { - t.Fatalf("failed to resolve: %s", err) - } + assert.NoError(t, err) addr3, err := net.ResolveUDPAddr("udp", "127.0.0.2:3478") - if err != nil { - t.Fatalf("failed to resolve: %s", err) - } + assert.NoError(t, err) perms := &Permission{ Addr: addr, @@ -95,9 +89,7 @@ func subTestAddPermission(t *testing.T) { alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") - if err != nil { - t.Fatalf("failed to resolve: %s", err) - } + assert.NoError(t, err) p := &Permission{ Addr: addr, @@ -116,9 +108,7 @@ func subTestRemovePermission(t *testing.T) { alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") - if err != nil { - t.Fatalf("failed to resolve: %s", err) - } + assert.NoError(t, err) p := &Permission{ Addr: addr, @@ -141,9 +131,7 @@ func subTestAddChannelBind(t *testing.T) { alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") - if err != nil { - t.Fatalf("failed to resolve: %s", err) - } + assert.NoError(t, err) c := NewChannelBind(proto.MinChannelNumber, addr, nil) @@ -167,9 +155,7 @@ func subTestGetChannelByNumber(t *testing.T) { alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") - if err != nil { - t.Fatalf("failed to resolve: %s", err) - } + assert.NoError(t, err) c := NewChannelBind(proto.MinChannelNumber, addr, nil) @@ -188,9 +174,7 @@ func subTestGetChannelByAddr(t *testing.T) { alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") - if err != nil { - t.Fatalf("failed to resolve: %s", err) - } + assert.NoError(t, err) c := NewChannelBind(proto.MinChannelNumber, addr, nil) @@ -210,9 +194,7 @@ func subTestRemoveChannelBind(t *testing.T) { alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") - if err != nil { - t.Fatalf("failed to resolve: %s", err) - } + assert.NoError(t, err) c := NewChannelBind(proto.MinChannelNumber, addr, nil) @@ -250,9 +232,7 @@ func subTestAllocationClose(t *testing.T) { network := "udp" l, err := net.ListenPacket(network, "0.0.0.0:0") - if err != nil { - panic(err) - } + assert.NoError(t, err) alloc := NewAllocation(nil, nil, nil) alloc.RelaySocket = l @@ -261,9 +241,7 @@ func subTestAllocationClose(t *testing.T) { // Add channel addr, err := net.ResolveUDPAddr(network, "127.0.0.1:3478") - if err != nil { - t.Fatalf("failed to resolve: %s", err) - } + assert.NoError(t, err) c := NewChannelBind(proto.MinChannelNumber, addr, nil) _ = alloc.AddChannelBind(c, proto.DefaultLifetime) @@ -271,8 +249,7 @@ func subTestAllocationClose(t *testing.T) { // Add permission alloc.AddPermission(NewPermission(addr, nil)) - err = alloc.Close() - assert.Nil(t, err, "should succeed") + assert.Nil(t, alloc.Close(), "should succeed") assert.True(t, isClose(alloc.RelaySocket), "should be closed") } @@ -285,15 +262,11 @@ func subTestPacketHandler(t *testing.T) { // TURN server initialization turnSocket, err := net.ListenPacket(network, "127.0.0.1:0") - if err != nil { - panic(err) - } + assert.NoError(t, err) // Client listener initialization clientListener, err := net.ListenPacket(network, "127.0.0.1:0") - if err != nil { - panic(err) - } + assert.NoError(t, err) dataCh := make(chan []byte) // Client listener read data @@ -314,17 +287,13 @@ func subTestPacketHandler(t *testing.T) { DstAddr: turnSocket.LocalAddr(), }, turnSocket, 0, proto.DefaultLifetime) - assert.Nil(t, err, "should succeed") + assert.NoError(t, err, "should succeed") peerListener1, err := net.ListenPacket(network, "127.0.0.1:0") - if err != nil { - panic(err) - } + assert.NoError(t, err) peerListener2, err := net.ListenPacket(network, "127.0.0.1:0") - if err != nil { - panic(err) - } + assert.NoError(t, err) // Add permission with peer1 address alloc.AddPermission(NewPermission(peerListener1.LocalAddr(), manager.log)) @@ -345,12 +314,10 @@ func subTestPacketHandler(t *testing.T) { assert.True(t, stun.IsMessage(data), "should be stun message") var msg stun.Message - err = stun.Decode(data, &msg) - assert.Nil(t, err, "decode data to stun message failed") + assert.NoError(t, stun.Decode(data, &msg), "decode data to stun message failed") var msgData proto.Data - err = msgData.GetFrom(&msg) - assert.Nil(t, err, "get data from stun message failed") + assert.NoError(t, msgData.GetFrom(&msg), "get data from stun message failed") assert.Equal(t, targetText, string(msgData), "get message doesn't equal the target text") // Test for channel bind and channel data @@ -364,8 +331,7 @@ func subTestPacketHandler(t *testing.T) { channelData := proto.ChannelData{ Raw: data, } - err = channelData.Decode() - assert.Nil(t, err, fmt.Sprintf("channel data decode with error: %v", err)) + assert.NoError(t, channelData.Decode(), fmt.Sprintf("channel data decode with error: %v", err)) assert.Equal(t, channelBind.Number, channelData.Number, "get channel data's number is invalid") assert.Equal(t, targetText2, string(channelData.Data), "get data doesn't equal the target text.") diff --git a/internal/allocation/channel_bind_test.go b/internal/allocation/channel_bind_test.go index 30e3034a..253c9ab9 100644 --- a/internal/allocation/channel_bind_test.go +++ b/internal/allocation/channel_bind_test.go @@ -9,24 +9,21 @@ import ( "time" "github.com/pion/turn/v4/internal/proto" + "github.com/stretchr/testify/assert" ) func TestChannelBind(t *testing.T) { c := newChannelBind(2 * time.Second) - - if c.allocation.GetChannelByNumber(c.Number) != c { - t.Errorf("GetChannelByNumber(%d) shouldn't be nil after added to allocation", c.Number) - } + assert.Equalf(t, c, c.allocation.GetChannelByNumber(c.Number), + "GetChannelByNumber(%d) shouldn't be nil after added to allocation", c.Number) } func TestChannelBindStart(t *testing.T) { c := newChannelBind(2 * time.Second) time.Sleep(3 * time.Second) - - if c.allocation.GetChannelByNumber(c.Number) != nil { - t.Errorf("GetChannelByNumber(%d) should be nil if timeout", c.Number) - } + assert.Nil(t, c.allocation.GetChannelByNumber(c.Number), + "GetChannelByNumber(%d) should be nil after timeout", c.Number) } func TestChannelBindReset(t *testing.T) { @@ -35,10 +32,8 @@ func TestChannelBindReset(t *testing.T) { time.Sleep(2 * time.Second) c.refresh(3 * time.Second) time.Sleep(2 * time.Second) - - if c.allocation.GetChannelByNumber(c.Number) == nil { - t.Errorf("GetChannelByNumber(%d) shouldn't be nil after refresh", c.Number) - } + assert.NotNil(t, c.allocation.GetChannelByNumber(c.Number), + "GetChannelByNumber(%d) shouldn't be nil after refresh", c.Number) } func newChannelBind(lifetime time.Duration) *ChannelBind { diff --git a/internal/allocation/five_tuple_test.go b/internal/allocation/five_tuple_test.go index 432e95d0..10a080f8 100644 --- a/internal/allocation/five_tuple_test.go +++ b/internal/allocation/five_tuple_test.go @@ -6,19 +6,15 @@ package allocation import ( "net" "testing" + + "github.com/stretchr/testify/assert" ) func TestFiveTupleProtocol(t *testing.T) { udpExpect := Protocol(0) tcpExpect := Protocol(1) - - if udpExpect != UDP { - t.Errorf("Invalid UDP Protocol value, expect %d but %d", udpExpect, UDP) - } - - if tcpExpect != TCP { - t.Errorf("Invalid TCP Protocol value, expect %d but %d", tcpExpect, TCP) - } + assert.Equal(t, UDP, udpExpect) + assert.Equal(t, TCP, tcpExpect) } func TestFiveTupleEqual(t *testing.T) { @@ -67,10 +63,7 @@ func TestFiveTupleEqual(t *testing.T) { t.Run(tc.name, func(t *testing.T) { fact := a.Equal(b) - - if expect != fact { - t.Errorf("%v, %v equal check should be %t, but %t", a, b, expect, fact) - } + assert.Equalf(t, expect, fact, "%v, %v equal check should be %t, but %t", a, b, expect, fact) }) } } diff --git a/internal/proto/addr_test.go b/internal/proto/addr_test.go index 56cc53ca..d21a7828 100644 --- a/internal/proto/addr_test.go +++ b/internal/proto/addr_test.go @@ -7,6 +7,8 @@ import ( "fmt" "net" "testing" + + "github.com/stretchr/testify/assert" ) func TestAddr_FromUDPAddr(t *testing.T) { @@ -16,12 +18,10 @@ func TestAddr_FromUDPAddr(t *testing.T) { } a := new(Addr) a.FromUDPAddr(u) - if !u.IP.Equal(a.IP) || u.Port != a.Port || u.String() != a.String() { - t.Error("not equal") - } - if a.Network() != "turn" { - t.Error("unexpected network") - } + assert.True(t, u.IP.Equal(a.IP)) + assert.Equal(t, u.Port, a.Port) + assert.Equal(t, u.String(), a.String()) + assert.Equal(t, "turn", a.Network()) } func TestAddr_EqualIP(t *testing.T) { @@ -33,12 +33,8 @@ func TestAddr_EqualIP(t *testing.T) { IP: net.IPv4(127, 0, 0, 1), Port: 1338, } - if a.Equal(b) { - t.Error("a != b") - } - if !a.EqualIP(b) { - t.Error("a.IP should equal to b.IP") - } + assert.False(t, a.Equal(b)) + assert.True(t, a.EqualIP(b)) } func TestFiveTuple_Equal(t *testing.T) { @@ -74,11 +70,7 @@ func TestFiveTuple_Equal(t *testing.T) { }, }, } { - if v := tc.a.Equal(tc.b); v != tc.v { - t.Errorf("(%s) %s [%v!=%v] %s", - tc.name, tc.a, v, tc.v, tc.b, - ) - } + assert.Equal(t, tc.v, tc.a.Equal(tc.b), "(%s) %s [%v!=%v] %s", tc.name, tc.a, tc.v, tc.b, tc.b) } } @@ -94,7 +86,5 @@ func TestFiveTuple_String(t *testing.T) { IP: net.IPv4(127, 0, 0, 1), }, }) - if s != "127.0.0.1:200->127.0.0.1:100 (UDP)" { - t.Error("unexpected stringer output") - } + assert.Equal(t, "127.0.0.1:200->127.0.0.1:100 (UDP)", s, "unexpected stringer output") } diff --git a/internal/proto/chandata_test.go b/internal/proto/chandata_test.go index e3e03da5..abe0e920 100644 --- a/internal/proto/chandata_test.go +++ b/internal/proto/chandata_test.go @@ -7,9 +7,10 @@ import ( "bufio" "bytes" "encoding/hex" - "errors" "io" "testing" + + "github.com/stretchr/testify/assert" ) func TestChannelData_Encode(t *testing.T) { @@ -20,15 +21,9 @@ func TestChannelData_Encode(t *testing.T) { chanData.Encode() b := &ChannelData{} b.Raw = append(b.Raw, chanData.Raw...) - if err := b.Decode(); err != nil { - t.Error(err) - } - if !b.Equal(chanData) { - t.Error("not equal") - } - if !IsChannelData(b.Raw) || !IsChannelData(chanData.Raw) { - t.Error("unexpected IsChannelData") - } + assert.NoError(t, b.Decode()) + assert.True(t, b.Equal(chanData)) + assert.True(t, IsChannelData(b.Raw) && IsChannelData(chanData.Raw)) } func TestChannelData_Equal(t *testing.T) { @@ -91,9 +86,7 @@ func TestChannelData_Equal(t *testing.T) { }, }, } { - if v := tc.a.Equal(tc.b); v != tc.value { - t.Errorf("unexpected: (%s) %v != %v", tc.name, tc.value, v) - } + assert.Equal(t, tc.value, tc.a.Equal(tc.b)) } } @@ -131,9 +124,7 @@ func TestChannelData_Decode(t *testing.T) { m := &ChannelData{ Raw: tc.buf, } - if err := m.Decode(); !errors.Is(err, tc.err) { - t.Errorf("unexpected: (%s) %v != %v", tc.name, tc.err, err) - } + assert.ErrorIs(t, m.Decode(), tc.err) } } @@ -147,9 +138,7 @@ func TestChannelData_Reset(t *testing.T) { copy(buf, d.Raw) d.Reset() d.Raw = buf - if err := d.Decode(); err != nil { - t.Fatal(err) - } + assert.NoError(t, d.Decode()) } func TestIsChannelData(t *testing.T) { @@ -170,9 +159,7 @@ func TestIsChannelData(t *testing.T) { buf: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, }, } { - if v := IsChannelData(tc.buf); v != tc.value { - t.Errorf("unexpected: (%s) %v != %v", tc.name, tc.value, v) - } + assert.Equal(t, tc.value, IsChannelData(tc.buf)) } } @@ -210,9 +197,7 @@ func BenchmarkChannelData_Decode(b *testing.B) { for i := 0; i < b.N; i++ { d.Reset() d.Raw = buf - if err := d.Decode(); err != nil { - b.Error(err) - } + assert.NoError(b, d.Decode()) } } @@ -227,9 +212,7 @@ func TestChromeChannelData(t *testing.T) { // Decoding hex data into binary. for s.Scan() { b, err := hex.DecodeString(s.Text()) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) data = append(data, b) } // All hex streams decoded to raw binary format and stored in data slice. @@ -237,9 +220,8 @@ func TestChromeChannelData(t *testing.T) { for i, packet := range data { chanData := new(ChannelData) chanData.Raw = packet - if err := chanData.Decode(); err != nil { - t.Errorf("Packet %d: %v", i, err) - } + assert.NoError(t, chanData.Decode(), "Packet %d errored", i) + encoded := &ChannelData{ Data: chanData.Data, Number: chanData.Number, @@ -247,16 +229,10 @@ func TestChromeChannelData(t *testing.T) { encoded.Encode() decoded := new(ChannelData) decoded.Raw = encoded.Raw - if err := decoded.Decode(); err != nil { - t.Error(err) - } - if !decoded.Equal(chanData) { - t.Error("should be equal") - } + assert.NoError(t, decoded.Decode()) + assert.True(t, decoded.Equal(chanData)) messages = append(messages, chanData) } - if len(messages) != 2 { - t.Error("unexpected message slice list") - } + assert.Equal(t, 2, len(messages), "unexpected number of messages") } diff --git a/internal/proto/chann_test.go b/internal/proto/chann_test.go index e4f80423..0237fa6f 100644 --- a/internal/proto/chann_test.go +++ b/internal/proto/chann_test.go @@ -4,7 +4,6 @@ package proto import ( - "errors" "testing" "github.com/pion/stun/v3" @@ -17,9 +16,7 @@ func BenchmarkChannelNumber(b *testing.B) { m := new(stun.Message) for i := 0; i < b.N; i++ { n := ChannelNumber(12) - if err := n.AddTo(m); err != nil { - b.Fatal(err) - } + assert.NoError(b, n.AddTo(m)) m.Reset() } }) @@ -28,76 +25,64 @@ func BenchmarkChannelNumber(b *testing.B) { assert.NoError(b, ChannelNumber(12).AddTo(m)) for i := 0; i < b.N; i++ { var n ChannelNumber - if err := n.GetFrom(m); err != nil { - b.Fatal(err) - } + assert.NoError(b, n.GetFrom(m)) } }) } -func TestChannelNumber(t *testing.T) { //nolint:cyclop +func TestChannelNumber(t *testing.T) { t.Run("String", func(t *testing.T) { n := ChannelNumber(112) - if n.String() != "112" { - t.Errorf("bad string %s, expected 112", n.String()) - } + assert.Equal(t, "112", n.String()) }) t.Run("NoAlloc", func(t *testing.T) { stunMsg := &stun.Message{} - if wasAllocs(func() { + allocated := wasAllocs(func() { // Case with ChannelNumber on stack. n := ChannelNumber(6) - n.AddTo(stunMsg) //nolint + assert.NoError(t, n.AddTo(stunMsg)) stunMsg.Reset() - }) { - t.Error("Unexpected allocations") - } + }) + assert.False(t, allocated) n := ChannelNumber(12) nP := &n - if wasAllocs(func() { + allocated = wasAllocs(func() { // On heap. - nP.AddTo(stunMsg) //nolint + assert.NoError(t, nP.AddTo(stunMsg)) stunMsg.Reset() - }) { - t.Error("Unexpected allocations") - } + }) + assert.False(t, allocated) }) t.Run("AddTo", func(t *testing.T) { stunMsg := new(stun.Message) chanNumber := ChannelNumber(6) - if err := chanNumber.AddTo(stunMsg); err != nil { - t.Error(err) - } + assert.NoError(t, chanNumber.AddTo(stunMsg)) + stunMsg.WriteHeader() t.Run("GetFrom", func(t *testing.T) { decoded := new(stun.Message) - if _, err := decoded.Write(stunMsg.Raw); err != nil { - t.Fatal("failed to decode message:", err) - } + _, err := decoded.Write(stunMsg.Raw) + assert.NoError(t, err) + var numDecoded ChannelNumber - if err := numDecoded.GetFrom(decoded); err != nil { - t.Fatal(err) - } - if numDecoded != chanNumber { - t.Errorf("Decoded %d, expected %d", numDecoded, chanNumber) - } - if wasAllocs(func() { + err = numDecoded.GetFrom(decoded) + assert.NoError(t, err) + assert.Equal(t, chanNumber, numDecoded) + + allocated := wasAllocs(func() { var num ChannelNumber - num.GetFrom(decoded) //nolint - }) { - t.Error("Unexpected allocations") - } + assert.NoError(t, num.GetFrom(decoded)) + }) + assert.False(t, allocated) + t.Run("HandleErr", func(t *testing.T) { m := new(stun.Message) nHandle := new(ChannelNumber) - if err := nHandle.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) { - t.Errorf("%v should be not found", err) - } + assert.ErrorIs(t, nHandle.GetFrom(m), stun.ErrAttributeNotFound) + m.Add(stun.AttrChannelNumber, []byte{1, 2, 3}) - if !stun.IsAttrSizeInvalid(nHandle.GetFrom(m)) { - t.Error("IsAttrSizeInvalid should be true") - } + assert.True(t, stun.IsAttrSizeInvalid(nHandle.GetFrom(m))) }) }) }) @@ -114,8 +99,7 @@ func TestChannelNumber_Valid(t *testing.T) { {MaxChannelNumber, true}, {MaxChannelNumber + 1, false}, } { - if v := tc.n.Valid(); v != tc.value { - t.Errorf("unexpected: (%s) %v != %v", tc.n.String(), tc.value, v) - } + v := tc.n.Valid() + assert.Equalf(t, tc.value, v, "unexpected: (%s) %v != %v", tc.n.String(), tc.value, v) } } diff --git a/internal/proto/chrome_test.go b/internal/proto/chrome_test.go index 3511f78a..142e5565 100644 --- a/internal/proto/chrome_test.go +++ b/internal/proto/chrome_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/pion/stun/v3" + "github.com/stretchr/testify/assert" ) func TestChromeAllocRequest(t *testing.T) { @@ -23,21 +24,16 @@ func TestChromeAllocRequest(t *testing.T) { // Decoding hex data into binary. for s.Scan() { b, err := hex.DecodeString(s.Text()) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) data = append(data, b) } // All hex streams decoded to raw binary format and stored in data slice. // Decoding packets to messages. for i, packet := range data { m := new(stun.Message) - if _, err := m.Write(packet); err != nil { - t.Errorf("Packet %d: %v", i, err) - } + _, err := m.Write(packet) + assert.NoErrorf(t, err, "Packet %d: %v", i, err) messages = append(messages, m) } - if len(messages) != 4 { - t.Error("unexpected message slice list") - } + assert.Equal(t, 4, len(messages), "unexpected number of messages") } diff --git a/internal/proto/data_test.go b/internal/proto/data_test.go index de714ba7..73b241cc 100644 --- a/internal/proto/data_test.go +++ b/internal/proto/data_test.go @@ -4,8 +4,6 @@ package proto import ( - "bytes" - "errors" "testing" "github.com/pion/stun/v3" @@ -36,55 +34,47 @@ func TestData(t *testing.T) { t.Run("NoAlloc", func(t *testing.T) { stunMsg := new(stun.Message) v := []byte{1, 2, 3, 4} - if wasAllocs(func() { + allocated := wasAllocs(func() { // On stack. d := Data(v) - d.AddTo(stunMsg) //nolint + assert.NoError(t, d.AddTo(stunMsg)) stunMsg.Reset() - }) { - t.Error("Unexpected allocations") - } + }) + assert.False(t, allocated) d := &Data{1, 2, 3, 4} - if wasAllocs(func() { + allocated = wasAllocs(func() { // On heap. - d.AddTo(stunMsg) //nolint + assert.NoError(t, d.AddTo(stunMsg)) stunMsg.Reset() - }) { - t.Error("Unexpected allocations") - } + }) + assert.False(t, allocated) }) t.Run("AddTo", func(t *testing.T) { m := new(stun.Message) data := Data{1, 2, 33, 44, 0x13, 0xaf} - if err := data.AddTo(m); err != nil { - t.Fatal(err) - } + assert.NoError(t, data.AddTo(m)) + m.WriteHeader() t.Run("GetFrom", func(t *testing.T) { decoded := new(stun.Message) - if _, err := decoded.Write(m.Raw); err != nil { - t.Fatal("failed to decode message:", err) - } + _, err := decoded.Write(m.Raw) + assert.NoError(t, err) + var dataDecoded Data - if err := dataDecoded.GetFrom(decoded); err != nil { - t.Fatal(err) - } - if !bytes.Equal(dataDecoded, data) { - t.Error(dataDecoded, "!=", data, "(expected)") - } - if wasAllocs(func() { + assert.NoError(t, dataDecoded.GetFrom(decoded)) + assert.Equal(t, data, dataDecoded) + + allocated := wasAllocs(func() { var dataDecoded Data - dataDecoded.GetFrom(decoded) //nolint - }) { - t.Error("Unexpected allocations") - } + assert.NoError(t, dataDecoded.GetFrom(decoded)) + }) + assert.False(t, allocated) + t.Run("HandleErr", func(t *testing.T) { m := new(stun.Message) var handle Data - if err := handle.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) { - t.Errorf("%v should be not found", err) - } + assert.ErrorIs(t, handle.GetFrom(m), stun.ErrAttributeNotFound) }) }) }) diff --git a/internal/proto/dontfrag_test.go b/internal/proto/dontfrag_test.go index f8100d9a..574d3628 100644 --- a/internal/proto/dontfrag_test.go +++ b/internal/proto/dontfrag_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/pion/stun/v3" + "github.com/stretchr/testify/assert" ) func TestDontFragment(t *testing.T) { @@ -15,29 +16,23 @@ func TestDontFragment(t *testing.T) { t.Run("False", func(t *testing.T) { m := new(stun.Message) m.WriteHeader() - if dontFrag.IsSet(m) { - t.Error("should not be set") - } + assert.False(t, dontFrag.IsSet(m)) }) t.Run("AddTo", func(t *testing.T) { stunMsg := new(stun.Message) - if err := dontFrag.AddTo(stunMsg); err != nil { - t.Error(err) - } + assert.NoError(t, dontFrag.AddTo(stunMsg)) + stunMsg.WriteHeader() t.Run("IsSet", func(t *testing.T) { decoded := new(stun.Message) - if _, err := decoded.Write(stunMsg.Raw); err != nil { - t.Fatal("failed to decode message:", err) - } - if !dontFrag.IsSet(stunMsg) { - t.Error("should be set") - } - if wasAllocs(func() { + _, err := decoded.Write(stunMsg.Raw) + assert.NoError(t, err) + assert.True(t, dontFrag.IsSet(stunMsg)) + + allocated := wasAllocs(func() { dontFrag.IsSet(stunMsg) - }) { - t.Error("unexpected allocations") - } + }) + assert.False(t, allocated) }) }) } diff --git a/internal/proto/evenport_test.go b/internal/proto/evenport_test.go index fb641fe6..d5c17cf6 100644 --- a/internal/proto/evenport_test.go +++ b/internal/proto/evenport_test.go @@ -4,78 +4,64 @@ package proto import ( - "errors" "testing" "github.com/pion/stun/v3" "github.com/stretchr/testify/assert" ) -func TestEvenPort(t *testing.T) { //nolint:cyclop +func TestEvenPort(t *testing.T) { t.Run("String", func(t *testing.T) { p := EvenPort{} - if p.String() != "reserve: false" { - t.Errorf("bad value %q for reserve: false", p.String()) - } + assert.Equal(t, "reserve: false", p.String()) + p.ReservePort = true - if p.String() != "reserve: true" { - t.Errorf("bad value %q for reserve: true", p.String()) - } + assert.Equal(t, "reserve: true", p.String()) }) t.Run("False", func(t *testing.T) { m := new(stun.Message) p := EvenPort{ ReservePort: false, } - if err := p.AddTo(m); err != nil { - t.Error(err) - } + assert.NoError(t, p.AddTo(m)) + m.WriteHeader() decoded := new(stun.Message) - var port EvenPort _, err := decoded.Write(m.Raw) assert.NoError(t, err) + + var port EvenPort assert.NoError(t, port.GetFrom(m)) - if port != p { - t.Fatal("not equal") - } + assert.Equal(t, p, port) }) t.Run("AddTo", func(t *testing.T) { m := new(stun.Message) evenPortAttr := EvenPort{ ReservePort: true, } - if err := evenPortAttr.AddTo(m); err != nil { - t.Error(err) - } + assert.NoError(t, evenPortAttr.AddTo(m)) m.WriteHeader() t.Run("GetFrom", func(t *testing.T) { decoded := new(stun.Message) - if _, err := decoded.Write(m.Raw); err != nil { - t.Fatal("failed to decode message:", err) - } + _, err := decoded.Write(m.Raw) + assert.NoError(t, err) + port := EvenPort{} - if err := port.GetFrom(decoded); err != nil { - t.Fatal(err) - } - if port != evenPortAttr { - t.Errorf("Decoded %q, expected %q", port.String(), evenPortAttr.String()) - } - if wasAllocs(func() { - port.GetFrom(decoded) //nolint - }) { - t.Error("Unexpected allocations") - } + assert.NoError(t, port.GetFrom(decoded)) + assert.Equalf(t, evenPortAttr, port, "Decoded %q, expected %q", port.String(), evenPortAttr.String()) + + allocated := wasAllocs(func() { + assert.NoError(t, port.GetFrom(decoded)) + }) + assert.False(t, allocated) + t.Run("HandleErr", func(t *testing.T) { m := new(stun.Message) var handle EvenPort - if err := handle.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) { - t.Errorf("%v should be not found", err) - } + assert.ErrorIs(t, handle.GetFrom(m), stun.ErrAttributeNotFound) + m.Add(stun.AttrEvenPort, []byte{1, 2, 3}) - if !stun.IsAttrSizeInvalid(handle.GetFrom(m)) { - t.Error("IsAttrSizeInvalid should be true") - } + assert.True(t, stun.IsAttrSizeInvalid(handle.GetFrom(m))) }) }) }) diff --git a/internal/proto/lifetime_test.go b/internal/proto/lifetime_test.go index 83002090..66a46302 100644 --- a/internal/proto/lifetime_test.go +++ b/internal/proto/lifetime_test.go @@ -4,7 +4,6 @@ package proto import ( - "errors" "testing" "time" @@ -18,9 +17,7 @@ func BenchmarkLifetime(b *testing.B) { m := new(stun.Message) for i := 0; i < b.N; i++ { l := Lifetime{time.Second} - if err := l.AddTo(m); err != nil { - b.Fatal(err) - } + assert.NoError(b, l.AddTo(m)) m.Reset() } }) @@ -29,76 +26,62 @@ func BenchmarkLifetime(b *testing.B) { assert.NoError(b, Lifetime{time.Minute}.AddTo(m)) for i := 0; i < b.N; i++ { l := Lifetime{} - if err := l.GetFrom(m); err != nil { - b.Fatal(err) - } + assert.NoError(b, l.GetFrom(m)) } }) } -func TestLifetime(t *testing.T) { //nolint:cyclop +func TestLifetime(t *testing.T) { t.Run("String", func(t *testing.T) { l := Lifetime{time.Second * 10} - if l.String() != "10s" { - t.Errorf("bad string %s, expedted 10s", l) - } + assert.Equal(t, "10s", l.String()) }) t.Run("NoAlloc", func(t *testing.T) { stunMsg := &stun.Message{} - if wasAllocs(func() { + allocated := wasAllocs(func() { // On stack. l := Lifetime{ Duration: time.Minute, } - l.AddTo(stunMsg) //nolint + assert.NoError(t, l.AddTo(stunMsg)) stunMsg.Reset() - }) { - t.Error("Unexpected allocations") - } + }) + assert.False(t, allocated) l := &Lifetime{time.Second} - if wasAllocs(func() { + allocated = wasAllocs(func() { // On heap. - l.AddTo(stunMsg) //nolint + assert.NoError(t, l.AddTo(stunMsg)) stunMsg.Reset() - }) { - t.Error("Unexpected allocations") - } + }) + assert.False(t, allocated) }) t.Run("AddTo", func(t *testing.T) { m := new(stun.Message) lifetime := Lifetime{time.Second * 10} - if err := lifetime.AddTo(m); err != nil { - t.Error(err) - } + assert.NoError(t, lifetime.AddTo(m)) m.WriteHeader() t.Run("GetFrom", func(t *testing.T) { decoded := new(stun.Message) - if _, err := decoded.Write(m.Raw); err != nil { - t.Fatal("failed to decode message:", err) - } + _, err := decoded.Write(m.Raw) + assert.NoError(t, err) + life := Lifetime{} - if err := life.GetFrom(decoded); err != nil { - t.Fatal(err) - } - if life != lifetime { - t.Errorf("Decoded %q, expected %q", life, lifetime) - } - if wasAllocs(func() { - life.GetFrom(decoded) //nolint - }) { - t.Error("Unexpected allocations") - } + assert.NoError(t, life.GetFrom(decoded)) + assert.Equal(t, lifetime, life) + + allocated := wasAllocs(func() { + assert.NoError(t, life.GetFrom(decoded)) + }) + assert.False(t, allocated) + t.Run("HandleErr", func(t *testing.T) { m := new(stun.Message) nHandle := new(Lifetime) - if err := nHandle.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) { - t.Errorf("%v should be not found", err) - } + assert.ErrorIs(t, nHandle.GetFrom(m), stun.ErrAttributeNotFound) + m.Add(stun.AttrLifetime, []byte{1, 2, 3}) - if !stun.IsAttrSizeInvalid(nHandle.GetFrom(m)) { - t.Error("IsAttrSizeInvalid should be true") - } + assert.True(t, stun.IsAttrSizeInvalid(nHandle.GetFrom(m))) }) }) }) diff --git a/internal/proto/peeraddr_test.go b/internal/proto/peeraddr_test.go index 9e45e0b0..1905512b 100644 --- a/internal/proto/peeraddr_test.go +++ b/internal/proto/peeraddr_test.go @@ -18,14 +18,11 @@ func TestPeerAddress(t *testing.T) { Port: 333, } t.Run("String", func(t *testing.T) { - if a.String() != "111.11.1.2:333" { - t.Error("invalid string") - } + assert.Equal(t, "111.11.1.2:333", a.String()) }) m := new(stun.Message) - if err := a.AddTo(m); err != nil { - t.Fatal(err) - } + assert.NoError(t, a.AddTo(m)) + m.WriteHeader() decoded := new(stun.Message) diff --git a/internal/proto/relayedaddr_test.go b/internal/proto/relayedaddr_test.go index 74b4137d..9441f62e 100644 --- a/internal/proto/relayedaddr_test.go +++ b/internal/proto/relayedaddr_test.go @@ -18,14 +18,11 @@ func TestRelayedAddress(t *testing.T) { Port: 333, } t.Run("String", func(t *testing.T) { - if a.String() != "111.11.1.2:333" { - t.Error("invalid string") - } + assert.Equal(t, "111.11.1.2:333", a.String()) }) m := new(stun.Message) - if err := a.AddTo(m); err != nil { - t.Fatal(err) - } + assert.NoError(t, a.AddTo(m)) + m.WriteHeader() decoded := new(stun.Message) diff --git a/internal/proto/reqfamily_test.go b/internal/proto/reqfamily_test.go index 36737244..dd0677d7 100644 --- a/internal/proto/reqfamily_test.go +++ b/internal/proto/reqfamily_test.go @@ -4,88 +4,68 @@ package proto import ( - "errors" "testing" "github.com/pion/stun/v3" + "github.com/stretchr/testify/assert" ) -func TestRequestedAddressFamily(t *testing.T) { //nolint:cyclop +func TestRequestedAddressFamily(t *testing.T) { t.Run("String", func(t *testing.T) { - if RequestedFamilyIPv4.String() != "IPv4" { - t.Errorf("bad string %q, expected %q", RequestedFamilyIPv4, - "IPv4", - ) - } - if RequestedFamilyIPv6.String() != "IPv6" { - t.Errorf("bad string %q, expected %q", RequestedFamilyIPv6, - "IPv6", - ) - } - if RequestedAddressFamily(0x04).String() != "unknown" { - t.Error("should be unknown") - } + assert.Equal(t, "IPv4", RequestedFamilyIPv4.String()) + assert.Equal(t, "IPv6", RequestedFamilyIPv6.String()) + assert.Equal(t, "unknown", RequestedAddressFamily(0x04).String()) }) t.Run("NoAlloc", func(t *testing.T) { stunMsg := &stun.Message{} - if wasAllocs(func() { + allocated := wasAllocs(func() { // On stack. r := RequestedFamilyIPv4 - r.AddTo(stunMsg) //nolint + assert.NoError(t, r.AddTo(stunMsg)) stunMsg.Reset() - }) { - t.Error("Unexpected allocations") - } + }) + assert.False(t, allocated) requestFamilyAttr := new(RequestedAddressFamily) *requestFamilyAttr = RequestedFamilyIPv4 - if wasAllocs(func() { + allocated = wasAllocs(func() { // On heap. - requestFamilyAttr.AddTo(stunMsg) //nolint + assert.NoError(t, requestFamilyAttr.AddTo(stunMsg)) stunMsg.Reset() - }) { - t.Error("Unexpected allocations") - } + }) + assert.False(t, allocated) }) t.Run("AddTo", func(t *testing.T) { stunMsg := new(stun.Message) requestFamilyAddr := RequestedFamilyIPv4 - if err := requestFamilyAddr.AddTo(stunMsg); err != nil { - t.Error(err) - } + assert.NoError(t, requestFamilyAddr.AddTo(stunMsg)) + stunMsg.WriteHeader() t.Run("GetFrom", func(t *testing.T) { decoded := new(stun.Message) - if _, err := decoded.Write(stunMsg.Raw); err != nil { - t.Fatal("failed to decode message:", err) - } + _, err := decoded.Write(stunMsg.Raw) + assert.NoError(t, err) + var req RequestedAddressFamily - if err := req.GetFrom(decoded); err != nil { - t.Fatal(err) - } - if req != requestFamilyAddr { - t.Errorf("Decoded %q, expected %q", req, requestFamilyAddr) - } - if wasAllocs(func() { - requestFamilyAddr.GetFrom(decoded) //nolint - }) { - t.Error("Unexpected allocations") - } + assert.NoError(t, req.GetFrom(decoded)) + assert.Equal(t, requestFamilyAddr, req) + + allocated := wasAllocs(func() { + assert.NoError(t, requestFamilyAddr.GetFrom(decoded)) + }) + assert.False(t, allocated) + t.Run("HandleErr", func(t *testing.T) { m := new(stun.Message) var handle RequestedAddressFamily - if err := handle.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) { - t.Errorf("%v should be not found", err) - } + assert.ErrorIs(t, handle.GetFrom(m), stun.ErrAttributeNotFound) + m.Add(stun.AttrRequestedAddressFamily, []byte{1, 2, 3}) - if !stun.IsAttrSizeInvalid(handle.GetFrom(m)) { - t.Error("IsAttrSizeInvalid should be true") - } + assert.True(t, stun.IsAttrSizeInvalid(handle.GetFrom(m))) + m.Reset() m.Add(stun.AttrRequestedAddressFamily, []byte{5, 0, 0, 0}) - if handle.GetFrom(m) == nil { - t.Error("should error on invalid value") - } + assert.NotNil(t, handle.GetFrom(m), "should not error on unknown value") }) }) }) diff --git a/internal/proto/reqtrans_test.go b/internal/proto/reqtrans_test.go index d024706a..fa2818f4 100644 --- a/internal/proto/reqtrans_test.go +++ b/internal/proto/reqtrans_test.go @@ -4,99 +4,79 @@ package proto import ( - "errors" "testing" "github.com/pion/stun/v3" + "github.com/stretchr/testify/assert" ) -func TestRequestedTransport(t *testing.T) { //nolint:cyclop +func TestRequestedTransport(t *testing.T) { t.Run("String", func(t *testing.T) { transAttr := RequestedTransport{ Protocol: ProtoUDP, } - if transAttr.String() != "protocol: UDP" { - t.Errorf("bad string %q, expected %q", transAttr, - "protocol: UDP", - ) - } + assert.Equal(t, "protocol: UDP", transAttr.String()) + transAttr = RequestedTransport{ Protocol: ProtoTCP, } - if transAttr.String() != "protocol: TCP" { - t.Errorf("bad string %q, expected %q", transAttr, - "protocol: TCP", - ) - } + assert.Equal(t, "protocol: TCP", transAttr.String()) + transAttr.Protocol = 254 - if transAttr.String() != "protocol: 254" { - t.Errorf("bad string %q, expected %q", transAttr, - "protocol: 254", - ) - } + assert.Equal(t, "protocol: 254", transAttr.String()) }) t.Run("NoAlloc", func(t *testing.T) { stunMsg := &stun.Message{} - if wasAllocs(func() { + allocated := wasAllocs(func() { // On stack. r := RequestedTransport{ Protocol: ProtoUDP, } - r.AddTo(stunMsg) //nolint + assert.NoError(t, r.AddTo(stunMsg)) stunMsg.Reset() - }) { - t.Error("Unexpected allocations") - } + }) + assert.False(t, allocated) r := &RequestedTransport{ Protocol: ProtoUDP, } - if wasAllocs(func() { + allocated = wasAllocs(func() { // On heap. - r.AddTo(stunMsg) //nolint + assert.NoError(t, r.AddTo(stunMsg)) stunMsg.Reset() - }) { - t.Error("Unexpected allocations") - } + }) + assert.False(t, allocated) }) t.Run("AddTo", func(t *testing.T) { m := new(stun.Message) transAttr := RequestedTransport{ Protocol: ProtoUDP, } - if err := transAttr.AddTo(m); err != nil { - t.Error(err) - } + assert.NoError(t, transAttr.AddTo(m)) m.WriteHeader() t.Run("GetFrom", func(t *testing.T) { decoded := new(stun.Message) - if _, err := decoded.Write(m.Raw); err != nil { - t.Fatal("failed to decode message:", err) - } + _, err := decoded.Write(m.Raw) + assert.NoError(t, err) + req := RequestedTransport{ Protocol: ProtoUDP, } - if err := req.GetFrom(decoded); err != nil { - t.Fatal(err) - } - if req != transAttr { - t.Errorf("Decoded %q, expected %q", req, transAttr) - } - if wasAllocs(func() { - transAttr.GetFrom(decoded) //nolint - }) { - t.Error("Unexpected allocations") - } + assert.NoError(t, req.GetFrom(decoded)) + assert.Equal(t, transAttr, req) + + allocated := wasAllocs(func() { + assert.NoError(t, transAttr.GetFrom(decoded)) + }) + assert.False(t, allocated) + t.Run("HandleErr", func(t *testing.T) { m := new(stun.Message) var handle RequestedTransport - if err := handle.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) { - t.Errorf("%v should be not found", err) - } + assert.ErrorIs(t, handle.GetFrom(m), stun.ErrAttributeNotFound) + m.Add(stun.AttrRequestedTransport, []byte{1, 2, 3}) - if !stun.IsAttrSizeInvalid(handle.GetFrom(m)) { - t.Error("IsAttrSizeInvalid should be true") - } + assert.True(t, stun.IsAttrSizeInvalid(handle.GetFrom(m))) }) }) }) diff --git a/internal/proto/rsrvtoken_test.go b/internal/proto/rsrvtoken_test.go index 14dcb5ef..4db991ad 100644 --- a/internal/proto/rsrvtoken_test.go +++ b/internal/proto/rsrvtoken_test.go @@ -4,77 +4,64 @@ package proto import ( - "bytes" - "errors" "testing" "github.com/pion/stun/v3" + "github.com/stretchr/testify/assert" ) -func TestReservationToken(t *testing.T) { //nolint:cyclop +func TestReservationToken(t *testing.T) { t.Run("NoAlloc", func(t *testing.T) { stunMsg := &stun.Message{} tok := make([]byte, 8) - if wasAllocs(func() { + allocated := wasAllocs(func() { // On stack. tk := ReservationToken(tok) - tk.AddTo(stunMsg) //nolint + assert.NoError(t, tk.AddTo(stunMsg)) stunMsg.Reset() - }) { - t.Error("Unexpected allocations") - } + }) + assert.False(t, allocated) tk := make(ReservationToken, 8) - if wasAllocs(func() { + allocated = wasAllocs(func() { // On heap. - tk.AddTo(stunMsg) //nolint + assert.NoError(t, tk.AddTo(stunMsg)) stunMsg.Reset() - }) { - t.Error("Unexpected allocations") - } + }) + assert.False(t, allocated) }) t.Run("AddTo", func(t *testing.T) { stunMsg := new(stun.Message) tk := make(ReservationToken, 8) tk[2] = 33 tk[7] = 1 - if err := tk.AddTo(stunMsg); err != nil { - t.Error(err) - } + assert.NoError(t, tk.AddTo(stunMsg)) + stunMsg.WriteHeader() t.Run("HandleErr", func(t *testing.T) { badTk := ReservationToken{34, 45} - if !stun.IsAttrSizeInvalid(badTk.AddTo(stunMsg)) { - t.Error("IsAttrSizeInvalid should be true") - } + assert.True(t, stun.IsAttrSizeInvalid(badTk.AddTo(stunMsg))) }) t.Run("GetFrom", func(t *testing.T) { decoded := new(stun.Message) - if _, err := decoded.Write(stunMsg.Raw); err != nil { - t.Fatal("failed to decode message:", err) - } + _, err := decoded.Write(stunMsg.Raw) + assert.NoError(t, err) + var tok ReservationToken - if err := tok.GetFrom(decoded); err != nil { - t.Fatal(err) - } - if !bytes.Equal(tok, tk) { - t.Errorf("Decoded %v, expected %v", tok, tk) - } - if wasAllocs(func() { - tok.GetFrom(decoded) //nolint - }) { - t.Error("Unexpected allocations") - } + assert.NoError(t, tok.GetFrom(decoded)) + assert.Equal(t, tk, tok) + allocated := wasAllocs(func() { + assert.NoError(t, tok.GetFrom(decoded)) + }) + assert.False(t, allocated) + t.Run("HandleErr", func(t *testing.T) { m := new(stun.Message) var handle ReservationToken - if err := handle.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) { - t.Errorf("%v should be not found", err) - } + assert.ErrorIs(t, handle.GetFrom(m), stun.ErrAttributeNotFound) + m.Add(stun.AttrReservationToken, []byte{1, 2, 3}) - if !stun.IsAttrSizeInvalid(handle.GetFrom(m)) { - t.Error("IsAttrSizeInvalid should be true") - } + assert.True(t, stun.IsAttrSizeInvalid(handle.GetFrom(m))) }) }) }) diff --git a/internal/server/turn_test.go b/internal/server/turn_test.go index 18e0c47d..bc7b44be 100644 --- a/internal/server/turn_test.go +++ b/internal/server/turn_test.go @@ -26,17 +26,13 @@ func TestAllocationLifeTime(t *testing.T) { m := &stun.Message{} lifetimeDuration := allocationLifeTime(m) - - if lifetimeDuration != proto.DefaultLifetime { - t.Errorf("Allocation lifetime should be default time duration") - } - + assert.Equal(t, proto.DefaultLifetime, lifetimeDuration, + "Allocation lifetime should be default time duration") assert.NoError(t, lifetime.AddTo(m)) lifetimeDuration = allocationLifeTime(m) - if lifetimeDuration != lifetime.Duration { - t.Errorf("Expect lifetimeDuration is %s, but %s", lifetime.Duration, lifetimeDuration) - } + assert.Equal(t, lifetime.Duration, lifetimeDuration, + "Allocation lifetime should be equal to the one set in the message") }) // If lifetime is bigger than maximumLifetime @@ -49,9 +45,7 @@ func TestAllocationLifeTime(t *testing.T) { _ = lifetime.AddTo(m2) lifetimeDuration := allocationLifeTime(m2) - if lifetimeDuration != proto.DefaultLifetime { - t.Errorf("Expect lifetimeDuration is %s, but %s", proto.DefaultLifetime, lifetimeDuration) - } + assert.Equal(t, proto.DefaultLifetime, lifetimeDuration) }) t.Run("DeletionZeroLifetime", func(t *testing.T) { diff --git a/lt_cred_test.go b/lt_cred_test.go index e93be15d..c819adf5 100644 --- a/lt_cred_test.go +++ b/lt_cred_test.go @@ -21,9 +21,7 @@ func TestLtCredMech(t *testing.T) { expectedPassword := "Tpz/nKkyvX/vMSLKvL4sbtBt8Vs=" //nolint:gosec actualPassword, _ := longTermCredentials(username, sharedSecret) - if expectedPassword != actualPassword { - t.Errorf("Expected %q, got %q", expectedPassword, actualPassword) - } + assert.Equal(t, expectedPassword, actualPassword) } func TestNewLongTermAuthHandler(t *testing.T) {