diff --git a/.golangci.yml b/.golangci.yml index 88cb4fb..120faf2 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -19,12 +19,16 @@ linters-settings: recommendations: - errors forbidigo: + analyze-types: true forbid: - ^fmt.Print(f|ln)?$ - ^log.(Panic|Fatal|Print)(f|ln)?$ - ^os.Exit$ - ^panic$ - ^print(ln)?$ + - p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: "use testify/assert instead" varnamelen: max-distance: 12 min-name-length: 2 @@ -127,9 +131,12 @@ issues: exclude-dirs-use-default: false exclude-rules: # Allow complex tests and examples, better to be self contained - - path: (examples|main\.go|_test\.go) + - path: (examples|main\.go) linters: + - gocognit - forbidigo + - path: _test\.go + linters: - gocognit # Allow forbidden identifiers in CLI commands diff --git a/addr_test.go b/addr_test.go index 945426f..da26624 100644 --- a/addr_test.go +++ b/addr_test.go @@ -4,10 +4,11 @@ package stun import ( - "errors" "io" "net" "testing" + + "github.com/stretchr/testify/assert" ) func TestMappedAddress(t *testing.T) { @@ -16,48 +17,32 @@ func TestMappedAddress(t *testing.T) { IP: net.ParseIP("122.12.34.5"), Port: 5412, } - if addr.String() != "122.12.34.5:5412" { - t.Error("bad string", addr) - } + assert.Equal(t, "122.12.34.5:5412", addr.String(), "bad string") t.Run("Bad length", func(t *testing.T) { badAddr := &MappedAddress{ IP: net.IP{1, 2, 3}, } - if err := badAddr.AddTo(msg); err == nil { - t.Error("should error") - } + assert.Error(t, badAddr.AddTo(msg), "should error") }) t.Run("AddTo", func(t *testing.T) { - if err := addr.AddTo(msg); err != nil { - t.Error(err) - } + assert.NoError(t, addr.AddTo(msg)) t.Run("GetFrom", func(t *testing.T) { got := new(MappedAddress) - if err := got.GetFrom(msg); err != nil { - t.Error(err) - } - if !got.IP.Equal(addr.IP) { - t.Error("got bad IP: ", got.IP) - } + assert.NoError(t, got.GetFrom(msg)) + assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP) t.Run("Not found", func(t *testing.T) { message := new(Message) - if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) { - t.Error("should be not found: ", err) - } + assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found") }) t.Run("Bad family", func(t *testing.T) { v, _ := msg.Attributes.Get(AttrMappedAddress) v.Value[0] = 32 - if err := got.GetFrom(msg); err == nil { - t.Error("should error") - } + assert.Error(t, got.GetFrom(msg), "should error") }) t.Run("Bad length", func(t *testing.T) { message := new(Message) message.Add(AttrMappedAddress, []byte{1, 2, 3}) - if err := got.GetFrom(message); !errors.Is(err, io.ErrUnexpectedEOF) { - t.Errorf("<%s> should be <%s>", err, io.ErrUnexpectedEOF) - } + assert.ErrorIs(t, got.GetFrom(message), io.ErrUnexpectedEOF) }) }) }) @@ -70,22 +55,14 @@ func TestMappedAddressV6(t *testing.T) { //nolint:dupl Port: 5412, } t.Run("AddTo", func(t *testing.T) { - if err := addr.AddTo(m); err != nil { - t.Error(err) - } + assert.NoError(t, addr.AddTo(m)) t.Run("GetFrom", func(t *testing.T) { got := new(MappedAddress) - if err := got.GetFrom(m); err != nil { - t.Error(err) - } - if !got.IP.Equal(addr.IP) { - t.Error("got bad IP: ", got.IP) - } + assert.NoError(t, got.GetFrom(m)) + assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP) t.Run("Not found", func(t *testing.T) { message := new(Message) - if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) { - t.Error("should be not found: ", err) - } + assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found") }) }) }) @@ -98,22 +75,14 @@ func TestAlternateServer(t *testing.T) { //nolint:dupl Port: 5412, } t.Run("AddTo", func(t *testing.T) { - if err := addr.AddTo(m); err != nil { - t.Error(err) - } + assert.NoError(t, addr.AddTo(m)) t.Run("GetFrom", func(t *testing.T) { got := new(AlternateServer) - if err := got.GetFrom(m); err != nil { - t.Error(err) - } - if !got.IP.Equal(addr.IP) { - t.Error("got bad IP: ", got.IP) - } + assert.NoError(t, got.GetFrom(m)) + assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP) t.Run("Not found", func(t *testing.T) { message := new(Message) - if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) { - t.Error("should be not found: ", err) - } + assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found") }) }) }) @@ -126,22 +95,14 @@ func TestOtherAddress(t *testing.T) { //nolint:dupl Port: 5412, } t.Run("AddTo", func(t *testing.T) { - if err := addr.AddTo(m); err != nil { - t.Error(err) - } + assert.NoError(t, addr.AddTo(m)) t.Run("GetFrom", func(t *testing.T) { got := new(OtherAddress) - if err := got.GetFrom(m); err != nil { - t.Error(err) - } - if !got.IP.Equal(addr.IP) { - t.Error("got bad IP: ", got.IP) - } + assert.NoError(t, got.GetFrom(m)) + assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP) t.Run("Not found", func(t *testing.T) { message := new(Message) - if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) { - t.Error("should be not found: ", err) - } + assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found") }) }) }) diff --git a/agent_test.go b/agent_test.go index 9777083..7d20b5f 100644 --- a/agent_test.go +++ b/agent_test.go @@ -7,84 +7,44 @@ import ( "errors" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestAgent_ProcessInTransaction(t *testing.T) { msg := New() agent := NewAgent(func(e Event) { - if e.Error != nil { - t.Errorf("got error: %s", e.Error) - } - if !e.Message.Equal(msg) { - t.Errorf("%s (got) != %s (expected)", e.Message, msg) - } + assert.NoError(t, e.Error, "got error") + assert.True(t, e.Message.Equal(msg), "%s (got) != %s (expected)", e.Message, msg) }) - if err := msg.NewTransactionID(); err != nil { - t.Fatal(err) - } - if err := agent.Start(msg.TransactionID, time.Time{}); err != nil { - t.Fatal(err) - } - if err := agent.Process(msg); err != nil { - t.Error(err) - } - if err := agent.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, msg.NewTransactionID()) + assert.NoError(t, agent.Start(msg.TransactionID, time.Time{})) + assert.NoError(t, agent.Process(msg)) + assert.NoError(t, agent.Close()) } func TestAgent_Process(t *testing.T) { msg := New() agent := NewAgent(func(e Event) { - if e.Error != nil { - t.Errorf("got error: %s", e.Error) - } - if !e.Message.Equal(msg) { - t.Errorf("%s (got) != %s (expected)", e.Message, msg) - } + assert.NoError(t, e.Error, "got error") + assert.True(t, e.Message.Equal(msg), "%s (got) != %s (expected)", e.Message, msg) }) - if err := msg.NewTransactionID(); err != nil { - t.Fatal(err) - } - if err := agent.Process(msg); err != nil { - t.Error(err) - } - if err := agent.Close(); err != nil { - t.Error(err) - } - if err := agent.Process(msg); !errors.Is(err, ErrAgentClosed) { - t.Errorf("closed agent should return <%s>, but got <%s>", - ErrAgentClosed, err, - ) - } + assert.NoError(t, msg.NewTransactionID()) + assert.NoError(t, agent.Process(msg)) + assert.NoError(t, agent.Close()) + assert.ErrorIs(t, agent.Process(msg), ErrAgentClosed) } func TestAgent_Start(t *testing.T) { agent := NewAgent(nil) id := NewTransactionID() deadline := time.Now().AddDate(0, 0, 1) - if err := agent.Start(id, deadline); err != nil { - t.Errorf("failed to statt transaction: %s", err) - } - if err := agent.Start(id, deadline); !errors.Is(err, ErrTransactionExists) { - t.Errorf("duplicate start should return <%s>, got <%s>", - ErrTransactionExists, err, - ) - } - if err := agent.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, agent.Start(id, deadline), "failed to start transaction") + assert.ErrorIs(t, agent.Start(id, deadline), ErrTransactionExists) + assert.NoError(t, agent.Close()) id = NewTransactionID() - if err := agent.Start(id, deadline); !errors.Is(err, ErrAgentClosed) { - t.Errorf("start on closed agent should return <%s>, got <%s>", - ErrAgentClosed, err, - ) - } - if err := agent.SetHandler(nil); !errors.Is(err, ErrAgentClosed) { - t.Errorf("SetHandler on closed agent should return <%s>, got <%s>", - ErrAgentClosed, err, - ) - } + assert.ErrorIs(t, agent.Start(id, deadline), ErrAgentClosed) + assert.ErrorIs(t, agent.SetHandler(nil), ErrAgentClosed) } func TestAgent_Stop(t *testing.T) { @@ -92,36 +52,20 @@ func TestAgent_Stop(t *testing.T) { agent := NewAgent(func(e Event) { called <- e }) - if err := agent.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) { - t.Fatalf("unexpected error: %s, should be %s", err, ErrTransactionNotExists) - } + assert.ErrorIs(t, agent.Stop(transactionID{}), ErrTransactionNotExists) id := NewTransactionID() timeout := time.Millisecond * 200 - if err := agent.Start(id, time.Now().Add(timeout)); err != nil { - t.Fatal(err) - } - if err := agent.Stop(id); err != nil { - t.Fatal(err) - } + assert.NoError(t, agent.Start(id, time.Now().Add(timeout))) + assert.NoError(t, agent.Stop(id)) select { case e := <-called: - if !errors.Is(e.Error, ErrTransactionStopped) { - t.Fatalf("unexpected error: %s, should be %s", - e.Error, ErrTransactionStopped, - ) - } + assert.ErrorIs(t, e.Error, ErrTransactionStopped) case <-time.After(timeout * 2): - t.Fatal("timed out") - } - if err := agent.Close(); err != nil { - t.Fatal(err) - } - if err := agent.Close(); !errors.Is(err, ErrAgentClosed) { - t.Fatalf("a.Close returned %s instead of %s", err, ErrAgentClosed) - } - if err := agent.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) { - t.Fatalf("unexpected error: %s, should be %s", err, ErrAgentClosed) + assert.Fail(t, "timed out") } + assert.NoError(t, agent.Close()) + assert.ErrorIs(t, agent.Close(), ErrAgentClosed) + assert.ErrorIs(t, agent.Stop(transactionID{}), ErrAgentClosed) } func TestAgent_GC(t *testing.T) { //nolint:cyclop @@ -136,60 +80,41 @@ func TestAgent_GC(t *testing.T) { //nolint:cyclop agent.SetHandler(func(e Event) { //nolint:errcheck,gosec id := e.TransactionID shouldTimeOut, found := shouldTimeOutID[id] - if !found { - t.Error("unexpected transaction ID") - } - if shouldTimeOut && !errors.Is(e.Error, ErrTransactionTimeOut) { - t.Errorf("%x should time out, but got %v", id, e.Error) - } - if !shouldTimeOut && errors.Is(e.Error, ErrTransactionTimeOut) { - t.Errorf("%x should not time out, but got %v", id, e.Error) + assert.True(t, found, "unexpected transaction ID") + if shouldTimeOut { + assert.ErrorIs(t, e.Error, ErrTransactionTimeOut, "%x should time out", id) + } else { + assert.False(t, errors.Is(e.Error, ErrTransactionTimeOut), "%x should not time out", id) } }) for i := 0; i < 5; i++ { id := NewTransactionID() shouldTimeOutID[id] = false - if err := agent.Start(id, deadline); err != nil { - t.Fatal(err) - } + assert.NoError(t, agent.Start(id, deadline)) } for i := 0; i < 5; i++ { id := NewTransactionID() shouldTimeOutID[id] = true - if err := agent.Start(id, deadlineNotGC); err != nil { - t.Fatal(err) - } - } - if err := agent.Collect(gcDeadline); err != nil { - t.Fatal(err) - } - if err := agent.Close(); err != nil { - t.Error(err) - } - if err := agent.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) { - t.Errorf("should <%s>, but got <%s>", ErrAgentClosed, err) + assert.NoError(t, agent.Start(id, deadlineNotGC)) } + assert.NoError(t, agent.Collect(gcDeadline)) + assert.NoError(t, agent.Close()) + assert.ErrorIs(t, agent.Collect(gcDeadline), ErrAgentClosed) } func BenchmarkAgent_GC(b *testing.B) { agent := NewAgent(nil) deadline := time.Now().AddDate(0, 0, 1) for i := 0; i < agentCollectCap; i++ { - if err := agent.Start(NewTransactionID(), deadline); err != nil { - b.Fatal(err) - } + assert.NoError(b, agent.Start(NewTransactionID(), deadline)) } defer func() { - if err := agent.Close(); err != nil { - b.Error(err) - } + assert.NoError(b, agent.Close()) }() b.ReportAllocs() gcDeadline := deadline.Add(-time.Second) for i := 0; i < b.N; i++ { - if err := agent.Collect(gcDeadline); err != nil { - b.Fatal(err) - } + assert.NoError(b, agent.Collect(gcDeadline)) } } @@ -197,20 +122,14 @@ func BenchmarkAgent_Process(b *testing.B) { agent := NewAgent(nil) deadline := time.Now().AddDate(0, 0, 1) for i := 0; i < 1000; i++ { - if err := agent.Start(NewTransactionID(), deadline); err != nil { - b.Fatal(err) - } + assert.NoError(b, agent.Start(NewTransactionID(), deadline)) } defer func() { - if err := agent.Close(); err != nil { - b.Error(err) - } + assert.NoError(b, agent.Close()) }() b.ReportAllocs() m := MustBuild(TransactionID) for i := 0; i < b.N; i++ { - if err := agent.Process(m); err != nil { - b.Fatal(err) - } + assert.NoError(b, agent.Process(m)) } } diff --git a/attributes_debug_test.go b/attributes_debug_test.go index e2c537a..96e1ff0 100644 --- a/attributes_debug_test.go +++ b/attributes_debug_test.go @@ -6,7 +6,11 @@ package stun -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) func TestAttrOverflowErr_Error(t *testing.T) { err := AttrOverflowErr{ @@ -14,9 +18,7 @@ func TestAttrOverflowErr_Error(t *testing.T) { Max: 50, Type: AttrLifetime, } - if err.Error() != "incorrect length of LIFETIME attribute: 100 exceeds maximum 50" { - t.Error("bad error string", err) - } + assert.Equal(t, "incorrect length of LIFETIME attribute: 100 exceeds maximum 50", err.Error()) } func TestAttrLengthErr_Error(t *testing.T) { @@ -25,7 +27,5 @@ func TestAttrLengthErr_Error(t *testing.T) { Expected: 15, Got: 99, } - if err.Error() != "incorrect length of ERROR-CODE attribute: got 99, expected 15" { - t.Errorf("bad error string: %s", err) - } + assert.Equal(t, "incorrect length of ERROR-CODE attribute: got 99, expected 15", err.Error()) } diff --git a/attributes_test.go b/attributes_test.go index a76e286..a12a01d 100644 --- a/attributes_test.go +++ b/attributes_test.go @@ -6,6 +6,8 @@ package stun import ( "bytes" "testing" + + "github.com/stretchr/testify/assert" ) func BenchmarkMessage_GetNotFound(b *testing.B) { @@ -31,16 +33,10 @@ func TestRawAttribute_AddTo(t *testing.T) { Type: AttrData, Value: v, }) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) gotV, gotErr := m.Get(AttrData) - if gotErr != nil { - t.Fatal(gotErr) - } - if !bytes.Equal(gotV, v) { - t.Error("value mismatch") - } + assert.NoError(t, gotErr) + assert.True(t, bytes.Equal(gotV, v), "value mismatch") } func TestMessage_GetNoAllocs(t *testing.T) { @@ -52,17 +48,13 @@ func TestMessage_GetNoAllocs(t *testing.T) { allocs := testing.AllocsPerRun(10, func() { msg.Get(AttrSoftware) //nolint:errcheck,gosec }) - if allocs > 0 { - t.Error("allocated memory, but should not") - } + assert.Zero(t, allocs, "allocated memory, but should not") }) t.Run("Not found", func(t *testing.T) { allocs := testing.AllocsPerRun(10, func() { msg.Get(AttrOrigin) //nolint:errcheck,gosec }) - if allocs > 0 { - t.Error("allocated memory, but should not") - } + assert.Zero(t, allocs, "allocated memory, but should not") }) } @@ -83,11 +75,8 @@ func TestPadding(t *testing.T) { {40, 40}, // 10 } for i, c := range tt { - if got := nearestPaddedValueLength(c.in); got != c.out { - t.Errorf("[%d]: padd(%d) %d (got) != %d (expected)", - i, c.in, got, c.out, - ) - } + got := nearestPaddedValueLength(c.in) + assert.Equal(t, c.out, got, "[%d]: padd(%d)", i, c.in) } } @@ -102,9 +91,8 @@ func TestAttrTypeRange(t *testing.T) { a := a t.Run(a.String(), func(t *testing.T) { a := a - if a.Optional() || !a.Required() { - t.Error("should be required") - } + assert.True(t, a.Required(), "should be required") + assert.False(t, a.Optional(), "should be required") }) } for _, a := range []AttrType{ @@ -114,9 +102,8 @@ func TestAttrTypeRange(t *testing.T) { } { a := a t.Run(a.String(), func(t *testing.T) { - if a.Required() || !a.Optional() { - t.Error("should be optional") - } + assert.False(t, a.Required(), "should be optional") + assert.True(t, a.Optional(), "should be optional") }) } } diff --git a/client_test.go b/client_test.go index da2cc63..e3bbc5e 100644 --- a/client_test.go +++ b/client_test.go @@ -18,6 +18,8 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) var ( @@ -88,13 +90,9 @@ func BenchmarkClient_Do(b *testing.B) { client, err := NewClient(noopConnection{}, WithAgent(agent), ) - if err != nil { - log.Fatal(err) - } + assert.NoError(b, err) defer func() { - if closeErr := client.Close(); closeErr != nil { - panic(closeErr) - } + assert.NoError(b, client.Close()) }() noopF := func(Event) { @@ -163,9 +161,8 @@ func TestClosedOrPanic(t *testing.T) { func() { defer func() { r, ok := recover().(error) - if !ok || !errors.Is(r, io.EOF) { - t.Error(r) - } + assert.True(t, ok, "should be error") + assert.ErrorIs(t, r, io.EOF) }() closedOrPanic(io.EOF) }() @@ -203,46 +200,32 @@ func TestClient_Start(t *testing.T) { //nolint:cyclop }, } client, err := NewClient(conn) - if err != nil { - log.Fatal(err) - } + assert.NoError(t, err) defer func() { - if err := client.Close(); err != nil { - t.Error(err) - } - if err := client.Close(); err == nil { - t.Error("second close should fail") - } - if err := client.Do(MustBuild(TransactionID), nil); err == nil { - t.Error("Do after Close should fail") - } + assert.NoError(t, client.Close()) + assert.Error(t, client.Close(), "second close should fail") + assert.Error(t, client.Do(MustBuild(TransactionID), nil), "Do after Close should fail") }() msg := MustBuild(response, BindingRequest) t.Log("init") got := make(chan struct{}) write <- struct{}{} t.Log("starting the first transaction") - if err := client.Start(msg, func(event Event) { + assert.NoError(t, client.Start(msg, func(event Event) { t.Log("got first transaction callback") - if event.Error != nil { - t.Error(event.Error) - } + assert.NoError(t, event.Error) got <- struct{}{} - }); err != nil { - t.Error(err) - } + })) t.Log("starting the second transaction") - if err := client.Start(msg, func(Event) { - t.Error("should not be called") - }); !errors.Is(err, ErrTransactionExists) { - t.Errorf("unexpected error %v", err) - } + assert.ErrorIs(t, client.Start(msg, func(Event) { + assert.Fail(t, "should not be called") + }), ErrTransactionExists) read <- struct{}{} select { case <-got: // pass case <-time.After(time.Millisecond * 10): - t.Error("timed out") + assert.Fail(t, "timed out") } } @@ -256,34 +239,20 @@ func TestClient_Do(t *testing.T) { }, } client, err := NewClient(conn) - if err != nil { - log.Fatal(err) - } + assert.NoError(t, err) defer func() { - if err := client.Close(); err != nil { - t.Error(err) - } - if err := client.Close(); err == nil { - t.Error("second close should fail") - } - if err := client.Do(MustBuild(TransactionID), nil); err == nil { - t.Error("Do after Close should fail") - } + assert.NoError(t, client.Close()) + assert.Error(t, client.Close(), "second close should fail") + assert.Error(t, client.Do(MustBuild(TransactionID), nil), "Do after Close should fail") }() m := MustBuild( NewTransactionIDSetter(response.TransactionID), ) - if err := client.Do(m, func(event Event) { - if event.Error != nil { - t.Error(event.Error) - } - }); err != nil { - t.Error(err) - } + assert.NoError(t, client.Do(m, func(event Event) { + assert.NoError(t, event.Error) + })) m = MustBuild(TransactionID) - if err := client.Do(m, nil); err != nil { - t.Error(err) - } + assert.NoError(t, client.Do(m, nil)) } func TestCloseErr_Error(t *testing.T) { @@ -299,11 +268,7 @@ func TestCloseErr_Error(t *testing.T) { ConnectionErr: io.ErrUnexpectedEOF, }, "failed to close: unexpected EOF (connection), (agent)"}, } { - if out := testCase.Err.Error(); out != testCase.Out { - t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)", - id, testCase.Err, out, testCase.Out, - ) - } + assert.Equal(t, testCase.Out, testCase.Err.Error(), "[%d]: Error(%#v)", id, testCase.Err) } } @@ -320,11 +285,7 @@ func TestStopErr_Error(t *testing.T) { Cause: io.ErrUnexpectedEOF, }, "error while stopping due to unexpected EOF: "}, } { - if out := testcase.Err.Error(); out != testcase.Out { - t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)", - id, testcase.Err, out, testcase.Out, - ) - } + assert.Equal(t, testcase.Out, testcase.Err.Error(), "[%d]: Error(%#v)", id, testcase.Err) } } @@ -365,25 +326,15 @@ func TestClientAgentError(t *testing.T) { startErr: io.ErrUnexpectedEOF, }), ) - if err != nil { - log.Fatal(err) - } + assert.NoError(t, err) defer func() { - if err := client.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, client.Close()) }() m := MustBuild(NewTransactionIDSetter(response.TransactionID)) - if err := client.Do(m, nil); err != nil { - t.Error(err) - } - if err := client.Do(m, func(event Event) { - if event.Error == nil { - t.Error("error expected") - } - }); !errors.Is(err, io.ErrUnexpectedEOF) { - t.Error("error expected") - } + assert.NoError(t, client.Do(m, nil)) + assert.ErrorIs(t, client.Do(m, func(event Event) { + assert.Error(t, event.Error, "error expected") + }), io.ErrUnexpectedEOF) } func TestClientConnErr(t *testing.T) { @@ -393,21 +344,13 @@ func TestClientConnErr(t *testing.T) { }, } client, err := NewClient(conn) - if err != nil { - log.Fatal(err) - } + assert.NoError(t, err) defer func() { - if err := client.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, client.Close()) }() m := MustBuild(TransactionID) - if err := client.Do(m, nil); err == nil { - t.Error("error expected") - } - if err := client.Do(m, NoopHandler()); err == nil { - t.Error("error expected") - } + assert.Error(t, client.Do(m, nil), "error expected") + assert.Error(t, client.Do(m, NoopHandler()), "error expected") } func TestClientConnErrStopErr(t *testing.T) { @@ -421,26 +364,19 @@ func TestClientConnErrStopErr(t *testing.T) { stopErr: io.ErrUnexpectedEOF, }), ) - if err != nil { - log.Fatal(err) - } + assert.NoError(t, err) defer func() { - if err := client.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, client.Close()) }() m := MustBuild(TransactionID) - if err := client.Do(m, NoopHandler()); err == nil { - t.Error("error expected") - } + assert.Error(t, client.Do(m, NoopHandler()), "error expected") } func TestCallbackWaitHandler_setCallback(t *testing.T) { c := callbackWaitHandler{} defer func() { - if err := recover(); err == nil { - t.Error("should panic") - } + err := recover() + assert.NotNil(t, err, "should panic") }() c.setCallback(nil) } @@ -450,56 +386,39 @@ func TestCallbackWaitHandler_HandleEvent(t *testing.T) { cond: sync.NewCond(new(sync.Mutex)), } defer func() { - if err := recover(); err == nil { - t.Error("should panic") - } + err := recover() + assert.NotNil(t, err, "should panic") }() c.HandleEvent(Event{}) } func TestNewClientNoConnection(t *testing.T) { c, err := NewClient(nil) - if c != nil { - t.Error("c should be nil") - } - if !errors.Is(err, ErrNoConnection) { - t.Error("bad error") - } + assert.Nil(t, c, "c should be nil") + assert.ErrorIs(t, err, ErrNoConnection, "bad error") } func TestDial(t *testing.T) { c, err := Dial("udp4", "localhost:3458") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) defer func() { - if err = c.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, c.Close()) }() } func TestDialURI(t *testing.T) { u, err := ParseURI("stun:localhost") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) c, err := DialURI(u, &DialConfig{}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) defer func() { - if err = c.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, c.Close()) }() } func TestDialError(t *testing.T) { _, err := Dial("bad?network", "?????") - if err == nil { - t.Fatal("error expected") - } + assert.Error(t, err, "error expected") } func TestClientCloseErr(t *testing.T) { @@ -516,13 +435,11 @@ func TestClientCloseErr(t *testing.T) { closeErr: io.ErrUnexpectedEOF, }), ) - if err != nil { - log.Fatal(err) - } + assert.NoError(t, err) defer func() { - if err, ok := c.Close().(CloseErr); !ok || !errors.Is(err.AgentErr, io.ErrUnexpectedEOF) { //nolint:errorlint - t.Error("unexpected close err") - } + err, ok := c.Close().(CloseErr) //nolint:errorlint + assert.True(t, ok, "should be CloseErr") + assert.ErrorIs(t, err.AgentErr, io.ErrUnexpectedEOF, "unexpected close err") }() } @@ -542,12 +459,8 @@ func TestWithNoConnClose(t *testing.T) { }), WithNoConnClose(), ) - if err != nil { - log.Fatal(err) - } - if err := c.Close(); err != nil { - t.Error("unexpected non-nil error") - } + assert.NoError(t, err) + assert.NoError(t, c.Close(), "unexpected non-nil error") } type gcWaitAgent struct { @@ -598,28 +511,20 @@ func TestClientGC(t *testing.T) { WithAgent(agent), WithTimeoutRate(time.Millisecond), ) - if err != nil { - log.Fatal(err) - } + assert.NoError(t, err) defer func() { - if err = c.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, c.Close()) }() select { case <-agent.gc: case <-time.After(time.Millisecond * 200): - t.Error("timed out") + assert.Fail(t, "timed out") } } func TestClientCheckInit(t *testing.T) { - if err := (&Client{}).Indicate(nil); !errors.Is(err, ErrClientNotInitialized) { - t.Error("unexpected error") - } - if err := (&Client{}).Do(nil, nil); !errors.Is(err, ErrClientNotInitialized) { - t.Error("unexpected error") - } + assert.ErrorIs(t, (&Client{}).Indicate(nil), ErrClientNotInitialized) + assert.ErrorIs(t, (&Client{}).Do(nil, nil), ErrClientNotInitialized) } func captureLog() (*bytes.Buffer, func()) { @@ -645,9 +550,7 @@ func TestClientFinalizer(t *testing.T) { }, } client, err := NewClient(conn) - if err != nil { - log.Panic(err) - } + assert.NoError(t, err) clientFinalizer(client) clientFinalizer(client) response := MustBuild(TransactionID, BindingSuccess) @@ -663,9 +566,7 @@ func TestClientFinalizer(t *testing.T) { closeErr: io.ErrUnexpectedEOF, }), ) - if err != nil { - log.Panic(err) - } + assert.NoError(t, err) clientFinalizer(client) reader := bufio.NewScanner(buf) var lines int @@ -676,17 +577,11 @@ func TestClientFinalizer(t *testing.T) { " (connection), unexpected EOF (agent)", } for reader.Scan() { - if reader.Text() != expectedLines[lines] { - t.Error(reader.Text(), "!=", expectedLines[lines]) - } + assert.Equal(t, expectedLines[lines], reader.Text()) lines++ } - if reader.Err() != nil { - t.Error(err) - } - if lines != 3 { - t.Error("incorrect count of log lines:", lines) - } + assert.NoError(t, reader.Err()) + assert.Equal(t, 3, lines, "incorrect count of log lines") } func TestCallbackWaitHandler(*testing.T) { @@ -784,9 +679,7 @@ func TestClientRetransmission(t *testing.T) { response.Encode() connL, connR := net.Pipe() defer func() { - if closeErr := connL.Close(); closeErr != nil { - panic(closeErr) - } + assert.NoError(t, connR.Close()) }() collector := new(manualCollector) clock := &manualClock{current: time.Now()} @@ -814,36 +707,22 @@ func TestClientRetransmission(t *testing.T) { WithCollector(collector), WithRTO(time.Millisecond), ) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) client.SetRTO(time.Second) gotReads := make(chan struct{}) go func() { buf := make([]byte, 1500) readN, readErr := connL.Read(buf) - if readErr != nil { - t.Error(readErr) - } - if !IsMessage(buf[:readN]) { - t.Error("should be STUN") - } + assert.NoError(t, readErr) + assert.True(t, IsMessage(buf[:readN]), "should be STUN") readN, readErr = connL.Read(buf) - if readErr != nil { - t.Error(readErr) - } - if !IsMessage(buf[:readN]) { - t.Error("should be STUN") - } + assert.NoError(t, readErr) + assert.True(t, IsMessage(buf[:readN]), "should be STUN") gotReads <- struct{}{} }() - if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { - if event.Error != nil { - t.Error("failed") - } - }); doErr != nil { - t.Fatal(doErr) - } + assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) { + assert.NoError(t, event.Error, "failed") + })) <-gotReads } @@ -854,9 +733,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop response.Encode() connL, connR := net.Pipe() defer func() { - if closeErr := connL.Close(); closeErr != nil { - panic(closeErr) - } + assert.NoError(t, connR.Close()) }() collector := new(manualCollector) clock := &manualClock{current: time.Now()} @@ -874,9 +751,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop WithClock(clock), WithCollector(collector), ) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) client.SetRTO(time.Second) conns := new(sync.WaitGroup) wg := new(sync.WaitGroup) @@ -891,29 +766,21 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop if errors.Is(readErr, io.EOF) { break } - t.Error(readErr) - } - if !IsMessage(buf[:readN]) { - t.Error("should be STUN") + assert.NoError(t, readErr) } + assert.True(t, IsMessage(buf[:readN]), "should be STUN") } }() wg.Add(1) go func() { defer wg.Done() - if doErr := client.Do(MustBuild(TransactionID, BindingRequest), func(event Event) { - if event.Error != nil { - t.Error("failed") - } - }); doErr != nil { - t.Error(doErr) - } + assert.NoError(t, client.Do(MustBuild(TransactionID, BindingRequest), func(event Event) { + assert.NoError(t, event.Error, "failed") + })) }() } wg.Wait() - if connErr := connR.Close(); connErr != nil { - t.Error(connErr) - } + assert.NoError(t, connR.Close()) conns.Wait() } @@ -942,24 +809,22 @@ func (c errorCollector) Close() error { return c.closeError } func TestNewClient(t *testing.T) { t.Run("SetCallbackError", func(t *testing.T) { setHandlerError := errClientSetHandler - if _, createErr := NewClient(noopConnection{}, + _, createErr := NewClient(noopConnection{}, WithAgent(&errorAgent{ setHandlerError: setHandlerError, }), - ); !errors.Is(createErr, setHandlerError) { - t.Errorf("unexpected error returned: %v", createErr) - } + ) + assert.ErrorIs(t, createErr, setHandlerError, "unexpected error returned") }) t.Run("CollectorStartError", func(t *testing.T) { startError := errClientStart - if _, createErr := NewClient(noopConnection{}, + _, createErr := NewClient(noopConnection{}, WithAgent(&TestAgent{}), WithCollector(errorCollector{ startError: startError, }), - ); !errors.Is(createErr, startError) { - t.Errorf("unexpected error returned: %v", createErr) - } + ) + assert.ErrorIs(t, createErr, startError, "unexpected error returned") }) } @@ -972,13 +837,9 @@ func TestClient_Close(t *testing.T) { }), WithAgent(&TestAgent{}), ) - if createErr != nil { - t.Errorf("unexpected create error returned: %v", createErr) - } + assert.NoError(t, createErr, "unexpected create error returned") gotCloseErr := c.Close() - if !errors.Is(gotCloseErr, closeErr) { - t.Errorf("unexpected close error returned: %v", gotCloseErr) - } + assert.ErrorIs(t, gotCloseErr, closeErr, "unexpected close error returned") }) } @@ -992,19 +853,13 @@ func TestClientDefaultHandler(t *testing.T) { client, createErr := NewClient(noopConnection{}, WithAgent(agent), WithHandler(func(e Event) { - if called { - t.Error("should not be called twice") - } + assert.False(t, called, "should not be called twice") called = true - if e.TransactionID != id { - t.Error("wrong transaction ID") - } + assert.Equal(t, id, e.TransactionID, "wrong transaction ID") handlerCalled <- struct{}{} }), ) - if createErr != nil { - t.Fatal(createErr) - } + assert.NoError(t, createErr) go func() { agent.h(Event{ TransactionID: id, @@ -1014,11 +869,9 @@ func TestClientDefaultHandler(t *testing.T) { case <-handlerCalled: // pass case <-time.After(time.Millisecond * 100): - t.Fatal("timed out") - } - if closeErr := client.Close(); closeErr != nil { - t.Error(closeErr) + assert.Fail(t, "timed out") } + assert.NoError(t, client.Close()) // Handler call should be ignored. agent.h(Event{}) } @@ -1030,15 +883,9 @@ func TestClientClosedStart(t *testing.T) { c, createErr := NewClient(noopConnection{}, WithAgent(a), ) - if createErr != nil { - t.Fatal(createErr) - } - if closeErr := c.Close(); closeErr != nil { - t.Error(closeErr) - } - if startErr := c.start(&clientTransaction{}); !errors.Is(startErr, ErrClientClosed) { - t.Error("should error") - } + assert.NoError(t, createErr) + assert.NoError(t, c.Close()) + assert.ErrorIs(t, c.start(&clientTransaction{}), ErrClientClosed) } func TestWithNoRetransmit(t *testing.T) { @@ -1046,9 +893,7 @@ func TestWithNoRetransmit(t *testing.T) { response.Encode() connL, connR := net.Pipe() defer func() { - if closeErr := connL.Close(); closeErr != nil { - panic(closeErr) - } + assert.NoError(t, connL.Close()) }() collector := new(manualCollector) clock := &manualClock{current: time.Now()} @@ -1062,7 +907,7 @@ func TestWithNoRetransmit(t *testing.T) { Error: ErrTransactionTimeOut, }) } else { - t.Error("there should be no second attempt") + assert.Fail(t, "there should be no second attempt") go agent.h(Event{ TransactionID: id, Error: ErrTransactionTimeOut, @@ -1078,28 +923,18 @@ func TestWithNoRetransmit(t *testing.T) { WithRTO(0), WithNoRetransmit, ) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) gotReads := make(chan struct{}) go func() { buf := make([]byte, 1500) readN, readErr := connL.Read(buf) - if readErr != nil { - t.Error(readErr) - } - if !IsMessage(buf[:readN]) { - t.Error("should be STUN") - } + assert.NoError(t, readErr) + assert.True(t, IsMessage(buf[:readN]), "should be STUN") gotReads <- struct{}{} }() - if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { - if !errors.Is(event.Error, ErrTransactionTimeOut) { - t.Error("unexpected error") - } - }); doErr != nil { - t.Fatal(err) - } + assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) { + assert.ErrorIs(t, event.Error, ErrTransactionTimeOut, "unexpected error") + })) <-gotReads } @@ -1114,9 +949,7 @@ func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop response.Encode() connL, connR := net.Pipe() defer func() { - if closeErr := connL.Close(); closeErr != nil { - panic(closeErr) - } + assert.NoError(t, connL.Close()) }() collector := new(manualCollector) shouldWait := false @@ -1169,9 +1002,7 @@ func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop t.Log("clock locked") <-clockLocked t.Log("closing client") - if closeErr := client.Close(); closeErr != nil { - t.Error(closeErr) - } + assert.NoError(t, client.Close()) t.Log("client closed, unlocking clock") clockWait <- struct{}{} t.Log("clock unlocked") @@ -1186,44 +1017,30 @@ func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop WithCollector(collector), WithRTO(time.Millisecond), ) - if startClientErr != nil { - t.Fatal(startClientErr) - } + assert.NoError(t, startClientErr) go func() { buf := make([]byte, 1500) readN, readErr := connL.Read(buf) - if readErr != nil { - t.Error(readErr) - } - if !IsMessage(buf[:readN]) { - t.Error("should be STUN") - } + assert.NoError(t, readErr) + assert.True(t, IsMessage(buf[:readN]), "should be STUN") readN, readErr = connL.Read(buf) - if readErr != nil { - t.Error(readErr) - } - if !IsMessage(buf[:readN]) { - t.Error("should be STUN") - } + assert.NoError(t, readErr) + assert.True(t, IsMessage(buf[:readN]), "should be STUN") gotReads <- struct{}{} }() t.Log("starting") done := make(chan struct{}) go func() { - if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { - if !errors.Is(event.Error, ErrClientClosed) { - t.Error(event.Error) - } - }); doErr != nil { - t.Error(doErr) - } + assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) { + assert.ErrorIs(t, event.Error, ErrClientClosed) + })) done <- struct{}{} }() select { case <-done: // ok case <-time.After(time.Second * 5): - t.Error("timeout") + assert.Fail(t, "timeout") } } @@ -1232,9 +1049,7 @@ func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop response.Encode() connL, connR := net.Pipe() defer func() { - if closeErr := connL.Close(); closeErr != nil { - panic(closeErr) - } + assert.NoError(t, connL.Close()) }() collector := new(manualCollector) shouldWait := false @@ -1291,9 +1106,7 @@ func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop t.Log("clock locked") <-clockLocked t.Log("closing connection") - if closeErr := connL.Close(); closeErr != nil { - panic(closeErr) - } + assert.NoError(t, connL.Close()) t.Log("connection closed, unlocking clock") clockWait <- struct{}{} t.Log("clock unlocked") @@ -1308,52 +1121,33 @@ func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop WithCollector(collector), WithRTO(time.Millisecond), ) - if startClientErr != nil { - t.Fatal(startClientErr) - } + assert.NoError(t, startClientErr) go func() { buf := make([]byte, 1500) readN, readErr := connL.Read(buf) - if readErr != nil { - t.Error(readErr) - } - if !IsMessage(buf[:readN]) { - t.Error("should be STUN") - } + assert.NoError(t, readErr) + assert.True(t, IsMessage(buf[:readN]), "should be STUN") readN, readErr = connL.Read(buf) - if readErr != nil { - t.Error(readErr) - } - if !IsMessage(buf[:readN]) { - t.Error("should be STUN") - } + assert.NoError(t, readErr) + assert.True(t, IsMessage(buf[:readN]), "should be STUN") gotReads <- struct{}{} }() t.Log("starting") done := make(chan struct{}) go func() { - if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { + assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) { var e StopErr - if !errors.As(event.Error, &e) { - t.Error(event.Error) - } else { - if !errors.Is(e.Err, agentStopErr) { - t.Error("incorrect agent error") - } - if !errors.Is(e.Cause, io.ErrClosedPipe) { - t.Error("incorrect connection error") - } - } - }); doErr != nil { - t.Error(doErr) - } + assert.ErrorAs(t, event.Error, &e) + assert.ErrorIs(t, e.Err, agentStopErr, "incorrect agent error") + assert.ErrorIs(t, e.Cause, io.ErrClosedPipe, "incorrect connection error") + })) done <- struct{}{} }() select { case <-done: // ok case <-time.After(time.Second * 5): - t.Error("timeout") + assert.Fail(t, "timeout") } } @@ -1362,9 +1156,7 @@ func TestClientRTOAgentErr(t *testing.T) { response.Encode() connL, connR := net.Pipe() defer func() { - if closeErr := connL.Close(); closeErr != nil { - panic(closeErr) - } + assert.NoError(t, connL.Close()) }() collector := new(manualCollector) clock := callbackClock(time.Now) @@ -1396,33 +1188,23 @@ func TestClientRTOAgentErr(t *testing.T) { WithCollector(collector), WithRTO(time.Millisecond), ) - if startClientErr != nil { - t.Fatal(startClientErr) - } + assert.NoError(t, startClientErr) go func() { buf := make([]byte, 1500) readN, readErr := connL.Read(buf) - if readErr != nil { - t.Error(readErr) - } - if !IsMessage(buf[:readN]) { - t.Error("should be STUN") - } + assert.NoError(t, readErr) + assert.True(t, IsMessage(buf[:readN]), "should be STUN") gotReads <- struct{}{} }() t.Log("starting") - if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { - if !errors.Is(event.Error, agentStartErr) { - t.Error(event.Error) - } - }); doErr != nil { - t.Error(doErr) - } + assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) { + assert.ErrorIs(t, event.Error, agentStartErr) + })) select { case <-gotReads: // ok case <-time.After(time.Second * 5): - t.Error("reads timeout") + assert.Fail(t, "reads timeout") } } @@ -1431,9 +1213,7 @@ func TestClient_HandleProcessError(t *testing.T) { response.Encode() connL, connR := net.Pipe() defer func() { - if closeErr := connL.Close(); closeErr != nil { - panic(closeErr) - } + assert.NoError(t, connL.Close()) }() collector := new(manualCollector) clock := callbackClock(time.Now) @@ -1451,14 +1231,10 @@ func TestClient_HandleProcessError(t *testing.T) { WithCollector(collector), WithRTO(time.Millisecond), ) - if startClientErr != nil { - t.Fatal(startClientErr) - } + assert.NoError(t, startClientErr) go func() { _, readErr := connL.Write(response.Raw) - if readErr != nil { - t.Error(readErr) - } + assert.NoError(t, readErr) gotWrites <- struct{}{} }() t.Log("starting") @@ -1466,20 +1242,16 @@ func TestClient_HandleProcessError(t *testing.T) { case <-gotWrites: // ok case <-time.After(time.Second * 5): - t.Error("reads timeout") - } - if closeErr := client.Close(); closeErr != nil { - t.Error(closeErr) + assert.Fail(t, "reads timeout") } + assert.NoError(t, client.Close()) } func TestClientImmediateTimeout(t *testing.T) { response := MustBuild(TransactionID, BindingSuccess) connL, connR := net.Pipe() defer func() { - if closeErr := connL.Close(); closeErr != nil { - panic(closeErr) - } + assert.NoError(t, connL.Close()) }() collector := new(manualCollector) clock := &manualClock{current: time.Now()} @@ -1488,16 +1260,14 @@ func TestClientImmediateTimeout(t *testing.T) { attempt := 0 agent.start = func(id [TransactionIDSize]byte, deadline time.Time) error { if attempt == 0 { - if deadline.Before(clock.current.Add(rto / 2)) { - t.Error("deadline too fast") - } + assert.False(t, deadline.Before(clock.current.Add(rto/2)), "deadline too fast") attempt++ go agent.h(Event{ TransactionID: id, Message: response, }) } else { - t.Error("there should be no second attempt") + assert.Fail(t, "there should be no second attempt") go agent.h(Event{ TransactionID: id, Error: ErrTransactionTimeOut, @@ -1512,25 +1282,17 @@ func TestClientImmediateTimeout(t *testing.T) { WithCollector(collector), WithRTO(rto), ) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) gotReads := make(chan struct{}) go func() { buf := make([]byte, 1500) readN, readErr := connL.Read(buf) - if readErr != nil { - t.Error(readErr) - } - if !IsMessage(buf[:readN]) { - t.Error("should be STUN") - } + assert.NoError(t, readErr) + assert.True(t, IsMessage(buf[:readN]), "should be STUN") gotReads <- struct{}{} }() client.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec - if errors.Is(e.Error, ErrTransactionTimeOut) { - t.Error("unexpected error") - } + assert.NoError(t, e.Error, "unexpected error") }) <-gotReads } diff --git a/errorcode_test.go b/errorcode_test.go index d7a0481..c5083f4 100644 --- a/errorcode_test.go +++ b/errorcode_test.go @@ -8,9 +8,10 @@ package stun import ( "encoding/base64" - "errors" "io" "testing" + + "github.com/stretchr/testify/assert" ) func BenchmarkErrorCode_AddTo(b *testing.B) { @@ -52,19 +53,13 @@ func TestErrorCodeAttribute_GetFrom(t *testing.T) { m := New() m.Add(AttrErrorCode, []byte{1}) c := new(ErrorCodeAttribute) - if err := c.GetFrom(m); !errors.Is(err, io.ErrUnexpectedEOF) { - t.Errorf("GetFrom should return <%s>, but got <%s>", - io.ErrUnexpectedEOF, err, - ) - } + assert.ErrorIs(t, c.GetFrom(m), io.ErrUnexpectedEOF) } func TestMessage_AddErrorCode(t *testing.T) { m := New() transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") - if err != nil { - t.Error(err) - } + assert.NoError(t, err) copy(m.TransactionID[:], transactionID) expectedCode := ErrorCode(438) expectedReason := "Stale Nonce" @@ -72,23 +67,13 @@ func TestMessage_AddErrorCode(t *testing.T) { m.WriteHeader() mRes := New() - if _, err = mRes.ReadFrom(m.reader()); err != nil { - t.Fatal(err) - } + _, err = mRes.ReadFrom(m.reader()) + assert.NoError(t, err) errCodeAttr := new(ErrorCodeAttribute) - if err = errCodeAttr.GetFrom(mRes); err != nil { - t.Error(err) - } + assert.NoError(t, errCodeAttr.GetFrom(mRes)) code := errCodeAttr.Code - if err != nil { - t.Error(err) - } - if code != expectedCode { - t.Error("bad code", code) - } - if string(errCodeAttr.Reason) != expectedReason { - t.Error("bad reason", string(errCodeAttr.Reason)) - } + assert.Equal(t, expectedCode, code, "bad code") + assert.Equal(t, expectedReason, string(errCodeAttr.Reason), "bad reason") } func TestErrorCode(t *testing.T) { @@ -96,19 +81,11 @@ func TestErrorCode(t *testing.T) { Code: 404, Reason: []byte("not found!"), } - if attr.String() != "404: not found!" { - t.Error("bad string", attr) - } + assert.Equal(t, "404: not found!", attr.String(), "bad string") m := New() cod := ErrorCode(666) - if err := cod.AddTo(m); !errors.Is(err, ErrNoDefaultReason) { - t.Error("should be ErrNoDefaultReason", err) - } - if err := attr.GetFrom(m); err == nil { - t.Error("attr should not be in message") - } + assert.ErrorIs(t, cod.AddTo(m), ErrNoDefaultReason, "should be ErrNoDefaultReason") + assert.Error(t, attr.GetFrom(m), "attr should not be in message") attr.Reason = make([]byte, 2048) - if err := attr.AddTo(m); err == nil { - t.Error("should error") - } + assert.Error(t, attr.AddTo(m), "should error") } diff --git a/errors_test.go b/errors_test.go index 1b5b462..08d45e8 100644 --- a/errors_test.go +++ b/errors_test.go @@ -6,6 +6,8 @@ package stun import ( "errors" "testing" + + "github.com/stretchr/testify/assert" ) func TestDecodeErr_IsInvalidCookie(t *testing.T) { @@ -14,25 +16,13 @@ func TestDecodeErr_IsInvalidCookie(t *testing.T) { decoded := new(Message) m.Raw[4] = 55 _, err := decoded.Write(m.Raw) - if err == nil { - t.Fatal("should error") - } + assert.Error(t, err, "should error") expected := "BadFormat for message/cookie: " + "3712a442 is invalid magic cookie (should be 2112a442)" - if err.Error() != expected { - t.Error(err, "!=", expected) - } + assert.Equal(t, expected, err.Error(), "error message mismatch") var dErr *DecodeErr - if !errors.As(err, &dErr) { - t.Error("not decode error") - } - if !dErr.IsInvalidCookie() { - t.Error("IsInvalidCookie = false, should be true") - } - if !dErr.IsPlaceChildren("cookie") { - t.Error("bad children") - } - if !dErr.IsPlaceParent("message") { - t.Error("bad parent") - } + assert.True(t, errors.As(err, &dErr), "not decode error") + assert.True(t, dErr.IsInvalidCookie(), "IsInvalidCookie = false, should be true") + assert.True(t, dErr.IsPlaceChildren("cookie"), "bad children") + assert.True(t, dErr.IsPlaceParent("message"), "bad parent") } diff --git a/fingerprint_test.go b/fingerprint_test.go index 137e1d1..b44dc58 100644 --- a/fingerprint_test.go +++ b/fingerprint_test.go @@ -9,6 +9,8 @@ package stun import ( "net" "testing" + + "github.com/stretchr/testify/assert" ) func BenchmarkFingerprint_AddTo(b *testing.B) { @@ -36,26 +38,18 @@ func TestFingerprint_Check(t *testing.T) { m.WriteHeader() Fingerprint.AddTo(m) //nolint:errcheck,gosec m.WriteHeader() - if err := Fingerprint.Check(m); err != nil { - t.Error(err) - } + assert.NoError(t, Fingerprint.Check(m)) m.Raw[3]++ - if err := Fingerprint.Check(m); err == nil { - t.Error("should error") - } + assert.Error(t, Fingerprint.Check(m)) } func TestFingerprint_CheckBad(t *testing.T) { m := new(Message) addAttr(t, m, NewSoftware("software")) m.WriteHeader() - if err := Fingerprint.Check(m); err == nil { - t.Error("should error") - } + assert.Error(t, Fingerprint.Check(m)) m.Add(AttrFingerprint, []byte{1, 2, 3}) - if !IsAttrSizeInvalid(Fingerprint.Check(m)) { - t.Error("IsAttrSizeInvalid should be true") - } + assert.True(t, IsAttrSizeInvalid(Fingerprint.Check(m))) } func BenchmarkFingerprint_Check(b *testing.B) { diff --git a/fuzz_test.go b/fuzz_test.go index e60353e..e0d3375 100644 --- a/fuzz_test.go +++ b/fuzz_test.go @@ -7,6 +7,8 @@ import ( "encoding/binary" "errors" "testing" + + "github.com/stretchr/testify/assert" ) func FuzzMessage(f *testing.F) { @@ -26,21 +28,12 @@ func FuzzMessage(f *testing.F) { } msg2 := New() - if _, err := msg2.Write(msg1.Raw); err != nil { - t.Fatalf("Failed to write: %s", err) - } - - if msg2.TransactionID != msg1.TransactionID { - t.Fatal("Transaction ID mismatch") - } + _, err := msg2.Write(msg1.Raw) + assert.NoError(t, err, "Failed to write") - if msg2.Type != msg1.Type { - t.Fatal("Type mismatch") - } - - if len(msg2.Attributes) != len(msg1.Attributes) { - t.Fatal("Attributes length mismatch") - } + assert.Equal(t, msg1.TransactionID, msg2.TransactionID, "Transaction ID mismatch") + assert.Equal(t, msg1.Type, msg2.Type, "Type mismatch") + assert.Equal(t, len(msg1.Attributes), len(msg2.Attributes), "Attributes length mismatch") }) } @@ -51,15 +44,11 @@ func FuzzType(f *testing.F) { t1 := MessageType{} t1.ReadValue(v) v2 := t1.Value() - if v != v2 { - t.Fatal("v != v2") - } + assert.Equal(t, v, v2, "v != v2") t2 := MessageType{} t2.ReadValue(v2) - if t2 != t1 { - t.Fatal("t2 != t1") - } + assert.Equal(t, t1, t2, "t2 != t1") }) } @@ -94,20 +83,19 @@ func FuzzSetters(f *testing.F) { m1.WriteHeader() m1.Add(attr.t, value) err := attr.g.GetFrom(m1) - if errors.Is(err, ErrAttributeNotFound) { - t.Fatalf("Unexpected 404: %s", err) - } + assert.False(t, errors.Is(err, ErrAttributeNotFound)) if err != nil { return } m2.WriteHeader() - if err = attr.g.AddTo(m2); err != nil { + err = attr.g.AddTo(m2) + if err != nil { // We allow decoding some text attributes // when their length is too big, but // not encoding. if !IsAttrSizeOverflow(err) { - t.Fatal(err) + assert.NoError(t, err) } return @@ -115,14 +103,10 @@ func FuzzSetters(f *testing.F) { m3.WriteHeader() v, err := m2.Get(attr.t) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) m3.Add(attr.t, v) - if !m2.Equal(m3) { - t.Fatalf("Not equal: %s != %s", m2, m3) - } + assert.True(t, m2.Equal(m3), "Not equal: %s != %s", m2, m3) }) } diff --git a/helpers_test.go b/helpers_test.go index f112ade..44a4f74 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/pion/stun/v3/internal/testutil" + "github.com/stretchr/testify/assert" ) func BenchmarkBuildOverhead(b *testing.B) { @@ -59,21 +60,12 @@ func TestMessage_Apply(t *testing.T) { integrity, Fingerprint, ) - if err != nil { - t.Fatal("failed to build:", err) - } - if err = msg.Check(Fingerprint, integrity); err != nil { - t.Fatal(err) - } - if _, err := decoded.Write(msg.Raw); err != nil { - t.Fatal(err) - } - if !decoded.Equal(msg) { - t.Error("not equal") - } - if err := integrity.Check(decoded); err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "failed to build") + assert.NoError(t, msg.Check(Fingerprint, integrity)) + _, err = decoded.Write(msg.Raw) + assert.NoError(t, err) + assert.True(t, decoded.Equal(msg)) + assert.NoError(t, integrity.Check(decoded)) } type errReturner struct { @@ -97,25 +89,17 @@ func (e errReturner) GetFrom(*Message) error { func TestHelpersErrorHandling(t *testing.T) { m := New() errReturn := errReturner{Err: errTError} - if err := m.Build(errReturn); !errors.Is(err, errReturn.Err) { - t.Error(err, "!=", errReturn.Err) - } - if err := m.Check(errReturn); !errors.Is(err, errReturn.Err) { - t.Error(err, "!=", errReturn.Err) - } - if err := m.Parse(errReturn); !errors.Is(err, errReturn.Err) { - t.Error(err, "!=", errReturn.Err) - } + assert.ErrorIs(t, m.Build(errReturn), errReturn.Err) + assert.ErrorIs(t, m.Check(errReturn), errReturn.Err) + assert.ErrorIs(t, m.Parse(errReturn), errReturn.Err) t.Run("MustBuild", func(t *testing.T) { t.Run("Positive", func(*testing.T) { MustBuild(NewTransactionIDSetter(transactionID{})) }) defer func() { - if p, ok := recover().(error); !ok || !errors.Is(p, errReturn.Err) { - t.Errorf("%s != %s", - p, errReturn.Err, - ) - } + p, ok := recover().(error) + assert.True(t, ok) + assert.ErrorIs(t, p, errReturn.Err) }() MustBuild(errReturn) }) @@ -123,91 +107,62 @@ func TestHelpersErrorHandling(t *testing.T) { func TestMessage_ForEach(t *testing.T) { //nolint:cyclop initial := New() - if err := initial.Build( + assert.NoError(t, initial.Build( NewRealm("realm1"), NewRealm("realm2"), - ); err != nil { - t.Fatal(err) - } + )) newMessage := func() *Message { m := New() - if err := m.Build( + assert.NoError(t, m.Build( NewRealm("realm1"), NewRealm("realm2"), - ); err != nil { - t.Fatal(err) - } + )) return m } t.Run("NoResults", func(t *testing.T) { m := newMessage() - if !m.Equal(initial) { - t.Error("m should be equal to initial") - } - if err := m.ForEach(AttrUsername, func(*Message) error { - t.Error("should not be called") + assert.True(t, m.Equal(initial), "m should be equal to initial") + assert.NoError(t, m.ForEach(AttrUsername, func(*Message) error { + assert.Fail(t, "should not be called") return nil - }); err != nil { - t.Fatal(err) - } - if !m.Equal(initial) { - t.Error("m should be equal to initial") - } + })) + assert.True(t, m.Equal(initial), "m should be equal to initial") }) t.Run("ReturnOnError", func(t *testing.T) { m := newMessage() var calls int - if err := m.ForEach(AttrRealm, func(*Message) error { + err := m.ForEach(AttrRealm, func(*Message) error { if calls > 0 { - t.Error("called multiple times") + assert.Fail(t, "called multiple times") } calls++ return ErrAttributeNotFound - }); !errors.Is(err, ErrAttributeNotFound) { - t.Fatal(err) - } - if !m.Equal(initial) { - t.Error("m should be equal to initial") - } + }) + assert.ErrorIs(t, err, ErrAttributeNotFound) + assert.True(t, m.Equal(initial), "m should be equal to initial") }) t.Run("Positive", func(t *testing.T) { msg := newMessage() var realms []string - if err := msg.ForEach(AttrRealm, func(m *Message) error { + assert.NoError(t, msg.ForEach(AttrRealm, func(m *Message) error { var realm Realm - if err := realm.GetFrom(m); err != nil { - return err - } + assert.NoError(t, realm.GetFrom(m)) realms = append(realms, realm.String()) return nil - }); err != nil { - t.Fatal(err) - } - if len(realms) != 2 { - t.Fatal("expected 2 realms") - } - if realms[0] != "realm1" { - t.Error("bad value for 1 realm") - } - if realms[1] != "realm2" { - t.Error("bad value for 2 realm") - } - if !msg.Equal(initial) { - t.Error("m should be equal to initial") - } + })) + assert.Len(t, realms, 2) + assert.Equal(t, "realm1", realms[0], "bad value for 1 realm") + assert.Equal(t, "realm2", realms[1], "bad value for 2 realm") + assert.True(t, msg.Equal(initial), "m should be equal to initial") t.Run("ZeroAlloc", func(t *testing.T) { msg = newMessage() var realm Realm testutil.ShouldNotAllocate(t, func() { - if err := msg.ForEach(AttrRealm, realm.GetFrom); err != nil { - t.Fatal(err) - } + assert.NoError(t, msg.ForEach(AttrRealm, realm.GetFrom)) }) - if !msg.Equal(initial) { - t.Error("m should be equal to initial") - } + assert.True(t, msg.Equal(initial), "m should be equal to initial") }) }) } diff --git a/iana_test.go b/iana_test.go index c129372..8af423f 100644 --- a/iana_test.go +++ b/iana_test.go @@ -12,6 +12,8 @@ import ( "strconv" "strings" "testing" + + "github.com/stretchr/testify/assert" ) func loadCSV(tb testing.TB, name string) [][]string { @@ -21,9 +23,7 @@ func loadCSV(tb testing.TB, name string) [][]string { r := csv.NewReader(bytes.NewReader(data)) r.Comment = '#' records, err := r.ReadAll() - if err != nil { - tb.Fatal(err) - } + assert.NoError(tb, err) return records } @@ -41,20 +41,14 @@ func TestIANA(t *testing.T) { //nolint:cyclop continue } val, parseErr := strconv.ParseInt(v[2:], 16, 64) - if parseErr != nil { - t.Fatal(parseErr) - } + assert.NoError(t, parseErr) t.Logf("value: 0x%x, name: %s", val, name) methods[name] = Method(val) //nolint:gosec // G115 } for val, name := range methodName() { mapped, ok := methods[name] - if !ok { - t.Errorf("failed to find method %s in IANA", name) - } - if mapped != val { - t.Errorf("%s: IANA %d != actual %d", name, mapped, val) - } + assert.True(t, ok, "failed to find method %s in IANA", name) + assert.Equal(t, mapped, val, "%s: IANA %d != actual %d", name, mapped, val) } }) t.Run("Attributes", func(t *testing.T) { @@ -69,9 +63,7 @@ func TestIANA(t *testing.T) { //nolint:cyclop continue } val, parseErr := strconv.ParseInt(v[2:], 16, 64) - if parseErr != nil { - t.Fatal(parseErr) - } + assert.NoError(t, parseErr) t.Logf("value: 0x%x, name: %s", val, name) attrTypes[name] = AttrType(val) //nolint:gosec // G115 } @@ -83,12 +75,8 @@ func TestIANA(t *testing.T) { //nolint:cyclop } for val, name := range attrNames() { mapped, ok := attrTypes[name] - if !ok { - t.Errorf("failed to find attribute %s in IANA", name) - } - if mapped != val { - t.Errorf("%s: IANA %d != actual %d", name, mapped, val) - } + assert.True(t, ok, "failed to find attribute %s in IANA", name) + assert.Equal(t, mapped, val, "%s: IANA %d != actual %d", name, mapped, val) } }) t.Run("ErrorCodes", func(t *testing.T) { @@ -103,21 +91,15 @@ func TestIANA(t *testing.T) { //nolint:cyclop continue } val, parseErr := strconv.ParseInt(v, 10, 64) - if parseErr != nil { - t.Fatal(parseErr) - } + assert.NoError(t, parseErr) t.Logf("value: 0x%x, name: %s", val, name) errorCodes[name] = ErrorCode(val) } for val, nameB := range errorReasons { name := string(nameB) mapped, ok := errorCodes[name] - if !ok { - t.Errorf("failed to find error code %s in IANA", name) - } - if mapped != val { - t.Errorf("%s: IANA %d != actual %d", name, mapped, val) - } + assert.True(t, ok, "failed to find error code %s in IANA", name) + assert.Equal(t, mapped, val, "%s: IANA %d != actual %d", name, mapped, val) } }) } diff --git a/integrity_test.go b/integrity_test.go index 1ae7f9c..8ee746f 100644 --- a/integrity_test.go +++ b/integrity_test.go @@ -4,40 +4,29 @@ package stun import ( - "bytes" "encoding/hex" "testing" + + "github.com/stretchr/testify/assert" ) func TestMessageIntegrity_AddTo_Simple(t *testing.T) { integrity := NewLongTermIntegrity("user", "realm", "pass") expected, err := hex.DecodeString("8493fbc53ba582fb4c044c456bdc40eb") - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(expected, integrity) { - t.Error(ErrIntegrityMismatch) - } + assert.NoError(t, err) + assert.EqualValues(t, expected, integrity) t.Run("Check", func(t *testing.T) { m := new(Message) m.WriteHeader() - if err := integrity.AddTo(m); err != nil { - t.Error(err) - } + assert.NoError(t, integrity.AddTo(m)) NewSoftware("software").AddTo(m) //nolint:errcheck,gosec m.WriteHeader() dM := new(Message) dM.Raw = m.Raw - if err := dM.Decode(); err != nil { - t.Error(err) - } - if err := integrity.Check(dM); err != nil { - t.Error(err) - } + assert.NoError(t, dM.Decode()) + assert.NoError(t, integrity.Check(dM)) dM.Raw[24] += 12 // HMAC now invalid - if integrity.Check(dM) == nil { - t.Error("should be invalid") - } + assert.Error(t, integrity.Check(dM)) }) } @@ -47,38 +36,23 @@ func TestMessageIntegrityWithFingerprint(t *testing.T) { msg.WriteHeader() NewSoftware("software").AddTo(msg) //nolint:errcheck,gosec integrity := NewShortTermIntegrity("pwd") - if integrity.String() != "KEY: 0x707764" { - t.Error("bad string", integrity) - } - if err := integrity.Check(msg); err == nil { - t.Error("should error") - } - if err := integrity.AddTo(msg); err != nil { - t.Fatal(err) - } - if err := Fingerprint.AddTo(msg); err != nil { - t.Fatal(err) - } - if err := integrity.Check(msg); err != nil { - t.Fatal(err) - } + assert.Equal(t, "KEY: 0x707764", integrity.String()) + assert.NoError(t, integrity.AddTo(msg)) + assert.NoError(t, integrity.AddTo(msg)) + assert.NoError(t, integrity.Check(msg)) + assert.NoError(t, Fingerprint.AddTo(msg)) + assert.NoError(t, integrity.Check(msg)) msg.Raw[24] = 33 - if err := integrity.Check(msg); err == nil { - t.Fatal("mismatch expected") - } + assert.Error(t, integrity.Check(msg)) } func TestMessageIntegrity(t *testing.T) { m := new(Message) i := NewShortTermIntegrity("password") m.WriteHeader() - if err := i.AddTo(m); err != nil { - t.Error(err) - } + assert.NoError(t, i.AddTo(m)) _, err := m.Get(AttrMessageIntegrity) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) } func TestMessageIntegrityBeforeFingerprint(t *testing.T) { @@ -86,9 +60,7 @@ func TestMessageIntegrityBeforeFingerprint(t *testing.T) { m.WriteHeader() Fingerprint.AddTo(m) //nolint:errcheck,gosec i := NewShortTermIntegrity("password") - if err := i.AddTo(m); err == nil { - t.Error("should error") - } + assert.Error(t, i.AddTo(m)) } func BenchmarkMessageIntegrity_AddTo(b *testing.B) { @@ -99,9 +71,7 @@ func BenchmarkMessageIntegrity_AddTo(b *testing.B) { b.SetBytes(int64(len(m.Raw))) for i := 0; i < b.N; i++ { m.WriteHeader() - if err := integrity.AddTo(m); err != nil { - b.Error(err) - } + assert.NoError(b, integrity.AddTo(m)) m.Reset() } } @@ -114,13 +84,9 @@ func BenchmarkMessageIntegrity_Check(b *testing.B) { b.ReportAllocs() m.WriteHeader() b.SetBytes(int64(len(m.Raw))) - if err := integrity.AddTo(m); err != nil { - b.Error(err) - } + assert.NoError(b, integrity.AddTo(m)) m.WriteLength() for i := 0; i < b.N; i++ { - if err := integrity.Check(m); err != nil { - b.Fatal(err) - } + assert.NoError(b, integrity.Check(m)) } } diff --git a/internal/hmac/hmac_test.go b/internal/hmac/hmac_test.go index db02f39..d51903d 100644 --- a/internal/hmac/hmac_test.go +++ b/internal/hmac/hmac_test.go @@ -11,6 +11,8 @@ import ( "fmt" "hash" "testing" + + "github.com/stretchr/testify/assert" ) type hmacTest struct { @@ -524,26 +526,17 @@ func hmacTests() []hmacTest { //nolint:maintidx func TestHMAC(t *testing.T) { for i, tt := range hmacTests() { hsh := New(tt.hash, tt.key) - if s := hsh.Size(); s != tt.size { - t.Errorf("Size: got %v, want %v", s, tt.size) - } - if b := hsh.BlockSize(); b != tt.blocksize { - t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize) - } + assert.Equal(t, tt.size, hsh.Size(), "Size mismatch") + assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch") for j := 0; j < 4; j++ { //nolint:varnamelen n, err := hsh.Write(tt.in) - if n != len(tt.in) || err != nil { - t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) - - continue - } + assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n) + assert.NoError(t, err, "test %d.%d: Write error", i, j) // Repetitive Sum() calls should return the same value for k := 0; k < 2; k++ { sum := fmt.Sprintf("%x", hsh.Sum(nil)) - if sum != tt.out { - t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) - } + assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) } // Second iteration: make sure reset works. @@ -568,18 +561,10 @@ func TestEqual(t *testing.T) { b := []byte("test1") c := []byte("test2") - if !Equal(b, b) { - t.Error("Equal failed with equal arguments") - } - if Equal(a, b) { - t.Error("Equal accepted a prefix of the second argument") - } - if Equal(b, a) { - t.Error("Equal accepted a prefix of the first argument") - } - if Equal(b, c) { - t.Error("Equal accepted unequal slices") - } + assert.True(t, Equal(b, b), "Equal failed with equal arguments") + assert.False(t, Equal(a, b), "Equal accepted a prefix of the second argument") + assert.False(t, Equal(b, a), "Equal accepted a prefix of the first argument") + assert.False(t, Equal(b, c), "Equal accepted unequal slices") } func BenchmarkHMACSHA256_1K(b *testing.B) { diff --git a/internal/hmac/pool_test.go b/internal/hmac/pool_test.go index ac3e2b7..c4b980e 100644 --- a/internal/hmac/pool_test.go +++ b/internal/hmac/pool_test.go @@ -8,6 +8,8 @@ import ( "crypto/sha256" "fmt" "testing" + + "github.com/stretchr/testify/assert" ) func BenchmarkHMACSHA1_512(b *testing.B) { @@ -44,26 +46,17 @@ func TestHMACReset(t *testing.T) { for i, tt := range hmacTests() { hsh := New(tt.hash, tt.key) hsh.(*hmac).resetTo(tt.key) //nolint:forcetypeassert - if s := hsh.Size(); s != tt.size { - t.Errorf("Size: got %v, want %v", s, tt.size) - } - if b := hsh.BlockSize(); b != tt.blocksize { - t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize) - } + assert.Equal(t, tt.size, hsh.Size(), "Size mismatch") + assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch") for j := 0; j < 2; j++ { n, err := hsh.Write(tt.in) - if n != len(tt.in) || err != nil { - t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) - - continue - } + assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n) + assert.NoError(t, err, "test %d.%d: Write error", i, j) // Repetitive Sum() calls should return the same value for k := 0; k < 2; k++ { sum := fmt.Sprintf("%x", hsh.Sum(nil)) - if sum != tt.out { - t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) - } + assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) } // Second iteration: make sure reset works. @@ -78,26 +71,17 @@ func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl,cyclop continue } hsh := AcquireSHA1(tt.key) - if s := hsh.Size(); s != tt.size { - t.Errorf("Size: got %v, want %v", s, tt.size) - } - if b := hsh.BlockSize(); b != tt.blocksize { - t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize) - } + assert.Equal(t, tt.size, hsh.Size(), "Size mismatch") + assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch") for j := 0; j < 2; j++ { n, err := hsh.Write(tt.in) - if n != len(tt.in) || err != nil { - t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) - - continue - } + assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n) + assert.NoError(t, err, "test %d.%d: Write error", i, j) // Repetitive Sum() calls should return the same value for k := 0; k < 2; k++ { sum := fmt.Sprintf("%x", hsh.Sum(nil)) - if sum != tt.out { - t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) - } + assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) } // Second iteration: make sure reset works. @@ -113,26 +97,17 @@ func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl,cyclop continue } hsh := AcquireSHA256(tt.key) - if s := hsh.Size(); s != tt.size { - t.Errorf("Size: got %v, want %v", s, tt.size) - } - if b := hsh.BlockSize(); b != tt.blocksize { - t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize) - } + assert.Equal(t, tt.size, hsh.Size(), "Size mismatch") + assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch") for j := 0; j < 2; j++ { n, err := hsh.Write(tt.in) - if n != len(tt.in) || err != nil { - t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) - - continue - } + assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n) + assert.NoError(t, err, "test %d.%d: Write error", i, j) // Repetitive Sum() calls should return the same value for k := 0; k < 2; k++ { sum := fmt.Sprintf("%x", hsh.Sum(nil)) - if sum != tt.out { - t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) - } + assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) } // Second iteration: make sure reset works. @@ -150,7 +125,7 @@ func TestAssertBlockSize(t *testing.T) { t.Run("Negative", func(t *testing.T) { defer func() { if r := recover(); r == nil { - t.Error("should panic") + assert.Fail(t, "should panic") } }() h := AcquireSHA256(make([]byte, 0, 1024)) diff --git a/internal/testutil/allocs.go b/internal/testutil/allocs.go index 27ac5d5..34d52b5 100644 --- a/internal/testutil/allocs.go +++ b/internal/testutil/allocs.go @@ -6,6 +6,8 @@ package testutil import ( "testing" + + "github.com/stretchr/testify/assert" ) // ShouldNotAllocate fails if f allocates. @@ -17,7 +19,5 @@ func ShouldNotAllocate(t *testing.T, f func()) { return } - if a := testing.AllocsPerRun(10, f); a > 0 { - t.Errorf("Allocations detected: %f", a) - } + assert.Zero(t, testing.AllocsPerRun(10, f)) } diff --git a/message_test.go b/message_test.go index 2ccea4e..f8862da 100644 --- a/message_test.go +++ b/message_test.go @@ -21,6 +21,8 @@ import ( "strconv" "strings" "testing" + + "github.com/stretchr/testify/assert" ) type attributeEncoder interface { @@ -50,12 +52,9 @@ func TestMessageBuffer(t *testing.T) { m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) m.WriteHeader() mDecoded := New() - if _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw)); err != nil { - t.Error(err) - } - if !mDecoded.Equal(m) { - t.Error(mDecoded, "!", m) - } + _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw)) + assert.NoError(t, err) + assert.True(t, mDecoded.Equal(m), "mDecoded != m") } func BenchmarkMessage_Write(b *testing.B) { @@ -86,9 +85,7 @@ func TestMessageType_Value(t *testing.T) { } for _, tt := range tests { b := tt.in.Value() - if b != tt.out { - t.Errorf("Value(%s) -> %s, want %s", tt.in, bUint16(b), bUint16(tt.out)) - } + assert.Equal(t, tt.out, b, "Value(%s) -> %s, want %s", tt.in, bUint16(b), bUint16(tt.out)) } } @@ -104,9 +101,7 @@ func TestMessageType_ReadValue(t *testing.T) { for _, tt := range tests { m := MessageType{} m.ReadValue(tt.in) - if m != tt.out { - t.Errorf("ReadValue(%s) -> %s, want %s", bUint16(tt.in), m, tt.out) - } + assert.Equal(t, tt.out, m, "ReadValue(%s) -> %s, want %s", bUint16(tt.in), m, tt.out) } } @@ -121,12 +116,8 @@ func TestMessageType_ReadWriteValue(t *testing.T) { m := MessageType{} v := tt.Value() m.ReadValue(v) - if m != tt { - t.Errorf("ReadValue(%s -> %s) = %s, should be %s", tt, bUint16(v), m, tt) - if m.Method != tt.Method { - t.Errorf("%s != %s", bUint16(uint16(m.Method)), bUint16(uint16(tt.Method))) - } - } + assert.Equal(t, tt, m, "ReadValue(%s -> %s) = %s, should be %s", tt, bUint16(v), m, tt) + assert.Equal(t, tt.Method, m.Method, "%s != %s", bUint16(uint16(m.Method)), bUint16(uint16(tt.Method))) } } @@ -137,32 +128,26 @@ func TestMessage_WriteTo(t *testing.T) { msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) msg.WriteHeader() buf := new(bytes.Buffer) - if _, err := msg.WriteTo(buf); err != nil { - t.Fatal(err) - } + _, err := msg.WriteTo(buf) + assert.NoError(t, err) mDecoded := New() - if _, err := mDecoded.ReadFrom(buf); err != nil { - t.Error(err) - } - if !mDecoded.Equal(msg) { - t.Error(mDecoded, "!", msg) - } + _, err = mDecoded.ReadFrom(buf) + assert.NoError(t, err) + assert.True(t, mDecoded.Equal(msg), "mDecoded != msg") } func TestMessage_Cookie(t *testing.T) { buf := make([]byte, 20) mDecoded := New() - if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil { - t.Error("should error") - } + _, err := mDecoded.ReadFrom(bytes.NewReader(buf)) + assert.Error(t, err, "should error") } func TestMessage_LengthLessHeaderSize(t *testing.T) { buf := make([]byte, 8) mDecoded := New() - if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil { - t.Error("should error") - } + _, err := mDecoded.ReadFrom(bytes.NewReader(buf)) + assert.Error(t, err, "should error") } func TestMessage_BadLength(t *testing.T) { @@ -176,9 +161,8 @@ func TestMessage_BadLength(t *testing.T) { m.WriteHeader() m.Raw[20+3] = 10 // set attr length = 10 mDecoded := New() - if _, err := mDecoded.Write(m.Raw); err == nil { - t.Error("should error") - } + _, err := mDecoded.Write(m.Raw) + assert.Error(t, err, "should error") } func TestMessage_AttrLengthLessThanHeader(t *testing.T) { @@ -197,13 +181,8 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) { binary.BigEndian.PutUint16(m.Raw[2:4], 2) // rewrite to bad length _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+2])) var e *DecodeErr - if errors.As(err, &e) { - if !e.IsPlace(DecodeErrPlace{"attribute", "header"}) { - t.Error(e, "bad place") - } - } else { - t.Error(err, "should be bad format") - } + assert.ErrorAs(t, err, &e) + assert.True(t, e.IsPlace(DecodeErrPlace{"attribute", "header"}), "bad place") } func TestMessage_AttrSizeLessThanLength(t *testing.T) { @@ -226,13 +205,8 @@ func TestMessage_AttrSizeLessThanLength(t *testing.T) { mDecoded := New() _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+5])) var e *DecodeErr - if errors.As(err, &e) { - if !e.IsPlace(DecodeErrPlace{"attribute", "value"}) { - t.Error(e, "bad place") - } - } else { - t.Error(err, "should be bad format") - } + assert.ErrorAs(t, err, &e) + assert.True(t, e.IsPlace(DecodeErrPlace{"attribute", "value"}), "bad place") } type unexpectedEOFReader struct{} @@ -244,9 +218,7 @@ func (r unexpectedEOFReader) Read([]byte) (int, error) { func TestMessage_ReadFromError(t *testing.T) { mDecoded := New() _, err := mDecoded.ReadFrom(unexpectedEOFReader{}) - if !errors.Is(err, io.ErrUnexpectedEOF) { - t.Error(err, "should be", io.ErrUnexpectedEOF) - } + assert.ErrorIs(t, err, io.ErrUnexpectedEOF, "should be", io.ErrUnexpectedEOF) } func BenchmarkMessageType_Value(b *testing.B) { @@ -321,9 +293,7 @@ func BenchmarkMessage_ReadBytes(b *testing.B) { func TestMessageClass_String(t *testing.T) { defer func() { - if err := recover(); err == nil { - t.Error(err, "should be not nil") - } + assert.NotNil(t, recover()) }() v := [...]MessageClass{ @@ -333,14 +303,12 @@ func TestMessageClass_String(t *testing.T) { ClassIndication, } for _, k := range v { - if k.String() == "" { - t.Error(k, "bad stringer") - } + assert.NotEmpty(t, k.String(), "%v bad stringer", k) } // should panic p := MessageClass(0x05).String() - t.Error("should panic!", p) + assert.Fail(t, "should panic", p) } func TestAttrType_String(t *testing.T) { @@ -358,46 +326,26 @@ func TestAttrType_String(t *testing.T) { AttrFingerprint, } for _, k := range attrType { - if k.String() == "" { - t.Error(k, "bad stringer") - } - if strings.HasPrefix(k.String(), "0x") { - t.Error(k, "bad stringer") - } + assert.NotEmpty(t, k.String(), "%v bad stringer", k) + assert.False(t, strings.HasPrefix(k.String(), "0x"), "%v bad stringer", k) } vNonStandard := AttrType(0x512) - if !strings.HasPrefix(vNonStandard.String(), "0x512") { - t.Error(vNonStandard, "bad prefix") - } + assert.True(t, strings.HasPrefix(vNonStandard.String(), "0x512"), "%v bad prefix", vNonStandard) } func TestMethod_String(t *testing.T) { - if MethodBinding.String() != "Binding" { - t.Error("binding is not binding!") - } - if Method(0x616).String() != "0x616" { - t.Error("Bad stringer", Method(0x616)) - } + assert.Equal(t, "Binding", MethodBinding.String(), "binding is not binding!") + assert.Equal(t, "0x616", Method(0x616).String(), "Bad stringer") } func TestAttribute_Equal(t *testing.T) { - attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}} - attr2 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}} - if !attr1.Equal(attr2) { - t.Error("should equal") - } - if attr1.Equal(RawAttribute{Type: 0x2}) { - t.Error("should not equal") - } - if attr1.Equal(RawAttribute{Length: 0x2}) { - t.Error("should not equal") - } - if attr1.Equal(RawAttribute{Length: 0x3}) { - t.Error("should not equal") - } - if attr1.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) { - t.Error("should not equal") - } + attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1} + attr2 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1} + assert.True(t, attr1.Equal(attr2)) + assert.False(t, attr1.Equal(RawAttribute{Type: 0x2})) + assert.False(t, attr1.Equal(RawAttribute{Length: 0x2})) + assert.False(t, attr1.Equal(RawAttribute{Length: 0x3})) + assert.False(t, attr1.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}})) } func TestMessage_Equal(t *testing.T) { //nolint:cyclop @@ -405,39 +353,23 @@ func TestMessage_Equal(t *testing.T) { //nolint:cyclop attrs := Attributes{attr} msg1 := &Message{Attributes: attrs, Length: 4 + 2} msg2 := &Message{Attributes: attrs, Length: 4 + 2} - if !msg1.Equal(msg2) { - t.Error("should equal") - } - if msg1.Equal(&Message{Type: MessageType{Class: 128}}) { - t.Error("should not equal") - } + assert.True(t, msg1.Equal(msg2)) + assert.False(t, msg1.Equal(&Message{Type: MessageType{Class: 128}})) tID := [TransactionIDSize]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, } - if msg1.Equal(&Message{TransactionID: tID}) { - t.Error("should not equal") - } - if msg1.Equal(&Message{Length: 3}) { - t.Error("should not equal") - } + assert.False(t, msg1.Equal(&Message{TransactionID: tID})) + assert.False(t, msg1.Equal(&Message{Length: 3})) tAttrs := Attributes{ {Length: 1, Value: []byte{0x1}, Type: 0x1}, } - if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) { - t.Error("should not equal") - } + assert.False(t, msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2})) tAttrs = Attributes{ {Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2}, } - if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) { - t.Error("should not equal") - } - if !(*Message)(nil).Equal(nil) { - t.Error("nil should be equal to nil") - } - if msg1.Equal(nil) { - t.Error("non-nil should not be equal to nil") - } + assert.False(t, msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2})) + assert.True(t, (*Message)(nil).Equal(nil), "nil should be equal to nil") + assert.False(t, msg1.Equal(nil), "non-nil should not be equal to nil") t.Run("Nil attributes", func(t *testing.T) { msg1 := &Message{ Attributes: nil, @@ -447,61 +379,43 @@ func TestMessage_Equal(t *testing.T) { //nolint:cyclop Attributes: attrs, Length: 4 + 2, } - if msg1.Equal(msg2) { - t.Error("should not equal") - } - if msg2.Equal(msg1) { - t.Error("should not equal") - } + assert.False(t, msg1.Equal(msg2)) + assert.False(t, msg2.Equal(msg1)) msg2.Attributes = nil - if !msg1.Equal(msg2) { - t.Error("should equal") - } + assert.True(t, msg1.Equal(msg2)) }) t.Run("Attributes length", func(t *testing.T) { attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1} attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1} a := &Message{Attributes: Attributes{attr}, Length: 4 + 2} b := &Message{Attributes: Attributes{attr, attr1}, Length: 4 + 2} - if a.Equal(b) { - t.Error("should not equal") - } + assert.False(t, a.Equal(b)) }) t.Run("Attributes values", func(t *testing.T) { attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1} attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x1} a := &Message{Attributes: Attributes{attr, attr}, Length: 4 + 2} b := &Message{Attributes: Attributes{attr, attr1}, Length: 4 + 2} - if a.Equal(b) { - t.Error("should not equal") - } + assert.False(t, a.Equal(b)) }) } func TestMessageGrow(t *testing.T) { m := New() m.grow(512) - if len(m.Raw) < 512 { - t.Error("Bad length", len(m.Raw)) - } + assert.GreaterOrEqual(t, len(m.Raw), 512) } func TestMessageGrowSmaller(t *testing.T) { m := New() m.grow(2) - if cap(m.Raw) < 20 { - t.Error("Bad capacity", cap(m.Raw)) - } - if len(m.Raw) < 20 { - t.Error("Bad length", len(m.Raw)) - } + assert.GreaterOrEqual(t, cap(m.Raw), 20) + assert.GreaterOrEqual(t, len(m.Raw), 20) } func TestMessage_String(t *testing.T) { m := New() - if m.String() == "" { - t.Error("bad string") - } + assert.NotEmpty(t, m.String()) } func TestIsMessage(t *testing.T) { @@ -525,9 +439,7 @@ func TestIsMessage(t *testing.T) { }, true}, // 6 } for i, v := range tt { - if got := IsMessage(v.in); got != v.out { - t.Errorf("tt[%d]: IsMessage(%+v) %v != %v", i, v.in, got, v.out) - } + assert.Equal(t, v.out, IsMessage(v.in), "tt[%d]: IsMessage(%+v)", i, v.in) } } @@ -553,18 +465,12 @@ func loadData(tb testing.TB, name string) []byte { name = filepath.Join("testdata", name) f, err := os.Open(name) //nolint:gosec - if err != nil { - tb.Fatal(err) - } + assert.NoError(tb, err) defer func() { - if errClose := f.Close(); errClose != nil { - tb.Fatal(errClose) - } + assert.NoError(tb, f.Close()) }() v, err := io.ReadAll(f) - if err != nil { - tb.Fatal(err) - } + assert.NoError(tb, err) return v } @@ -573,9 +479,7 @@ func TestExampleChrome(t *testing.T) { buf := loadData(t, "ex1_chrome.stun") m := New() _, err := m.Write(buf) - if err != nil { - t.Errorf("Failed to parse ex1_chrome: %s", err) - } + assert.NoError(t, err, "Failed to parse ex1_chrome") } func TestMessageFromBrowsers(t *testing.T) { @@ -583,9 +487,7 @@ func TestMessageFromBrowsers(t *testing.T) { reader := csv.NewReader(bytes.NewReader(loadData(t, "frombrowsers.csv"))) reader.Comment = '#' _, err := reader.Read() // skipping header - if err != nil { - t.Fatal("failed to skip header of csv: ", err) - } + assert.NoError(t, err, "failed to skip header of csv") crcTable := crc64.MakeTable(crc64.ISO) msg := New() for { @@ -593,23 +495,14 @@ func TestMessageFromBrowsers(t *testing.T) { if errors.Is(err, io.EOF) { break } - if err != nil { - t.Fatal("failed to read csv line: ", err) - } + assert.NoError(t, err, "failed to read csv line") data, err := base64.StdEncoding.DecodeString(line[1]) - if err != nil { - t.Fatal("failed to decode ", line[1], " as base64: ", err) - } + assert.NoError(t, err) b, err := strconv.ParseUint(line[2], 10, 64) - if err != nil { - t.Fatal(err) - } - if b != crc64.Checksum(data, crcTable) { - t.Error("crc64 check failed for ", line[1]) - } - if _, err = msg.Write(data); err != nil { - t.Error("failed to decode ", line[1], " as message: ", err) - } + assert.NoError(t, err) + assert.Equal(t, b, crc64.Checksum(data, crcTable), "crc64 check failed for %s", line[1]) + _, err = msg.Write(data) + assert.NoError(t, err, "failed to decode %s as message: %s", line[1], err) msg.Reset() } } @@ -619,9 +512,7 @@ func BenchmarkMessage_NewTransactionID(b *testing.B) { m := new(Message) m.WriteHeader() for i := 0; i < b.N; i++ { - if err := m.NewTransactionID(); err != nil { - b.Fatal(err) - } + assert.NoError(b, m.NewTransactionID()) } } @@ -633,12 +524,8 @@ func BenchmarkMessageFull(b *testing.B) { IP: net.IPv4(213, 1, 223, 5), } for i := 0; i < b.N; i++ { - if err := addr.AddTo(msg); err != nil { - b.Fatal(err) - } - if err := s.AddTo(msg); err != nil { - b.Fatal(err) - } + assert.NoError(b, addr.AddTo(msg)) + assert.NoError(b, s.AddTo(msg)) msg.WriteAttributes() msg.WriteHeader() Fingerprint.AddTo(msg) //nolint:errcheck,gosec @@ -655,12 +542,8 @@ func BenchmarkMessageFullHardcore(b *testing.B) { IP: net.IPv4(213, 1, 223, 5), } for i := 0; i < b.N; i++ { - if err := addr.AddTo(msg); err != nil { - b.Fatal(err) - } - if err := s.AddTo(msg); err != nil { - b.Fatal(err) - } + assert.NoError(b, addr.AddTo(msg)) + assert.NoError(b, s.AddTo(msg)) msg.WriteHeader() msg.Reset() } @@ -684,12 +567,8 @@ func BenchmarkMessage_WriteHeader(b *testing.B) { func TestMessage_Contains(t *testing.T) { m := new(Message) m.Add(AttrSoftware, []byte("value")) - if !m.Contains(AttrSoftware) { - t.Error("message should contain software") - } - if m.Contains(AttrNonce) { - t.Error("message should not contain nonce") - } + assert.True(t, m.Contains(AttrSoftware), "message should contain software") + assert.False(t, m.Contains(AttrNonce), "message should not contain nonce") } func ExampleMessage() { @@ -787,13 +666,9 @@ func TestAllocations(t *testing.T) { allocs := testing.AllocsPerRun(10, func() { m.Reset() m.WriteHeader() - if err := s.AddTo(m); err != nil { - t.Errorf("[%d] failed to add", i) - } + assert.NoError(t, s.AddTo(m), "[%d] failed to add", i) }) - if allocs > 0 { - t.Errorf("[%d] allocated %.0f", i, allocs) - } + assert.Zero(t, allocs, "[%d] allocated", i) } } @@ -818,9 +693,7 @@ func TestAllocationsGetters(t *testing.T) { Fingerprint, } msg := New() - if err := msg.Build(setters...); err != nil { - t.Error("failed to build", err) - } + assert.NoError(t, msg.Build(setters...)) getters := []Getter{ new(Nonce), new(Username), @@ -832,66 +705,48 @@ func TestAllocationsGetters(t *testing.T) { g := g i := i allocs := testing.AllocsPerRun(10, func() { - if err := g.GetFrom(msg); err != nil { - t.Errorf("[%d] failed to get", i) - } + assert.NoError(t, g.GetFrom(msg), "[%d] failed to get", i) }) - if allocs > 0 { - t.Errorf("[%d] allocated %.0f", i, allocs) - } + assert.Zero(t, allocs, "[%d] allocated", i) } } func TestMessageFullSize(t *testing.T) { msg := new(Message) - if err := msg.Build(BindingRequest, + assert.NoError(t, msg.Build(BindingRequest, NewTransactionIDSetter([TransactionIDSize]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, }), NewSoftware("pion/stun"), NewLongTermIntegrity("username", "realm", "password"), Fingerprint, - ); err != nil { - t.Fatal(err) - } + )) msg.Raw = msg.Raw[:len(msg.Raw)-10] decoder := new(Message) decoder.Raw = msg.Raw[:len(msg.Raw)-10] - if err := decoder.Decode(); err == nil { - t.Error("decode on truncated buffer should error") - } + assert.Error(t, decoder.Decode(), "decode on truncated buffer should error") } func TestMessage_CloneTo(t *testing.T) { msg := new(Message) - if err := msg.Build(BindingRequest, + assert.NoError(t, msg.Build(BindingRequest, NewTransactionIDSetter([TransactionIDSize]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, }), NewSoftware("pion/stun"), NewLongTermIntegrity("username", "realm", "password"), Fingerprint, - ); err != nil { - t.Fatal(err) - } + )) msg.Encode() msg2 := new(Message) - if err := msg.CloneTo(msg2); err != nil { - t.Fatal(err) - } - if !msg2.Equal(msg) { - t.Fatal("not equal") - } + assert.NoError(t, msg.CloneTo(msg2)) + assert.True(t, msg2.Equal(msg), "cloned message should equal original") // Corrupting m and checking that b is not corrupted. s, ok := msg2.Attributes.Get(AttrSoftware) - if !ok { - t.Fatal("no software attribute") - } + assert.True(t, ok) s.Value[0] = 'k' - if msg2.Equal(msg) { - t.Fatal("should not be equal") - } + assert.False(t, msg2.Equal(msg), "should not be equal") } func BenchmarkMessage_CloneTo(b *testing.B) { @@ -919,29 +774,21 @@ func BenchmarkMessage_CloneTo(b *testing.B) { func TestMessage_AddTo(t *testing.T) { msg := new(Message) - if err := msg.Build(BindingRequest, + assert.NoError(t, msg.Build(BindingRequest, NewTransactionIDSetter([TransactionIDSize]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, }), Fingerprint, - ); err != nil { - t.Fatal(err) - } + )) msg.Encode() b := new(Message) - if err := msg.CloneTo(b); err != nil { - t.Fatal(err) - } + assert.NoError(t, msg.CloneTo(b)) msg.TransactionID = [TransactionIDSize]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, } - if b.Equal(msg) { - t.Fatal("should not be equal") - } + assert.False(t, b.Equal(msg), "should not be equal") msg.AddTo(b) //nolint:errcheck,gosec - if !b.Equal(msg) { - t.Fatal("should be equal") - } + assert.True(t, b.Equal(msg), "should be equal") } func BenchmarkMessage_AddTo(b *testing.B) { @@ -966,9 +813,7 @@ func BenchmarkMessage_AddTo(b *testing.B) { func TestDecode(t *testing.T) { t.Run("Nil", func(t *testing.T) { - if err := Decode(nil, nil); !errors.Is(err, ErrDecodeToNil) { - t.Errorf("unexpected error: %v", err) - } + assert.ErrorIs(t, Decode(nil, nil), ErrDecodeToNil) }) msg := New() msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest} @@ -976,22 +821,14 @@ func TestDecode(t *testing.T) { msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) msg.WriteHeader() mDecoded := New() - if err := Decode(msg.Raw, mDecoded); err != nil { - t.Errorf("unexpected error: %v", err) - } - if !mDecoded.Equal(msg) { - t.Error("decoded result is not equal to encoded message") - } + assert.NoError(t, Decode(msg.Raw, mDecoded)) + assert.True(t, mDecoded.Equal(msg), "decoded result is not equal to encoded message") t.Run("ZeroAlloc", func(t *testing.T) { allocs := testing.AllocsPerRun(10, func() { mDecoded.Reset() - if err := Decode(msg.Raw, mDecoded); err != nil { - t.Error(err) - } + assert.NoError(t, Decode(msg.Raw, mDecoded)) }) - if allocs > 0 { - t.Error("unexpected allocations") - } + assert.Zero(t, allocs, "unexpected allocations") }) } @@ -1021,25 +858,19 @@ func TestMessage_MarshalBinary(t *testing.T) { }, ) data, err := msg.MarshalBinary() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) // Reset m.Raw to check retention. for i := range msg.Raw { msg.Raw[i] = 0 } - if err := msg.UnmarshalBinary(data); err != nil { - t.Fatal(err) - } + assert.NoError(t, msg.UnmarshalBinary(data)) // Reset data to check retention. for i := range data { data[i] = 0 } - if err := msg.Decode(); err != nil { - t.Fatal(err) - } + assert.NoError(t, msg.Decode()) } func TestMessage_GobDecode(t *testing.T) { @@ -1050,23 +881,17 @@ func TestMessage_GobDecode(t *testing.T) { }, ) data, err := msg.GobEncode() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) // Reset m.Raw to check retention. for i := range msg.Raw { msg.Raw[i] = 0 } - if err := msg.GobDecode(data); err != nil { - t.Fatal(err) - } + assert.NoError(t, msg.GobDecode(data)) // Reset data to check retention. for i := range data { data[i] = 0 } - if err := msg.Decode(); err != nil { - t.Fatal(err) - } + assert.NoError(t, msg.Decode()) } diff --git a/rfc5769_test.go b/rfc5769_test.go index 30ac3fb..eed76c1 100644 --- a/rfc5769_test.go +++ b/rfc5769_test.go @@ -6,6 +6,8 @@ package stun import ( "net" "testing" + + "github.com/stretchr/testify/assert" ) func TestRFC5769(t *testing.T) { //nolint:cyclop @@ -32,19 +34,11 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop "\xe5\x7a\x3b\xcf", ), } - if err := m.Decode(); err != nil { - t.Error(err) - } + assert.NoError(t, m.Decode()) software := new(Software) - if err := software.GetFrom(m); err != nil { - t.Error(err) - } - if software.String() != "STUN test client" { - t.Error("bad software: ", software) - } - if err := Fingerprint.Check(m); err != nil { - t.Error("check failed: ", err) - } + assert.NoError(t, software.GetFrom(m)) + assert.Equal(t, "STUN test client", software.String()) + assert.NoError(t, Fingerprint.Check(m)) t.Run("Long-Term credentials", func(t *testing.T) { msg := &Message{ Raw: []byte("\x00\x01\x00\x60" + @@ -64,40 +58,24 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop "\x2e\x85\xc9\xa2\x8c\xa8\x96\x66", ), } - if err := msg.Decode(); err != nil { - t.Error(err) - } + assert.NoError(t, msg.Decode()) u := new(Username) - if err := u.GetFrom(msg); err != nil { - t.Error(err) - } + assert.NoError(t, u.GetFrom(msg)) expectedUsername := "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9" - if u.String() != expectedUsername { - t.Errorf("username: %q (got) != %q (exp)", u, expectedUsername) - } + assert.Equal(t, expectedUsername, u.String()) n := new(Nonce) - if err := n.GetFrom(msg); err != nil { - t.Error(err) - } - if n.String() != "f//499k954d6OL34oL9FSTvy64sA" { - t.Error("bad nonce") - } + assert.NoError(t, n.GetFrom(msg)) + assert.Equal(t, "f//499k954d6OL34oL9FSTvy64sA", n.String()) r := new(Realm) - if err := r.GetFrom(msg); err != nil { - t.Error(err) - } - if r.String() != "example.org" { //nolint:goconst - t.Error("bad realm") - } + assert.NoError(t, r.GetFrom(msg)) + assert.Equal(t, "example.org", r.String()) // checking HMAC i := NewLongTermIntegrity( "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9", "example.org", "TheMatrIX", ) - if err := i.Check(msg); err != nil { - t.Error(err) - } + assert.NoError(t, i.Check(msg)) }) }) t.Run("Response", func(t *testing.T) { @@ -117,32 +95,18 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop "\xc0\x7d\x4c\x96", ), } - if err := msg.Decode(); err != nil { - t.Error(err) - } + assert.NoError(t, msg.Decode()) + software := new(Software) - if err := software.GetFrom(msg); err != nil { - t.Error(err) - } - if software.String() != "test vector" { - t.Error("bad software: ", software) - } - if err := Fingerprint.Check(msg); err != nil { - t.Error("Check failed: ", err) - } + assert.NoError(t, software.GetFrom(msg)) + assert.Equal(t, "test vector", software.String()) + assert.NoError(t, Fingerprint.Check(msg)) addr := new(XORMappedAddress) - if err := addr.GetFrom(msg); err != nil { - t.Error(err) - } - if !addr.IP.Equal(net.ParseIP("192.0.2.1")) { - t.Error("bad IP") - } - if addr.Port != 32853 { - t.Error("bad Port") - } - if err := Fingerprint.Check(msg); err != nil { - t.Error("check failed: ", err) - } + assert.NoError(t, addr.GetFrom(msg)) + expected := "192.0.2.1" + assert.Equalf(t, expected, addr.IP.String(), "Expected %s, got %s", expected, addr.IP) + assert.Equal(t, 32853, addr.Port) + assert.NoError(t, Fingerprint.Check(msg)) }) t.Run("IPv6", func(t *testing.T) { msg := &Message{ @@ -162,32 +126,20 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop "\xc8\xfb\x0b\x4c", ), } - if err := msg.Decode(); err != nil { - t.Error(err) - } + assert.NoError(t, msg.Decode()) software := new(Software) - if err := software.GetFrom(msg); err != nil { - t.Error(err) - } - if software.String() != "test vector" { - t.Error("bad software: ", software) - } - if err := Fingerprint.Check(msg); err != nil { - t.Error("Check failed: ", err) - } + assert.NoError(t, software.GetFrom(msg)) + assert.Equal(t, "test vector", software.String()) + assert.NoError(t, Fingerprint.Check(msg)) addr := new(XORMappedAddress) - if err := addr.GetFrom(msg); err != nil { - t.Error(err) - } - if !addr.IP.Equal(net.ParseIP("2001:db8:1234:5678:11:2233:4455:6677")) { - t.Error("bad IP") - } - if addr.Port != 32853 { - t.Error("bad Port") - } - if err := Fingerprint.Check(msg); err != nil { - t.Error("check failed: ", err) - } + assert.NoError(t, addr.GetFrom(msg)) + expectedIP := "2001:db8:1234:5678:11:2233:4455:6677" + assert.Truef( + t, addr.IP.Equal(net.ParseIP(expectedIP)), + "Expected %s, got %s", expectedIP, addr.IP, + ) + assert.Equal(t, 32853, addr.Port) + assert.NoError(t, Fingerprint.Check(msg)) }) }) } diff --git a/stun_test.go b/stun_test.go index 0348e5a..ad48577 100644 --- a/stun_test.go +++ b/stun_test.go @@ -6,6 +6,8 @@ package stun import ( "errors" "testing" + + "github.com/stretchr/testify/assert" ) type errorReader struct{} @@ -21,9 +23,7 @@ func (errorReader) Read([]byte) (int, error) { func TestReadFullHelper(t *testing.T) { defer func() { - if r := recover(); r == nil { - t.Error("should panic") - } + assert.NotNil(t, recover(), "should panic") }() readFullOrPanic(errorReader{}, make([]byte, 1)) } @@ -36,9 +36,7 @@ func (errorWriter) Write([]byte) (int, error) { func TestWriteHelper(t *testing.T) { defer func() { - if r := recover(); r == nil { - t.Error("should panic") - } + assert.NotNil(t, recover(), "should panic") }() writeOrPanic(errorWriter{}, make([]byte, 1)) } diff --git a/stuntest/udp_server.go b/stuntest/udp_server.go index 191b74d..0196742 100644 --- a/stuntest/udp_server.go +++ b/stuntest/udp_server.go @@ -9,6 +9,8 @@ import ( "fmt" "net" "testing" + + "github.com/stretchr/testify/assert" ) var errUDPServerUnsupportedNetwork = errors.New("unsupported network") @@ -37,9 +39,7 @@ func NewUDPServer( } udpConn, err := net.ListenUDP(network, &net.UDPAddr{IP: net.ParseIP(ip), Port: 0}) - if err != nil { - t.Fatal(err) //nolint:forbidigo - } + assert.NoError(t, err) // Necessary for IPv6 address := fmt.Sprintf("%s:%d", ip, udpConn.LocalAddr().(*net.UDPAddr).Port) //nolint:forcetypeassert @@ -81,18 +81,14 @@ func NewUDPServer( select { case err := <-errCh: if err != nil { - t.Fatal(err) + assert.NoError(t, err) return } default: } - err := udpConn.Close() - if err != nil { - t.Fatal(err) - } - + assert.NoError(t, udpConn.Close()) <-errCh }, nil } diff --git a/textattrs_test.go b/textattrs_test.go index 3935368..3391474 100644 --- a/textattrs_test.go +++ b/textattrs_test.go @@ -7,9 +7,10 @@ package stun import ( - "errors" "strings" "testing" + + "github.com/stretchr/testify/assert" ) func TestSoftware_GetFrom(t *testing.T) { @@ -22,44 +23,29 @@ func TestSoftware_GetFrom(t *testing.T) { Raw: make([]byte, 0, 256), } software := new(Software) - if _, err := m2.ReadFrom(msg.reader()); err != nil { - t.Error(err) - } - if err := software.GetFrom(msg); err != nil { - t.Fatal(err) - } - if software.String() != val { - t.Errorf("Expected %q, got %q.", val, software) - } + _, err := m2.ReadFrom(msg.reader()) + assert.NoError(t, err) + assert.NoError(t, software.GetFrom(msg)) + assert.Equal(t, val, software.String()) sAttr, ok := msg.Attributes.Get(AttrSoftware) - if !ok { - t.Error("software attribute should be found") - } + assert.True(t, ok, "software attribute should be found") s := sAttr.String() - if !strings.HasPrefix(s, "SOFTWARE:") { - t.Error("bad string representation", s) - } + assert.True(t, strings.HasPrefix(s, "SOFTWARE:"), "bad string representation") } func TestSoftware_AddTo_Invalid(t *testing.T) { m := New() s := make(Software, 1024) - if err := s.AddTo(m); !IsAttrSizeOverflow(err) { - t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err) - } - if err := s.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) { - t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) - } + assert.True(t, IsAttrSizeOverflow(s.AddTo(m)), "AddTo should return *AttrOverflowErr") + assert.ErrorIs(t, s.GetFrom(m), ErrAttributeNotFound) } func TestSoftware_AddTo_Regression(t *testing.T) { // s.AddTo checked len(m.Raw) instead of len(s.Raw). m := &Message{Raw: make([]byte, 2048)} s := make(Software, 100) - if err := s.AddTo(m); err != nil { - t.Errorf("AddTo should return , got: %v", err) - } + assert.NoError(t, s.AddTo(m)) } func BenchmarkUsername_AddTo(b *testing.B) { @@ -95,28 +81,18 @@ func TestUsername(t *testing.T) { msg.WriteHeader() t.Run("Bad length", func(t *testing.T) { badU := make(Username, 600) - if err := badU.AddTo(msg); !IsAttrSizeOverflow(err) { - t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err) - } + assert.True(t, IsAttrSizeOverflow(badU.AddTo(msg)), "AddTo should return *AttrOverflowErr") }) t.Run("AddTo", func(t *testing.T) { - if err := uName.AddTo(msg); err != nil { - t.Error("errored:", err) - } + assert.NoError(t, uName.AddTo(msg)) t.Run("GetFrom", func(t *testing.T) { got := new(Username) - if err := got.GetFrom(msg); err != nil { - t.Error("errored:", err) - } - if got.String() != username { - t.Errorf("expedted: %s, got: %s", username, got) - } + assert.NoError(t, got.GetFrom(msg)) + assert.Equal(t, username, got.String()) t.Run("Not found", func(t *testing.T) { m := new(Message) u := new(Username) - if err := u.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) { - t.Error("Should error") - } + assert.ErrorIs(t, u.GetFrom(m), ErrAttributeNotFound) }) }) }) @@ -124,14 +100,10 @@ func TestUsername(t *testing.T) { m := new(Message) m.WriteHeader() u := NewUsername("username") - if allocs := testing.AllocsPerRun(10, func() { - if err := u.AddTo(m); err != nil { - t.Error(err) - } + assert.Empty(t, testing.AllocsPerRun(10, func() { + assert.NoError(t, u.AddTo(m)) m.Reset() - }); allocs > 0 { - t.Errorf("got %f allocations, zero expected", allocs) - } + })) }) } @@ -145,38 +117,23 @@ func TestRealm_GetFrom(t *testing.T) { Raw: make([]byte, 0, 256), } r := new(Realm) - if err := r.GetFrom(m2); !errors.Is(err, ErrAttributeNotFound) { - t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) - } - if _, err := m2.ReadFrom(msg.reader()); err != nil { - t.Error(err) - } - if err := r.GetFrom(msg); err != nil { - t.Fatal(err) - } - if r.String() != val { - t.Errorf("Expected %q, got %q.", val, r) - } + assert.ErrorIs(t, r.GetFrom(m2), ErrAttributeNotFound) + _, err := m2.ReadFrom(msg.reader()) + assert.NoError(t, err) + assert.NoError(t, r.GetFrom(msg)) + assert.Equal(t, val, r.String()) rAttr, ok := msg.Attributes.Get(AttrRealm) - if !ok { - t.Error("realm attribute should be found") - } + assert.True(t, ok, "realm attribute should be found") s := rAttr.String() - if !strings.HasPrefix(s, "REALM:") { - t.Error("bad string representation", s) - } + assert.True(t, strings.HasPrefix(s, "REALM:"), "bad string representation") } func TestRealm_AddTo_Invalid(t *testing.T) { m := New() r := make(Realm, 1024) - if err := r.AddTo(m); !IsAttrSizeOverflow(err) { - t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err) - } - if err := r.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) { - t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) - } + assert.True(t, IsAttrSizeOverflow(r.AddTo(m)), "AddTo should return *AttrOverflowErr") + assert.ErrorIs(t, r.GetFrom(m), ErrAttributeNotFound) } func TestNonce_GetFrom(t *testing.T) { @@ -189,50 +146,31 @@ func TestNonce_GetFrom(t *testing.T) { Raw: make([]byte, 0, 256), } var nonce Nonce - if _, err := m2.ReadFrom(msg.reader()); err != nil { - t.Error(err) - } - if err := nonce.GetFrom(msg); err != nil { - t.Fatal(err) - } - if nonce.String() != val { - t.Errorf("Expected %q, got %q.", val, nonce) - } + _, err := m2.ReadFrom(msg.reader()) + assert.NoError(t, err) + assert.NoError(t, nonce.GetFrom(msg)) + assert.Equal(t, val, nonce.String()) nAttr, ok := msg.Attributes.Get(AttrNonce) - if !ok { - t.Error("nonce attribute should be found") - } + assert.True(t, ok, "nonce attribute should be found") s := nAttr.String() - if !strings.HasPrefix(s, "NONCE:") { - t.Error("bad string representation", s) - } + assert.True(t, strings.HasPrefix(s, "NONCE:"), "bad string representation") } func TestNonce_AddTo_Invalid(t *testing.T) { m := New() n := make(Nonce, 1024) - if err := n.AddTo(m); !IsAttrSizeOverflow(err) { - t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err) - } - if err := n.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) { - t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) - } + assert.True(t, IsAttrSizeOverflow(n.AddTo(m)), "AddTo should return *AttrOverflowErr") + assert.ErrorIs(t, n.GetFrom(m), ErrAttributeNotFound) } func TestNonce_AddTo(t *testing.T) { m := New() n := Nonce("example.org") - if err := n.AddTo(m); err != nil { - t.Error(err) - } + assert.NoError(t, n.AddTo(m)) v, err := m.Get(AttrNonce) - if err != nil { - t.Error(err) - } - if string(v) != "example.org" { - t.Errorf("bad nonce %q", v) - } + assert.NoError(t, err) + assert.Equal(t, "example.org", string(v)) } func BenchmarkNonce_AddTo(b *testing.B) { diff --git a/uattrs_test.go b/uattrs_test.go index c70872c..b3000a8 100644 --- a/uattrs_test.go +++ b/uattrs_test.go @@ -5,6 +5,8 @@ package stun import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestUnknownAttributes(t *testing.T) { @@ -13,33 +15,19 @@ func TestUnknownAttributes(t *testing.T) { AttrDontFragment, AttrChannelNumber, } - if attr.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" { - t.Error("bad String:", attr) - } - if (UnknownAttributes{}).String() != "" { - t.Error("bad blank string") - } - if err := attr.AddTo(msg); err != nil { - t.Error(err) - } + assert.Equal(t, "DONT-FRAGMENT, CHANNEL-NUMBER", attr.String()) + assert.Equal(t, "", (UnknownAttributes{}).String()) + assert.NoError(t, attr.AddTo(msg)) t.Run("GetFrom", func(t *testing.T) { attrs := make(UnknownAttributes, 10) - if err := attrs.GetFrom(msg); err != nil { - t.Error(err) - } + assert.NoError(t, attrs.GetFrom(msg)) for i, at := range *attr { - if at != attrs[i] { - t.Error("expected", at, "!=", attrs[i]) - } + assert.Equal(t, at, attrs[i]) } mBlank := new(Message) - if err := attrs.GetFrom(mBlank); err == nil { - t.Error("should error") - } + assert.Error(t, attrs.GetFrom(mBlank)) mBlank.Add(AttrUnknownAttributes, []byte{1, 2, 3}) - if err := attrs.GetFrom(mBlank); err == nil { - t.Error("should error") - } + assert.Error(t, attrs.GetFrom(mBlank)) }) } diff --git a/xoraddr_test.go b/xoraddr_test.go index f1017f8..73cfb3f 100644 --- a/xoraddr_test.go +++ b/xoraddr_test.go @@ -11,10 +11,11 @@ import ( "encoding/base64" "encoding/binary" "encoding/hex" - "errors" "io" "net" "testing" + + "github.com/stretchr/testify/assert" ) func BenchmarkXORMappedAddress_AddTo(b *testing.B) { @@ -31,82 +32,56 @@ func BenchmarkXORMappedAddress_AddTo(b *testing.B) { func BenchmarkXORMappedAddress_GetFrom(b *testing.B) { msg := New() transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") - if err != nil { - b.Error(err) - } + assert.NoError(b, err) copy(msg.TransactionID[:], transactionID) addrValue, err := hex.DecodeString("00019cd5f49f38ae") - if err != nil { - b.Error(err) - } + assert.NoError(b, err) msg.Add(AttrXORMappedAddress, addrValue) addr := new(XORMappedAddress) b.ReportAllocs() for i := 0; i < b.N; i++ { - if err := addr.GetFrom(msg); err != nil { - b.Fatal(err) - } + assert.NoError(b, addr.GetFrom(msg)) } } func TestXORMappedAddress_GetFrom(t *testing.T) { m := New() transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") - if err != nil { - t.Error(err) - } + assert.NoError(t, err) copy(m.TransactionID[:], transactionID) addrValue, err := hex.DecodeString("00019cd5f49f38ae") - if err != nil { - t.Error(err) - } + assert.NoError(t, err) m.Add(AttrXORMappedAddress, addrValue) addr := new(XORMappedAddress) - if err = addr.GetFrom(m); err != nil { - t.Error(err) - } - if !addr.IP.Equal(net.ParseIP("213.141.156.236")) { - t.Error("bad IP", addr.IP, "!=", "213.141.156.236") - } - if addr.Port != 48583 { - t.Error("bad Port", addr.Port, "!=", 48583) - } + assert.NoError(t, addr.GetFrom(m)) + assert.True(t, addr.IP.Equal(net.ParseIP("213.141.156.236"))) + assert.Equal(t, 48583, addr.Port) t.Run("UnexpectedEOF", func(t *testing.T) { m := New() // {0, 1} is correct addr family. m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4}) addr := new(XORMappedAddress) - if err = addr.GetFrom(m); !errors.Is(err, io.ErrUnexpectedEOF) { - t.Errorf("len(v) = 4 should render <%s> error, got <%s>", - io.ErrUnexpectedEOF, err, - ) - } + assert.ErrorIs(t, addr.GetFrom(m), io.ErrUnexpectedEOF, "len(v) = 4 should return io.ErrUnexpectedEOF") }) t.Run("AttrOverflowErr", func(t *testing.T) { m := New() // {0, 1} is correct addr family. m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4, 5, 6, 7, 8, 9, 1, 1, 1, 1, 1, 2, 3, 4}) addr := new(XORMappedAddress) - if err := addr.GetFrom(m); !IsAttrSizeOverflow(err) { - t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err) - } + assert.True(t, IsAttrSizeOverflow(addr.GetFrom(m)), "GetFrom should return *AttrOverflowErr") }) } func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) { msg := New() transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") - if err != nil { - t.Error(err) - } + assert.NoError(t, err) copy(msg.TransactionID[:], transactionID) expectedIP := net.ParseIP("213.141.156.236") expectedPort := 21254 addr := new(XORMappedAddress) - if err = addr.GetFrom(msg); err == nil { - t.Fatal(err, "should be nil") - } + assert.Error(t, addr.GetFrom(msg)) addr.IP = expectedIP addr.Port = expectedPort @@ -115,20 +90,15 @@ func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) { mRes := New() binary.BigEndian.PutUint16(msg.Raw[20+4:20+4+2], 0x21) - if _, err = mRes.ReadFrom(bytes.NewReader(msg.Raw)); err != nil { - t.Fatal(err) - } - if err = addr.GetFrom(msg); err == nil { - t.Fatal(err, "should not be nil") - } + _, err = mRes.ReadFrom(bytes.NewReader(msg.Raw)) + assert.NoError(t, err) + assert.Error(t, addr.GetFrom(msg)) } func TestXORMappedAddress_AddTo(t *testing.T) { msg := New() transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") - if err != nil { - t.Error(err) - } + assert.NoError(t, err) copy(msg.TransactionID[:], transactionID) expectedIP := net.ParseIP("213.141.156.236") expectedPort := 21254 @@ -136,31 +106,20 @@ func TestXORMappedAddress_AddTo(t *testing.T) { IP: net.ParseIP("213.141.156.236"), Port: expectedPort, } - if err = addr.AddTo(msg); err != nil { - t.Fatal(err) - } + assert.NoError(t, addr.AddTo(msg)) msg.WriteHeader() mRes := New() - if _, err = mRes.Write(msg.Raw); err != nil { - t.Fatal(err) - } - if err = addr.GetFrom(mRes); err != nil { - t.Fatal(err) - } - if !addr.IP.Equal(expectedIP) { - t.Errorf("%s (got) != %s (expected)", addr.IP, expectedIP) - } - if addr.Port != expectedPort { - t.Error("bad Port", addr.Port, "!=", expectedPort) - } + _, err = mRes.Write(msg.Raw) + assert.NoError(t, err) + assert.NoError(t, addr.GetFrom(mRes)) + assert.True(t, addr.IP.Equal(expectedIP), "Expected %s, got %s", expectedIP, addr.IP) + assert.Equal(t, expectedPort, addr.Port) } func TestXORMappedAddress_AddTo_IPv6(t *testing.T) { msg := New() transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") - if err != nil { - t.Error(err) - } + assert.NoError(t, err) copy(msg.TransactionID[:], transactionID) expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009") expectedPort := 21254 @@ -172,19 +131,12 @@ func TestXORMappedAddress_AddTo_IPv6(t *testing.T) { msg.WriteHeader() mRes := New() - if _, err = mRes.ReadFrom(msg.reader()); err != nil { - t.Fatal(err) - } + _, err = mRes.ReadFrom(msg.reader()) + assert.NoError(t, err) gotAddr := new(XORMappedAddress) - if err = gotAddr.GetFrom(msg); err != nil { - t.Fatal(err) - } - if !gotAddr.IP.Equal(expectedIP) { - t.Error("bad IP", gotAddr.IP, "!=", expectedIP) - } - if gotAddr.Port != expectedPort { - t.Error("bad Port", gotAddr.Port, "!=", expectedPort) - } + assert.NoError(t, gotAddr.GetFrom(mRes)) + assert.True(t, gotAddr.IP.Equal(expectedIP), "Expected %s, got %s", expectedIP, gotAddr.IP) + assert.Equal(t, expectedPort, gotAddr.Port) } func TestXORMappedAddress_AddTo_Invalid(t *testing.T) { @@ -193,9 +145,7 @@ func TestXORMappedAddress_AddTo_Invalid(t *testing.T) { IP: []byte{1, 2, 3, 4, 5, 6, 7, 8}, Port: 21254, } - if err := addr.AddTo(m); !errors.Is(err, ErrBadIPLength) { - t.Errorf("AddTo should return %q, got: %v", ErrBadIPLength, err) - } + assert.ErrorIs(t, addr.AddTo(m), ErrBadIPLength) } func TestXORMappedAddress_String(t *testing.T) { @@ -219,12 +169,6 @@ func TestXORMappedAddress_String(t *testing.T) { }, } for i, c := range tt { - if got := c.in.String(); got != c.out { - t.Errorf("[%d]: XORMappesAddres.String() %s (got) != %s (expected)", - i, - got, - c.out, - ) - } + assert.Equalf(t, c.out, c.in.String(), "[%d]: XORMappesAddres.String()", i) } }