diff --git a/.golangci.yml b/.golangci.yml index 88cb4fbf9..0dfe0d604 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -19,12 +19,16 @@ linters-settings: recommendations: - errors forbidigo: + analyze-types: true forbid: - ^fmt.Print(f|ln)?$ - ^log.(Panic|Fatal|Print)(f|ln)?$ - ^os.Exit$ - ^panic$ - ^print(ln)?$ + - p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: "use testify/assert instead" varnamelen: max-distance: 12 min-name-length: 2 @@ -123,13 +127,18 @@ linters: - wsl # Whitespace Linter - Forces you to use empty lines! issues: + max-issues-per-linter: 0 + max-same-issues: 0 exclude-use-default: false exclude-dirs-use-default: false exclude-rules: # Allow complex tests and examples, better to be self contained - - path: (examples|main\.go|_test\.go) + - path: (examples|main\.go) linters: + - gocognit - forbidigo + - path: _test\.go + linters: - gocognit # Allow forbidden identifiers in CLI commands diff --git a/bench_test.go b/bench_test.go index 8d90786cb..ac700b93b 100644 --- a/bench_test.go +++ b/bench_test.go @@ -15,6 +15,7 @@ import ( "github.com/pion/logging" "github.com/pion/transport/v3/dpipe" "github.com/pion/transport/v3/test" + "github.com/stretchr/testify/assert" ) func TestSimpleReadWrite(t *testing.T) { @@ -25,9 +26,7 @@ func TestSimpleReadWrite(t *testing.T) { ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) gotHello := make(chan struct{}) go func() { @@ -35,41 +34,30 @@ func TestSimpleReadWrite(t *testing.T) { Certificates: []tls.Certificate{certificate}, LoggerFactory: logging.NewDefaultLoggerFactory(), }, false) - if sErr != nil { - t.Error(sErr) + assert.NoError(t, sErr) - return - } buf := make([]byte, 1024) - if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck - t.Error(sErr) - } + _, sErr = server.Read(buf) //nolint:contextcheck + assert.NoError(t, sErr) + gotHello <- struct{}{} - if sErr = server.Close(); sErr != nil { //nolint:contextcheck - t.Error(sErr) - } + assert.NoError(t, server.Close()) //nolint:contextcheck }() client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) - if err != nil { - t.Fatal(err) - } - if _, err = client.Write([]byte("hello")); err != nil { - t.Error(err) - } + assert.NoError(t, err) + _, err = client.Write([]byte("hello")) + assert.NoError(t, err) select { case <-gotHello: // OK case <-time.After(time.Second * 5): - t.Error("timeout") - } - - if err = client.Close(); err != nil { - t.Error(err) + assert.Fail(t, "timeout") } + assert.NoError(t, client.Close()) } func benchmarkConn(b *testing.B, payloadSize int64) { @@ -80,21 +68,18 @@ func benchmarkConn(b *testing.B, payloadSize int64) { ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() + assert.NoError(b, err) server := make(chan *Conn) + go func() { s, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, }, false) - if err != nil { - b.Error(sErr) + assert.NoError(b, sErr) - return - } server <- s }() - if err != nil { - b.Fatal(err) - } + hw := make([]byte, payloadSize) b.ReportAllocs() b.SetBytes(int64(len(hw))) @@ -102,21 +87,17 @@ func benchmarkConn(b *testing.B, payloadSize int64) { client, cErr := testClient( ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{InsecureSkipVerify: true}, false, ) - if cErr != nil { - b.Error(err) - } + assert.NoError(b, cErr) for { - if _, cErr = client.Write(hw); cErr != nil { //nolint:contextcheck - b.Error(err) - } + _, cErr = client.Write(hw) //nolint:contextcheck + assert.NoError(b, cErr) } }() s := <-server buf := make([]byte, 2048) for i := 0; i < b.N; i++ { - if _, err = s.Read(buf); err != nil { - b.Error(err) - } + _, err = s.Read(buf) + assert.NoError(b, err) } }) } diff --git a/certificate_test.go b/certificate_test.go index 37598a639..ab96fe40e 100644 --- a/certificate_test.go +++ b/certificate_test.go @@ -5,27 +5,21 @@ package dtls import ( "crypto/tls" - "reflect" "testing" "github.com/pion/dtls/v3/pkg/crypto/selfsign" + "github.com/stretchr/testify/assert" ) func TestGetCertificate(t *testing.T) { certificateWildcard, err := selfsign.GenerateSelfSignedWithDNS("*.test.test") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) certificateTest, err := selfsign.GenerateSelfSignedWithDNS("test.test", "www.test.test", "pop.test.test") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) certificateRandom, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) testCases := []struct { localCertificates []tls.Certificate @@ -92,13 +86,8 @@ func TestGetCertificate(t *testing.T) { localGetCertificate: test.getCertificate, } cert, err := cfg.getCertificate(&ClientHelloInfo{ServerName: test.serverName}) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(cert.Leaf, test.expectedCertificate.Leaf) { - t.Fatalf("Certificate does not match: expected(%v) actual(%v)", test.expectedCertificate.Leaf, cert.Leaf) - } + assert.NoError(t, err) + assert.Equal(t, test.expectedCertificate.Leaf, cert.Leaf, "Certificate Leaf should match expected") }) } } diff --git a/cipher_suite_go114_test.go b/cipher_suite_go114_test.go index e93b760c5..86b3e56e7 100644 --- a/cipher_suite_go114_test.go +++ b/cipher_suite_go114_test.go @@ -8,48 +8,29 @@ package dtls import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestInsecureCipherSuites(t *testing.T) { - r := InsecureCipherSuites() - - if len(r) != 0 { - t.Fatalf("Expected no insecure ciphersuites, got %d", len(r)) - } + assert.Empty(t, InsecureCipherSuites(), "Expected no insecure ciphersuites") } func TestCipherSuites(t *testing.T) { ours := allCipherSuites() theirs := CipherSuites() - - if len(ours) != len(theirs) { - t.Fatalf("Expected %d CipherSuites, got %d", len(ours), len(theirs)) - } + assert.Equal(t, len(ours), len(theirs)) for i, s := range ours { i := i s := s t.Run(s.String(), func(t *testing.T) { cipher := theirs[i] - if cipher.ID != uint16(s.ID()) { - t.Fatalf("Expected ID: 0x%04X, got 0x%04X", s.ID(), cipher.ID) - } - - if cipher.Name != s.String() { - t.Fatalf("Expected Name: %s, got %s", s.String(), cipher.Name) - } - - if len(cipher.SupportedVersions) != 1 { - t.Fatalf("Expected %d SupportedVersion, got %d", 1, len(cipher.SupportedVersions)) - } - - if cipher.SupportedVersions[0] != VersionDTLS12 { - t.Fatalf("Expected SupportedVersions 0x%04X, got 0x%04X", VersionDTLS12, cipher.SupportedVersions[0]) - } - - if cipher.Insecure { - t.Fatalf("Expected Insecure %t, got %t", false, cipher.Insecure) - } + assert.Equal(t, cipher.ID, uint16(s.ID())) + assert.Equal(t, cipher.Name, s.String()) + assert.Equal(t, 1, len(cipher.SupportedVersions), "Expected SupportedVersion to be 1") + assert.Equal(t, uint16(VersionDTLS12), cipher.SupportedVersions[0], "Expected SupportedVersion to match") + assert.False(t, cipher.Insecure, "Expected Insecure") }) } } diff --git a/cipher_suite_test.go b/cipher_suite_test.go index c4fd4840a..7ec02f41a 100644 --- a/cipher_suite_test.go +++ b/cipher_suite_test.go @@ -12,6 +12,7 @@ import ( dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/pion/transport/v3/dpipe" "github.com/pion/transport/v3/test" + "github.com/stretchr/testify/assert" ) func TestCipherSuiteName(t *testing.T) { @@ -24,18 +25,12 @@ func TestCipherSuiteName(t *testing.T) { } for _, testCase := range testCases { - res := CipherSuiteName(testCase.suite) - if res != testCase.expected { - t.Fatalf("Expected: %s, got %s", testCase.expected, res) - } + assert.Equal(t, testCase.expected, CipherSuiteName(testCase.suite)) } } func TestAllCipherSuites(t *testing.T) { - actual := len(allCipherSuites()) - if actual == 0 { - t.Fatal() - } + assert.NotEmpty(t, allCipherSuites()) } // CustomCipher that is just used to assert Custom IDs work. @@ -84,18 +79,10 @@ func TestCustomCipherSuite(t *testing.T) { }, true) clientResult := <-resultCh - - if err != nil { - t.Error(err) - } else { - _ = server.Close() - } - - if clientResult.err != nil { - t.Error(clientResult.err) - } else { - _ = clientResult.c.Close() - } + assert.NoError(t, err) + assert.NoError(t, server.Close()) + assert.Nil(t, clientResult.err) + assert.NoError(t, clientResult.c.Close()) } t.Run("Custom ID", func(*testing.T) { diff --git a/config_test.go b/config_test.go index b01de1442..126c424e7 100644 --- a/config_test.go +++ b/config_test.go @@ -12,31 +12,32 @@ import ( "testing" "github.com/pion/dtls/v3/pkg/crypto/selfsign" + "github.com/stretchr/testify/assert" ) -func TestValidateConfig(t *testing.T) { //nolint:cyclop +func TestValidateConfig(t *testing.T) { cert, err := selfsign.GenerateSelfSigned() if err != nil { - t.Fatalf("TestValidateConfig: Config validation error(%v), self signed certificate not generated", err) + assert.NoError(t, err, "TestValidateConfig: Config validation error, self signed certificate not generated") return } dsaPrivateKey := &dsa.PrivateKey{} err = dsa.GenerateParameters(&dsaPrivateKey.Parameters, rand.Reader, dsa.L1024N160) if err != nil { - t.Fatalf("TestValidateConfig: Config validation error(%v), DSA parameters not generated", err) + assert.NoError(t, err, "TestValidateConfig: Config validation error, DSA parameters not generated") return } err = dsa.GenerateKey(dsaPrivateKey, rand.Reader) if err != nil { - t.Fatalf("TestValidateConfig: Config validation error(%v), DSA private key not generated", err) + assert.NoError(t, err, "TestValidateConfig: Config validation error, DSA private key not generated") return } rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - t.Fatalf("TestValidateConfig: Config validation error(%v), RSA private key not generated", err) + assert.NoError(t, err, "TestValidateConfig: Config validation error, RSA private key not generated") return } @@ -133,11 +134,9 @@ func TestValidateConfig(t *testing.T) { //nolint:cyclop err := validateConfig(testCase.config) if testCase.expErr != nil || testCase.wantAnyErr { if testCase.expErr != nil && !errors.Is(err, testCase.expErr) { - t.Fatalf("TestValidateConfig: Config validation error exp(%v) failed(%v)", testCase.expErr, err) - } - if err == nil { - t.Fatalf("TestValidateConfig: Config validation expected an error") + assert.ErrorIs(t, err, testCase.expErr, "TestValidateConfig") } + assert.Error(t, err, "TestValidateConfig: Config validation expected an error") } }) } diff --git a/conn_go_test.go b/conn_go_test.go index b22e7c71e..8fa1004aa 100644 --- a/conn_go_test.go +++ b/conn_go_test.go @@ -7,7 +7,6 @@ package dtls import ( - "bytes" "context" "crypto/tls" "errors" @@ -19,6 +18,7 @@ import ( dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/pion/transport/v3/dpipe" "github.com/pion/transport/v3/test" + "github.com/stretchr/testify/assert" ) func TestContextConfig(t *testing.T) { //nolint:cyclop @@ -30,27 +30,20 @@ func TestContextConfig(t *testing.T) { //nolint:cyclop defer report() addrListen, err := net.ResolveUDPAddr("udp", "localhost:0") - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + assert.NoError(t, err) // Dummy listener listen, err := net.ListenUDP("udp", addrListen) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + assert.NoError(t, err) defer func() { _ = listen.Close() }() addr, ok := listen.LocalAddr().(*net.UDPAddr) - if !ok { - t.Fatal("Failed to cast net.UDPAddr") - } + assert.True(t, ok) cert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + assert.NoError(t, err) + config := &Config{ Certificates: []tls.Certificate{cert}, } @@ -127,7 +120,7 @@ func TestContextConfig(t *testing.T) { //nolint:cyclop defer cancel() var netError net.Error if !errors.As(err, &netError) || !netError.Temporary() { //nolint:staticcheck - t.Errorf("Client error exp(Temporary network error) failed(%v)", err) + assert.Fail(t, "Dial failed with unexpected error", "err: %v", err) close(done) return @@ -156,9 +149,7 @@ func TestContextConfig(t *testing.T) { //nolint:cyclop } } }() - if !bytes.Equal(dial.order, order) { - t.Errorf("Invalid cancel timing, expected: %v, got: %v", dial.order, order) - } + assert.Equal(t, dial.order, order, "Invalid cancel timing") }) } } diff --git a/conn_test.go b/conn_test.go index ed821a57c..28bece05c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -39,6 +39,7 @@ import ( "github.com/pion/logging" "github.com/pion/transport/v3/dpipe" "github.com/pion/transport/v3/test" + "github.com/stretchr/testify/assert" ) var ( @@ -66,19 +67,11 @@ func stressDuplex(t *testing.T) { t.Helper() ca, cb, err := pipeMemory() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) defer func() { - err = ca.Close() - if err != nil { - t.Fatal(err) - } - err = cb.Close() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, ca.Close()) + assert.NoError(t, cb.Close()) }() opt := test.Options{ @@ -86,10 +79,7 @@ func stressDuplex(t *testing.T) { MsgCount: 100, } - err = test.StressDuplex(ca, cb, opt) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, test.StressDuplex(ca, cb, opt)) } func TestRoutineLeakOnClose(t *testing.T) { @@ -102,24 +92,17 @@ func TestRoutineLeakOnClose(t *testing.T) { defer report() ca, cb, err := pipeMemory() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) - if _, err := ca.Write(make([]byte, 100)); err != nil { - t.Fatal(err) - } - if err := cb.Close(); err != nil { - t.Fatal(err) - } - if err := ca.Close(); err != nil { - t.Fatal(err) - } + _, err = ca.Write(make([]byte, 100)) + assert.NoError(t, err) + assert.NoError(t, cb.Close()) + assert.NoError(t, ca.Close()) // Packet is sent, but not read. // inboundLoop routine should not be leaked. } -func TestReadWriteDeadline(t *testing.T) { //nolint:cyclop +func TestReadWriteDeadline(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() @@ -131,52 +114,26 @@ func TestReadWriteDeadline(t *testing.T) { //nolint:cyclop var netErr net.Error ca, cb, err := pipeMemory() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + assert.NoError(t, ca.SetDeadline(time.Unix(0, 1))) - if err := ca.SetDeadline(time.Unix(0, 1)); err != nil { - t.Fatal(err) - } _, werr := ca.Write(make([]byte, 100)) - if errors.As(werr, &netErr) { - if !netErr.Timeout() { - t.Error("Deadline exceeded Write must return Timeout error") - } - if !netErr.Temporary() { //nolint:staticcheck - t.Error("Deadline exceeded Write must return Temporary error") - } - } else { - t.Error("Write must return net.Error error") - } - _, rerr := ca.Read(make([]byte, 100)) - if errors.As(rerr, &netErr) { - if !netErr.Timeout() { - t.Error("Deadline exceeded Read must return Timeout error") - } - if !netErr.Temporary() { //nolint:staticcheck - t.Error("Deadline exceeded Read must return Temporary error") - } - } else { - t.Error("Read must return net.Error error") - } - if err := ca.SetDeadline(time.Time{}); err != nil { - t.Error(err) - } + assert.ErrorAs(t, werr, &netErr, "Write must return net.Error") + assert.True(t, netErr.Timeout(), "Deadline exceeded Write must return Timeout") + assert.True(t, netErr.Temporary(), "Deadline exceeded Write must return Temporary") //nolint:staticcheck - if err := ca.Close(); err != nil { - t.Error(err) - } - if err := cb.Close(); err != nil { - t.Error(err) - } - - if _, err := ca.Write(make([]byte, 100)); !errors.Is(err, ErrConnClosed) { - t.Errorf("Write must return %v after close, got %v", ErrConnClosed, err) - } - if _, err := ca.Read(make([]byte, 100)); !errors.Is(err, io.EOF) { - t.Errorf("Read must return %v after close, got %v", io.EOF, err) - } + _, rerr := ca.Read(make([]byte, 100)) + assert.ErrorAs(t, rerr, &netErr, "Read must return net.Error") + assert.True(t, netErr.Timeout(), "Deadline exceeded Read must return Timeout") + assert.True(t, netErr.Temporary(), "Deadline exceeded Read must return Temporary") //nolint:staticcheck + assert.NoError(t, ca.SetDeadline(time.Time{})) + assert.NoError(t, ca.Close()) + assert.NoError(t, cb.Close()) + + _, err = ca.Write(make([]byte, 100)) + assert.ErrorIs(t, err, ErrConnClosed) + _, err = ca.Read(make([]byte, 100)) + assert.ErrorIs(t, err, io.EOF) } func TestSequenceNumberOverflow(t *testing.T) { @@ -190,30 +147,20 @@ func TestSequenceNumberOverflow(t *testing.T) { t.Run("ApplicationData", func(t *testing.T) { ca, cb, err := pipeMemory() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) atomic.StoreUint64(&ca.state.localSequenceNumber[1], recordlayer.MaxSequenceNumber) - if _, werr := ca.Write(make([]byte, 100)); werr != nil { - t.Errorf("Write must send message with maximum sequence number, but errord: %v", werr) - } - if _, werr := ca.Write(make([]byte, 100)); !errors.Is(werr, errSequenceNumberOverflow) { - t.Errorf("Write must abandonsend message with maximum sequence number, but errord: %v", werr) - } + _, werr := ca.Write(make([]byte, 100)) + assert.NoError(t, werr, "Write must send message with maximum sequence number") + _, werr = ca.Write(make([]byte, 100)) + assert.ErrorIs(t, werr, errSequenceNumberOverflow, "Write must abandonsend message with maximum sequence number") - if err := ca.Close(); err != nil { - t.Error(err) - } - if err := cb.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, ca.Close()) + assert.NoError(t, cb.Close()) }) t.Run("Handshake", func(t *testing.T) { ca, cb, err := pipeMemory() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -221,7 +168,7 @@ func TestSequenceNumberOverflow(t *testing.T) { atomic.StoreUint64(&ca.state.localSequenceNumber[0], recordlayer.MaxSequenceNumber+1) // Try to send handshake packet. - if werr := ca.writePackets(ctx, []*packet{ + werr := ca.writePackets(ctx, []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ @@ -237,16 +184,11 @@ func TestSequenceNumberOverflow(t *testing.T) { }, }, }, - }); !errors.Is(werr, errSequenceNumberOverflow) { - t.Errorf("Connection must fail on handshake packet reaches maximum sequence number") - } - - if err := ca.Close(); err != nil { - t.Error(err) - } - if err := cb.Close(); err != nil { - t.Error(err) - } + }) + assert.ErrorIs(t, werr, errSequenceNumberOverflow, + "Connection must fail when handshake packet reaches maximum sequence num") + assert.NoError(t, ca.Close()) + assert.NoError(t, cb.Close()) }) } @@ -421,14 +363,8 @@ func TestHandshakeWithAlert(t *testing.T) { }() _, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), testCase.configServer, true) - if !errors.Is(errServer, testCase.errServer) { - t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer) - } - - errClient := <-clientErr - if !errors.Is(errClient, testCase.errClient) { - t.Fatalf("Client error exp(%v) failed(%v)", testCase.errClient, errClient) - } + assert.ErrorIs(t, errServer, testCase.errServer) + assert.ErrorIs(t, <-clientErr, testCase.errClient) }) } } @@ -457,9 +393,8 @@ func TestHandshakeWithInvalidRecord(t *testing.T) { // Send invalid record after first message caWithInvalidRecord.onWrite = func([]byte) { if msgSeq.Add(1) == 2 { - if _, err := ca.Write([]byte{0x01, 0x02}); err != nil { - t.Fatal(err) - } + _, err := ca.Write([]byte{0x01, 0x02}) + assert.NoError(t, err) } } go func() { @@ -481,28 +416,19 @@ func TestHandshakeWithInvalidRecord(t *testing.T) { defer func() { if server != nil { - if err := server.Close(); err != nil { - t.Fatal(err) - } + assert.NoError(t, server.Close()) } if errClient.c != nil { - if err := errClient.c.Close(); err != nil { - t.Fatal(err) - } + assert.NoError(t, errClient.c.Close()) } }() - if errServer != nil { - t.Fatalf("Server failed(%v)", errServer) - } - - if errClient.err != nil { - t.Fatalf("Client failed(%v)", errClient.err) - } + assert.NoError(t, errServer) + assert.NoError(t, errClient.err) } -func TestExportKeyingMaterial(t *testing.T) { //nolint:cyclop +func TestExportKeyingMaterial(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -525,60 +451,43 @@ func TestExportKeyingMaterial(t *testing.T) { //nolint:cyclop conn.setRemoteEpoch(0) state, ok := conn.ConnectionState() - if !ok { - t.Fatal("ConnectionState failed") - } + assert.True(t, ok) + _, err := state.ExportKeyingMaterial(exportLabel, nil, 0) - if !errors.Is(err, errHandshakeInProgress) { - t.Errorf("ExportKeyingMaterial when epoch == 0: expected '%s' actual '%s'", errHandshakeInProgress, err) - } + assert.ErrorIs(t, err, errHandshakeInProgress, "ExportKeyingMaterial when epoch == 0 error mismatch") conn.setLocalEpoch(1) state, ok = conn.ConnectionState() - if !ok { - t.Fatal("ConnectionState failed") - } + assert.True(t, ok) + _, err = state.ExportKeyingMaterial(exportLabel, []byte{0x00}, 0) - if !errors.Is(err, errContextUnsupported) { - t.Errorf("ExportKeyingMaterial with context: expected '%s' actual '%s'", errContextUnsupported, err) - } + assert.ErrorIs(t, err, errContextUnsupported, "ExportKeyingMaterial with context mismatch") for k := range invalidKeyingLabels() { state, ok = conn.ConnectionState() - if !ok { - t.Fatal("ConnectionState failed") - } + assert.True(t, ok) + _, err = state.ExportKeyingMaterial(k, nil, 0) - if !errors.Is(err, errReservedExportKeyingMaterial) { - t.Errorf("ExportKeyingMaterial reserved label: expected '%s' actual '%s'", errReservedExportKeyingMaterial, err) - } + assert.ErrorIs(t, err, errReservedExportKeyingMaterial, "ExportKeyingMaterial reserved label mismatch") } state, ok = conn.ConnectionState() - if !ok { - t.Fatal("ConnectionState failed") - } + assert.True(t, ok) + keyingMaterial, err := state.ExportKeyingMaterial(exportLabel, nil, 10) - if err != nil { - t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err) - } else if !bytes.Equal(keyingMaterial, expectedServerKey) { - t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedServerKey, keyingMaterial) - } + assert.NoError(t, err, "ExportingKeyingMaterial as server error") + assert.Equal(t, expectedServerKey, keyingMaterial, "ExportKeyingMaterial client export mismatch") conn.state.isClient = true state, ok = conn.ConnectionState() - if !ok { - t.Fatal("ConnectionState failed") - } + assert.True(t, ok) + keyingMaterial, err = state.ExportKeyingMaterial(exportLabel, nil, 10) - if err != nil { - t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err) - } else if !bytes.Equal(keyingMaterial, expectedClientKey) { - t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedClientKey, keyingMaterial) - } + assert.NoError(t, err) + assert.Equal(t, expectedClientKey, keyingMaterial, "ExportKeyingMaterial client report mismatch") } -func TestPSK(t *testing.T) { //nolint:cyclop +func TestPSK(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -688,7 +597,7 @@ func TestPSK(t *testing.T) { //nolint:cyclop config := &Config{ PSK: func(hint []byte) ([]byte, error) { - fmt.Println(hint) + t.Log(hint) if !bytes.Equal(test.ClientIdentity, hint) { return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, test.ClientIdentity, hint) } @@ -703,40 +612,29 @@ func TestPSK(t *testing.T) { //nolint:cyclop server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, false) if test.WantFail { res := <-clientRes - if err == nil || !strings.Contains(err.Error(), test.ExpectedServerErr) { - t.Fatalf("TestPSK: Server expected(%v) actual(%v)", test.ExpectedServerErr, err) - } - if res.err == nil || !strings.Contains(res.err.Error(), test.ExpectedClientErr) { - t.Fatalf("TestPSK: Client expected(%v) actual(%v)", test.ExpectedClientErr, res.err) - } + assert.Error(t, err) + assert.True(t, strings.Contains(err.Error(), test.ExpectedServerErr), "TestPSK: Server expected error mismatch") + assert.Error(t, res.err, "TestPSK: Client expected error mismatch") + assert.True(t, strings.Contains(res.err.Error(), test.ExpectedClientErr), + "TestPSK: Client expeected error mismatch") return } - if err != nil { - t.Fatalf("TestPSK: Server failed(%v)", err) - } + assert.NoError(t, err) state, ok := server.ConnectionState() - if !ok { - t.Fatalf("TestPSK: Server ConnectionState failed") - } + assert.True(t, ok, "TestPSK: Server ConnectionState failed") + actualPSKIdentityHint := state.IdentityHint - if !bytes.Equal(actualPSKIdentityHint, test.ClientIdentity) { - t.Errorf( - "TestPSK: Server ClientPSKIdentity Mismatch '%s': expected(%v) actual(%v)", - test.Name, test.ClientIdentity, actualPSKIdentityHint, - ) - } + assert.Equal(t, test.ClientIdentity, actualPSKIdentityHint, "TestPSK: Server ClientPSKIdentity Mismatch") defer func() { _ = server.Close() }() res := <-clientRes - if res.err != nil { - t.Fatal(res.err) - } - _ = res.c.Close() + assert.NoError(t, res.err) + assert.NoError(t, res.c.Close()) }) } } @@ -779,19 +677,13 @@ func TestPSKHintFail(t *testing.T) { CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } - if _, err := testServer( - ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, false, - ); !errors.Is(err, serverAlertError) { - t.Fatalf("TestPSK: Server error exp(%v) failed(%v)", serverAlertError, err) - } - - if err := <-clientErr; !errors.Is(err, pskRejected) { - t.Fatalf("TestPSK: Client error exp(%v) failed(%v)", pskRejected, err) - } + _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, false) + assert.ErrorIs(t, err, serverAlertError, "TestPSK: Server should fail with alert error") + assert.ErrorIs(t, <-clientErr, pskRejected, "TestPSK: Client should fail with pskRejected error") } // Assert that ServerKeyExchange is only sent if Identity is set on server side. -func TestPSKServerKeyExchange(t *testing.T) { //nolint:cyclop +func TestPSKServerKeyExchange(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -824,9 +716,7 @@ func TestPSKServerKeyExchange(t *testing.T) { //nolint:cyclop cbAnalyzer := &connWithCallback{Conn: cb} cbAnalyzer.onWrite = func(in []byte) { messages, err := recordlayer.UnpackDatagram(in) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) for i := range messages { h := &handshake.Handshake{} @@ -864,26 +754,11 @@ func TestPSKServerKeyExchange(t *testing.T) { //nolint:cyclop config.PSKIdentityHint = []byte{0xAB, 0xC1, 0x23} } - if server, err := testServer( - ctx, dtlsnet.PacketConnFromConn(cbAnalyzer), cbAnalyzer.RemoteAddr(), config, false, - ); err != nil { - t.Fatalf("TestPSK: Server error %v", err) - } else { - if err = server.Close(); err != nil { - t.Fatal(err) - } - } - - if err := <-clientErr; err != nil { - t.Fatalf("TestPSK: Client error %v", err) - } - - if gotServerKeyExchange != test.SetIdentity { - t.Fatalf( - "Mismatch between setting Identity and getting a ServerKeyExchange exp(%t) actual(%t)", - test.SetIdentity, gotServerKeyExchange, - ) - } + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cbAnalyzer), cbAnalyzer.RemoteAddr(), config, false) + assert.NoError(t, err) + assert.NoError(t, server.Close()) + assert.NoError(t, <-clientErr, "TestPSK: Client erro") + assert.Equal(t, test.SetIdentity, gotServerKeyExchange) }) } } @@ -916,12 +791,11 @@ func TestClientTimeout(t *testing.T) { // no server! err := <-clientErr var netErr net.Error - if !errors.As(err, &netErr) || !netErr.Timeout() { - t.Fatalf("Client error exp(Temporary network error) failed(%v)", err) - } + assert.ErrorAs(t, err, &netErr, "Client error exp(Temporary network error) failed") + assert.True(t, netErr.Timeout(), "Client error exp(Timeout) failed") } -func TestSRTPConfiguration(t *testing.T) { //nolint:cyclop +func TestSRTPConfiguration(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -1007,12 +881,8 @@ func TestSRTPConfiguration(t *testing.T) { //nolint:cyclop server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ SRTPProtectionProfiles: test.ServerSRTP, SRTPMasterKeyIdentifier: test.ClientSRTPMasterKeyIdentifier, }, true) - if !errors.Is(err, test.WantServerError) { - t.Errorf( - "TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", - test.Name, test.WantServerError, err, - ) - } + assert.ErrorIs(t, err, test.WantServerError, "TestSRTPConfiguration: Server Error Mismatch") + if err == nil { defer func() { _ = server.Close() @@ -1025,47 +895,26 @@ func TestSRTPConfiguration(t *testing.T) { //nolint:cyclop _ = res.c.Close() }() } - if !errors.Is(res.err, test.WantClientError) { - t.Fatalf( - "TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", - test.Name, test.WantClientError, res.err, - ) - } + assert.ErrorIsf(t, res.err, test.WantClientError, "TestSRTPConfiguration: Client Error Mismatch '%s'", test.Name) if res.c == nil { return } actualClientSRTP, _ := res.c.SelectedSRTPProtectionProfile() - if actualClientSRTP != test.ExpectedProfile { - t.Errorf( - "TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", - test.Name, test.ExpectedProfile, actualClientSRTP, - ) - } + assert.Equalf(t, test.ExpectedProfile, actualClientSRTP, + "TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s'", test.Name) actualServerSRTP, _ := server.SelectedSRTPProtectionProfile() - if actualServerSRTP != test.ExpectedProfile { - t.Errorf( - "TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", - test.Name, test.ExpectedProfile, actualServerSRTP, - ) - } + assert.Equalf(t, test.ExpectedProfile, actualServerSRTP, + "TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s'", test.Name) actualServerMKI, _ := server.RemoteSRTPMasterKeyIdentifier() - if !bytes.Equal(actualServerMKI, test.ServerSRTPMasterKeyIdentifier) { - t.Errorf( - "TestSRTPConfiguration: Server SRTPMKI Mismatch '%s': expected(%v) actual(%v)", - test.Name, test.ServerSRTPMasterKeyIdentifier, actualServerMKI, - ) - } + assert.Truef(t, bytes.Equal(test.ServerSRTPMasterKeyIdentifier, actualServerMKI), + "TestSRTPConfiguration: Server SRTPMKI Mismatch '%s'", test.Name) actualClientMKI, _ := res.c.RemoteSRTPMasterKeyIdentifier() - if !bytes.Equal(actualClientMKI, test.ClientSRTPMasterKeyIdentifier) { - t.Errorf( - "TestSRTPConfiguration: Client SRTPMKI Mismatch '%s': expected(%v) actual(%v)", - test.Name, test.ClientSRTPMasterKeyIdentifier, actualClientMKI, - ) - } + assert.Truef(t, bytes.Equal(test.ClientSRTPMasterKeyIdentifier, actualClientMKI), + "TestSRTPConfiguration: Client SRTPMKI Mismatch '%s'", test.Name) } } @@ -1075,24 +924,20 @@ func TestClientCertificate(t *testing.T) { //nolint:gocyclo,cyclop,maintidx defer report() srvCert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + srvCAPool := x509.NewCertPool() srvCertificate, err := x509.ParseCertificate(srvCert.Certificate[0]) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + srvCAPool.AddCert(srvCertificate) cert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + certificate, err := x509.ParseCertificate(cert.Certificate[0]) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + caPool := x509.NewCertPool() caPool.AddCert(certificate) @@ -1269,31 +1114,21 @@ func TestClientCertificate(t *testing.T) { //nolint:gocyclo,cyclop,maintidx }() if tt.wantErr { - if err != nil || hserr != nil { - // Error expected, test succeeded - return - } - t.Error("Error expected") - } - if err != nil { - t.Errorf("Server failed(%v)", err) - } + assert.True(t, err != nil || hserr != nil, "Error expected") - if res.err != nil { - t.Errorf("Client failed(%v)", res.err) + return // Error expected, test succeeded } + assert.NoError(t, err) + assert.NoError(t, res.err) state, ok := server.ConnectionState() - if !ok { - t.Error("Server connection state not available") - } + assert.True(t, ok, "Server connection state not available") + actualClientCert := state.PeerCertificates //nolint:nestif if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert { - if actualClientCert == nil { - t.Errorf("Client did not provide a certificate") - } + assert.NotNil(t, actualClientCert, "Client did not provide a certificate") var cfgCert [][]byte if len(tt.clientCfg.Certificates) > 0 { @@ -1301,43 +1136,35 @@ func TestClientCertificate(t *testing.T) { //nolint:gocyclo,cyclop,maintidx } if tt.clientCfg.GetClientCertificate != nil { crt, err := tt.clientCfg.GetClientCertificate(&CertificateRequestInfo{}) - if err != nil { - t.Errorf("Server configuration did not provide a certificate") - } + assert.NoError(t, err, "Server configuration did not provide a certificate") + cfgCert = crt.Certificate } - if len(cfgCert) == 0 || !bytes.Equal(cfgCert[0], actualClientCert[0]) { - t.Errorf("Client certificate was not communicated correctly") - } + + assert.NotEmpty(t, cfgCert, "Client certificate was not communicated correctly") + assert.Equal(t, actualClientCert[0], cfgCert[0], "Client certificate was not communicated correctly") } if tt.serverCfg.ClientAuth == NoClientCert { - if actualClientCert != nil { - t.Errorf("Client certificate wasn't expected") - } + assert.Nil(t, actualClientCert, "Client certificate wasn't expected") } clientState, ok := res.c.ConnectionState() - if !ok { - t.Error("Client connection state not available") - } + assert.True(t, ok, "Client connection state not available") + actualServerCert := clientState.PeerCertificates - if actualServerCert == nil { - t.Errorf("Server did not provide a certificate") - } + assert.NotNil(t, actualServerCert, "server did not provide a certificate") + var cfgCert [][]byte if len(tt.serverCfg.Certificates) > 0 { cfgCert = tt.serverCfg.Certificates[0].Certificate } if tt.serverCfg.GetCertificate != nil { crt, err := tt.serverCfg.GetCertificate(&ClientHelloInfo{}) - if err != nil { - t.Errorf("Server configuration did not provide a certificate") - } + assert.NoError(t, err, "Server configuration did not provide a certificate") cfgCert = crt.Certificate } - if len(cfgCert) == 0 || !bytes.Equal(cfgCert[0], actualServerCert[0]) { - t.Errorf("Server certificate was not communicated correctly") - } + assert.NotEmpty(t, cfgCert, "Server certificate was not communicated correctly") + assert.Equal(t, actualServerCert[0], cfgCert[0], "Server certificate was not communicated correctly") }) } }) @@ -1425,13 +1252,10 @@ func TestConnectionID(t *testing.T) { }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg, true) - if err != nil { - t.Fatalf("Unexpected server error: %v", err) - } + assert.NoError(t, err) + res := <-c - if res.err != nil { - t.Fatalf("Unexpected client error: %v", res.err) - } + assert.NoError(t, res.err) defer func() { if err == nil { _ = server.Close() @@ -1441,30 +1265,14 @@ func TestConnectionID(t *testing.T) { } }() - if !bytes.Equal(res.c.state.getLocalConnectionID(), tt.clientConnectionID) { - t.Errorf( - "Unexpected client local connection ID\nwant: %v\ngot:%v", - tt.clientConnectionID, res.c.state.localConnectionID, - ) - } - if !bytes.Equal(res.c.state.remoteConnectionID, tt.serverConnectionID) { - t.Errorf( - "Unexpected client remote connection ID\nwant: %v\ngot:%v", - tt.serverConnectionID, res.c.state.remoteConnectionID, - ) - } - if !bytes.Equal(server.state.getLocalConnectionID(), tt.serverConnectionID) { - t.Errorf( - "Unexpected server local connection ID\nwant: %v\ngot:%v", - tt.serverConnectionID, server.state.localConnectionID, - ) - } - if !bytes.Equal(server.state.remoteConnectionID, tt.clientConnectionID) { - t.Errorf( - "Unexpected server remote connection ID\nwant: %v\ngot:%v", - tt.clientConnectionID, server.state.remoteConnectionID, - ) - } + assert.True(t, bytes.Equal(tt.clientConnectionID, res.c.state.getLocalConnectionID()), + "Unexpected client local connection ID") + assert.True(t, bytes.Equal(tt.serverConnectionID, res.c.state.remoteConnectionID), + "Unexpected client remote connection ID") + assert.True(t, bytes.Equal(tt.serverConnectionID, server.state.getLocalConnectionID()), + "Unexpected server local connection ID") + assert.True(t, bytes.Equal(tt.clientConnectionID, server.state.remoteConnectionID), + "Unexpected server remote connection ID") }) } } @@ -1599,14 +1407,8 @@ func TestExtendedMasterSecret(t *testing.T) { _ = res.c.Close() } }() - - if !errors.Is(res.err, tt.expectedClientErr) { - t.Errorf("Client error expected: \"%v\" but got \"%v\"", tt.expectedClientErr, res.err) - } - - if !errors.Is(err, tt.expectedServerErr) { - t.Errorf("Server error expected: \"%v\" but got \"%v\"", tt.expectedServerErr, err) - } + assert.ErrorIs(t, res.err, tt.expectedClientErr) + assert.ErrorIs(t, err, tt.expectedServerErr) }) } } @@ -1617,13 +1419,11 @@ func TestServerCertificate(t *testing.T) { //nolint:cyclop defer report() cert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + certificate, err := x509.ParseCertificate(cert.Certificate[0]) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + caPool := x509.NewCertPool() caPool.AddCert(certificate) @@ -1715,11 +1515,11 @@ func TestServerCertificate(t *testing.T) { //nolint:cyclop if err == nil { _ = cli.Close() } - if !tt.wantErr && (err != nil || hserr != nil) { - t.Errorf("Client failed(%v, %v)", err, hserr) - } - if tt.wantErr && err == nil && hserr == nil { - t.Fatal("Error expected") + if tt.wantErr { + assert.True(t, err != nil || hserr != nil, "Expected error") + } else { + assert.NoError(t, err, "Client connection failed") + assert.NoError(t, hserr, "Client handshake failed") } srv := <-srvCh @@ -1827,29 +1627,17 @@ func TestCipherSuiteConfiguration(t *testing.T) { _ = server.Close() }() } - if !errors.Is(err, test.WantServerError) { - t.Errorf( - "TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", - test.Name, test.WantServerError, err, - ) - } + assert.ErrorIsf(t, err, test.WantServerError, "TestCipherSuiteConfiguration: Server Error Mismatch '%s'", test.Name) res := <-resultCh - if res.err == nil { - _ = server.Close() - _ = res.c.Close() - } - if !errors.Is(res.err, test.WantClientError) { - t.Errorf( - "TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", - test.Name, test.WantClientError, res.err, - ) + if err == nil { + assert.NoError(t, server.Close()) + assert.NoError(t, res.c.Close()) } - if test.WantSelectedCipherSuite != 0x00 && res.c.state.cipherSuite.ID() != test.WantSelectedCipherSuite { - t.Errorf( - "TestCipherSuiteConfiguration: Server Selected Bad Cipher Suite '%s': expected(%v) actual(%v)", - test.Name, test.WantSelectedCipherSuite, res.c.state.cipherSuite.ID(), - ) + assert.ErrorIsf(t, res.err, test.WantClientError, "TestCipherSuiteConfiguration: Client Error Mismatch '%s'") + if test.WantSelectedCipherSuite != 0x00 { + assert.Equal(t, test.WantSelectedCipherSuite, res.c.state.cipherSuite.ID(), + "TestCipherSuiteConfiguration: Server Selected Bad Cipher Suite '%s'", test.Name) } }) } @@ -1907,24 +1695,17 @@ func TestCertificateAndPSKServer(t *testing.T) { } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) - if err == nil { + assert.NoErrorf(t, err, "TestCertificateAndPSKServer: Server Error Mismatch '%s'", test.Name) + if err != nil { defer func() { - _ = server.Close() + assert.NoError(t, server.Close()) }() - } else { - t.Errorf("TestCertificateAndPSKServer: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, err) } res := <-resultCh - if res.err == nil { - _ = server.Close() - _ = res.c.Close() - } else { - t.Errorf( - "TestCertificateAndPSKServer: Client Error Mismatch '%s': expected(%v) actual(%v)", - test.Name, nil, res.err, - ) - } + assert.NoErrorf(t, res.err, "TestCertificateAndPSKServer: Server Error Mismatch '%s'", test.Name) + assert.NoError(t, server.Close()) + assert.NoError(t, res.c.Close()) }) } } @@ -2020,22 +1801,14 @@ func TestPSKConfiguration(t *testing.T) { //nolint:cyclop ) if err != nil || test.WantServerError != nil { if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) { - t.Fatalf( - "TestPSKConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", - test.Name, test.WantServerError, err, - ) + assert.Failf(t, "TestPSKConfiguration", "Server Error Mismatch '%s'", test.Name) } } res := <-resultCh if res.err != nil || test.WantClientError != nil { if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) { - t.Fatalf( - "TestPSKConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", - test.Name, - test.WantClientError, - res.err, - ) + assert.Failf(t, "TestPSKConfiguration", "Client Error Mismatch '%s'", test.Name) } } } @@ -2052,9 +1825,7 @@ func TestServerTimeout(t *testing.T) { //nolint:cyclop cookie := make([]byte, 20) _, err := rand.Read(cookie) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) var rand [28]byte random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand} @@ -2105,16 +1876,11 @@ func TestServerTimeout(t *testing.T) { //nolint:cyclop } packet, err := record.Marshal() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) ca, cb := dpipe.Pipe() defer func() { - err := ca.Close() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, ca.Close()) }() // Client reader @@ -2157,20 +1923,19 @@ func TestServerTimeout(t *testing.T) { //nolint:cyclop _, serverErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) var netErr net.Error - if !errors.As(serverErr, &netErr) || !netErr.Timeout() { - t.Fatalf("Client error exp(Temporary network error) failed(%v)", serverErr) - } + assert.ErrorAsf(t, serverErr, &netErr, "Client error exp(Temporary network error) failed(%v)", serverErr) + assert.Truef(t, netErr.Timeout(), "Client error exp(Temporary network error) failed(%v)", serverErr) // Wait a little longer to ensure no additional messages have been sent by the server time.Sleep(300 * time.Millisecond) select { case msg := <-caReadChan: - t.Fatalf("Expected no additional messages from server, got: %+v", msg) + assert.Fail(t, "Expected no additional messages from server", "got: %+v", msg) default: } } -func TestProtocolVersionValidation(t *testing.T) { //nolint:cyclop,maintidx +func TestProtocolVersionValidation(t *testing.T) { //nolint:maintidx // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2180,9 +1945,8 @@ func TestProtocolVersionValidation(t *testing.T) { //nolint:cyclop,maintidx defer report() cookie := make([]byte, 20) - if _, err := rand.Read(cookie); err != nil { - t.Fatal(err) - } + _, err := rand.Read(cookie) + assert.NoError(t, err) var rand [28]byte random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand} @@ -2256,10 +2020,7 @@ func TestProtocolVersionValidation(t *testing.T) { //nolint:cyclop,maintidx t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { - err := ca.Close() - if err != nil { - t.Error(err) - } + assert.NoError(t, ca.Close()) }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -2270,15 +2031,14 @@ func TestProtocolVersionValidation(t *testing.T) { //nolint:cyclop,maintidx defer wg.Wait() go func() { defer wg.Done() - if _, err := testServer( + _, err := testServer( ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true, - ); !errors.Is(err, errUnsupportedProtocolVersion) { - t.Errorf("Client error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) - } + ) + assert.ErrorIs(t, err, errUnsupportedProtocolVersion) }() time.Sleep(50 * time.Millisecond) @@ -2286,26 +2046,20 @@ func TestProtocolVersionValidation(t *testing.T) { //nolint:cyclop,maintidx resp := make([]byte, 1024) for _, record := range serverCase.records { packet, err := record.Marshal() - if err != nil { - t.Fatal(err) - } - if _, werr := ca.Write(packet); werr != nil { - t.Fatal(werr) - } + assert.NoError(t, err) + + _, werr := ca.Write(packet) + assert.NoError(t, werr) + n, rerr := ca.Read(resp[:cap(resp)]) - if rerr != nil { - t.Fatal(rerr) - } + assert.NoError(t, rerr) + resp = resp[:n] } h := &recordlayer.Header{} - if err := h.Unmarshal(resp); err != nil { - t.Fatal("Failed to unmarshal response") - } - if h.ContentType != protocol.ContentTypeAlert { - t.Errorf("Peer must return alert to unsupported protocol version") - } + assert.NoError(t, h.Unmarshal(resp)) + assert.Equal(t, protocol.ContentTypeAlert, h.ContentType, "Peer must return alert to unsupported protocol version") }) } }) @@ -2356,10 +2110,7 @@ func TestProtocolVersionValidation(t *testing.T) { //nolint:cyclop,maintidx t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { - err := ca.Close() - if err != nil { - t.Error(err) - } + assert.NoError(t, ca.Close()) }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -2370,48 +2121,37 @@ func TestProtocolVersionValidation(t *testing.T) { //nolint:cyclop,maintidx defer wg.Wait() go func() { defer wg.Done() - if _, err := testClient(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true); !errors.Is( - err, errUnsupportedProtocolVersion, - ) { - t.Errorf("Server error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) - } + _, err := testClient(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) + assert.ErrorIs(t, err, errUnsupportedProtocolVersion) }() time.Sleep(50 * time.Millisecond) for _, record := range clientCase.records { - if _, err := ca.Read(make([]byte, 1024)); err != nil { - t.Fatal(err) - } + _, err := ca.Read(make([]byte, 1024)) + assert.NoError(t, err) packet, err := record.Marshal() - if err != nil { - t.Fatal(err) - } - if _, err := ca.Write(packet); err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + + _, err = ca.Write(packet) + assert.NoError(t, err) } resp := make([]byte, 1024) n, err := ca.Read(resp) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + resp = resp[:n] h := &recordlayer.Header{} - if err := h.Unmarshal(resp); err != nil { - t.Fatal("Failed to unmarshal response") - } - if h.ContentType != protocol.ContentTypeAlert { - t.Errorf("Peer must return alert to unsupported protocol version") - } + assert.NoError(t, h.Unmarshal(resp)) + assert.Equal(t, protocol.ContentTypeAlert, h.ContentType, "Peer must return alert to unsupported protocol version") }) } }) } -func TestMultipleHelloVerifyRequest(t *testing.T) { //nolint:cyclop +func TestMultipleHelloVerifyRequest(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2427,14 +2167,14 @@ func TestMultipleHelloVerifyRequest(t *testing.T) { //nolint:cyclop var packets [][]byte for i := 0; i < 2; i++ { cookie := make([]byte, 20) - if _, err := rand.Read(cookie); err != nil { - t.Fatal(err) - } + _, err := rand.Read(cookie) + assert.NoError(t, err) + cookies = append(cookies, cookie) record := &recordlayer.RecordLayer{ Header: recordlayer.Header{ - SequenceNumber: uint64(i), + SequenceNumber: uint64(i), //nolint:gosec // G101 Version: protocol.Version1_2, }, Content: &handshake.Handshake{ @@ -2448,18 +2188,14 @@ func TestMultipleHelloVerifyRequest(t *testing.T) { //nolint:cyclop }, } packet, err := record.Marshal() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + packets = append(packets, packet) } ca, cb := dpipe.Pipe() defer func() { - err := ca.Close() - if err != nil { - t.Error(err) - } + assert.NoError(t, ca.Close()) }() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) @@ -2477,28 +2213,20 @@ func TestMultipleHelloVerifyRequest(t *testing.T) { //nolint:cyclop // read client hello resp := make([]byte, 1024) n, err := cb.Read(resp) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + record := &recordlayer.RecordLayer{} - if err := record.Unmarshal(resp[:n]); err != nil { - t.Fatal(err) - } - clientHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello) - if !ok { - t.Fatal("Failed to cast MessageClientHello") - } + assert.NoError(t, record.Unmarshal(resp[:n])) - if !bytes.Equal(clientHello.Cookie, cookie) { - t.Fatalf("Wrong cookie, expected: %x, got: %x", clientHello.Cookie, cookie) - } + clientHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello) + assert.True(t, ok) + assert.Equal(t, cookie, clientHello.Cookie) if len(packets) <= i { break } // write hello verify request - if _, err := cb.Write(packets[i]); err != nil { - t.Fatal(err) - } + _, err = cb.Write(packets[i]) + assert.NoError(t, err) } cancel() } @@ -2533,24 +2261,21 @@ func TestRenegotationInfo(t *testing.T) { //nolint:cyclop t.Run(test.Name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { - if err := ca.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, ca.Close()) }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() go func() { - if _, err := testServer( + _, err := testServer( ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true, - ); !errors.Is(err, context.Canceled) { - t.Error(err) - } + ) + assert.ErrorIs(t, err, context.Canceled) }() time.Sleep(50 * time.Millisecond) @@ -2562,44 +2287,29 @@ func TestRenegotationInfo(t *testing.T) { //nolint:cyclop }) } err := sendClientHello([]byte{}, ca, 0, extensions) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + n, err := ca.Read(resp) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + record := &recordlayer.RecordLayer{} - if err = record.Unmarshal(resp[:n]); err != nil { - t.Fatal(err) - } + assert.NoError(t, record.Unmarshal(resp[:n])) helloVerifyRequest, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) - if !ok { - t.Fatal("Failed to cast MessageHelloVerifyRequest") - } + assert.True(t, ok) err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions) - if err != nil { - t.Fatal(err) - } - if n, err = ca.Read(resp); err != nil { - t.Fatal(err) - } + assert.NoError(t, err) - messages, err := recordlayer.UnpackDatagram(resp[:n]) - if err != nil { - t.Fatal(err) - } + n, err = ca.Read(resp) + assert.NoError(t, err) - if err := record.Unmarshal(messages[0]); err != nil { - t.Fatal(err) - } + messages, err := recordlayer.UnpackDatagram(resp[:n]) + assert.NoError(t, err) + assert.NoError(t, record.Unmarshal(messages[0])) serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) - if !ok { - t.Fatal("Failed to cast MessageServerHello") - } + assert.True(t, ok) gotNegotationInfo := false for _, v := range serverHello.Extensions { @@ -2608,9 +2318,7 @@ func TestRenegotationInfo(t *testing.T) { //nolint:cyclop } } - if !gotNegotationInfo { - t.Fatalf("Received ServerHello without RenegotiationInfo") - } + assert.True(t, gotNegotationInfo, "Expected RenegotiationInfo extension in ServerHello") }) } } @@ -2666,18 +2374,13 @@ func TestServerNameIndicationExtension(t *testing.T) { // Receive ClientHello resp := make([]byte, 1024) n, err := cb.Read(resp) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + r := &recordlayer.RecordLayer{} - if err = r.Unmarshal(resp[:n]); err != nil { - t.Fatal(err) - } + assert.NoError(t, r.Unmarshal(resp[:n])) clientHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello) - if !ok { - t.Fatal("Failed to cast MessageClientHello") - } + assert.True(t, ok) gotSNI := false var actualServerName string @@ -2685,26 +2388,19 @@ func TestServerNameIndicationExtension(t *testing.T) { if _, ok := v.(*extension.ServerName); ok { gotSNI = true extensionServerName, ok := v.(*extension.ServerName) - if !ok { - t.Fatal("Failed to cast extension.ServerName") - } + assert.True(t, ok) actualServerName = extensionServerName.ServerName } } - if gotSNI != test.IncludeSNI { - t.Errorf("TestSNI: unexpected SNI inclusion '%s': expected(%v) actual(%v)", test.Name, test.IncludeSNI, gotSNI) - } - - if !bytes.Equal([]byte(actualServerName), test.Expected) { - t.Errorf("TestSNI: server name mismatch '%s': expected(%v) actual(%v)", test.Name, test.Expected, actualServerName) - } + assert.Equalf(t, test.IncludeSNI, gotSNI, "TestSNI: expected SNI inclusion '%s'", test.Name) + assert.Equalf(t, test.Expected, []byte(actualServerName), "TestSNI: server name mismatch '%s'", test.Name) }) } } -func TestALPNExtension(t *testing.T) { //nolint:cyclop,maintidx +func TestALPNExtension(t *testing.T) { //nolint:maintidx // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2784,9 +2480,7 @@ func TestALPNExtension(t *testing.T) { //nolint:cyclop,maintidx // Receive ClientHello resp := make([]byte, 1024) n, err := cb.Read(resp) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) ctx2, cancel2 := context.WithTimeout(context.Background(), 10*time.Second) defer cancel2() @@ -2796,87 +2490,60 @@ func TestALPNExtension(t *testing.T) { //nolint:cyclop,maintidx conf := &Config{ SupportedProtocols: test.ServerProtocolNameList, } - if _, err2 := testServer(ctx2, dtlsnet.PacketConnFromConn(cb2), cb2.RemoteAddr(), conf, true); !errors.Is( - err2, context.Canceled, - ) { - if test.ExpectAlertFromServer { //nolint - // Assert the error type? - } else { - t.Error(err2) - } + _, err2 := testServer(ctx2, dtlsnet.PacketConnFromConn(cb2), cb2.RemoteAddr(), conf, true) + if test.ExpectAlertFromServer { + assert.NotErrorIs(t, err2, context.Canceled) } }() time.Sleep(50 * time.Millisecond) // Forward ClientHello - if _, err = ca2.Write(resp[:n]); err != nil { - t.Fatal(err) - } + _, err = ca2.Write(resp[:n]) + assert.NoError(t, err) // Receive HelloVerify resp2 := make([]byte, 1024) n, err = ca2.Read(resp2) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) // Forward HelloVerify - if _, err = cb.Write(resp2[:n]); err != nil { - t.Fatal(err) - } + _, err = cb.Write(resp2[:n]) + assert.NoError(t, err) // Receive ClientHello resp3 := make([]byte, 1024) n, err = cb.Read(resp3) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) // Forward ClientHello - if _, err = ca2.Write(resp3[:n]); err != nil { - t.Fatal(err) - } + _, err = ca2.Write(resp3[:n]) + assert.NoError(t, err) // Receive ServerHello resp4 := make([]byte, 1024) n, err = ca2.Read(resp4) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) messages, err := recordlayer.UnpackDatagram(resp4[:n]) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) record := &recordlayer.RecordLayer{} - if err := record.Unmarshal(messages[0]); err != nil { - t.Fatal(err) - } + assert.NoError(t, record.Unmarshal(messages[0])) if test.ExpectAlertFromServer { //nolint:nestif a, ok := record.Content.(*alert.Alert) - if !ok { - t.Fatal("Failed to cast alert.Alert") - } - - if a.Description != test.Alert { - t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description) - } + assert.True(t, ok) + assert.Equalf(t, test.Alert, a.Description, "ALPN %v", test.Name) } else { serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) - if !ok { - t.Fatal("Failed to cast handshake.MessageServerHello") - } + assert.True(t, ok) var negotiatedProtocol string for _, v := range serverHello.Extensions { if _, ok := v.(*extension.ALPN); ok { e, ok := v.(*extension.ALPN) - if !ok { - t.Fatal("Failed to cast extension.ALPN") - } + assert.True(t, ok) negotiatedProtocol = e.ProtocolNameList[0] @@ -2887,40 +2554,26 @@ func TestALPNExtension(t *testing.T) { //nolint:cyclop,maintidx } } - if negotiatedProtocol != test.ExpectedProtocol { - t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.ExpectedProtocol, negotiatedProtocol) - } + assert.Equalf(t, test.ExpectedProtocol, negotiatedProtocol, "ALPN %v", test.Name) s, err := record.Marshal() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) // Forward ServerHello - if _, err = cb.Write(s); err != nil { - t.Fatal(err) - } + _, err = cb.Write(s) + assert.NoError(t, err) if test.ExpectAlertFromClient { resp5 := make([]byte, 1024) n, err = cb.Read(resp5) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) r2 := &recordlayer.RecordLayer{} - if err := r2.Unmarshal(resp5[:n]); err != nil { - t.Fatal(err) - } + assert.NoError(t, r2.Unmarshal(resp5[:n])) a, ok := r2.Content.(*alert.Alert) - if !ok { - t.Fatal("Failed to cast alert.Alert") - } - - if a.Description != test.Alert { - t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description) - } + assert.True(t, ok) + assert.Equalf(t, test.Alert, a.Description, "ALPN %v", test.Name) } } @@ -2930,7 +2583,7 @@ func TestALPNExtension(t *testing.T) { //nolint:cyclop,maintidx } // Make sure the supported_groups extension is not included in the ServerHello. -func TestSupportedGroupsExtension(t *testing.T) { //nolint:cyclop +func TestSupportedGroupsExtension(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2945,11 +2598,8 @@ func TestSupportedGroupsExtension(t *testing.T) { //nolint:cyclop ca, cb := dpipe.Pipe() go func() { - if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true); !errors.Is( - err, context.Canceled, - ) { - t.Error(err) - } + _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true) + assert.ErrorIs(t, err, context.Canceled) }() extensions := []extension.Extension{ &extension.SupportedEllipticCurves{ @@ -2964,46 +2614,30 @@ func TestSupportedGroupsExtension(t *testing.T) { //nolint:cyclop resp := make([]byte, 1024) err := sendClientHello([]byte{}, ca, 0, extensions) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) // Receive ServerHello n, err := ca.Read(resp) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + record := &recordlayer.RecordLayer{} - if err = record.Unmarshal(resp[:n]); err != nil { - t.Fatal(err) - } + assert.NoError(t, record.Unmarshal(resp[:n])) helloVerifyRequest, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) - if !ok { - t.Fatal("Failed to cast MessageHelloVerifyRequest") - } + assert.True(t, ok, "Failed to cast MessageHelloVerifyRequest") err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions) - if err != nil { - t.Fatal(err) - } - if n, err = ca.Read(resp); err != nil { - t.Fatal(err) - } + assert.NoError(t, err) - messages, err := recordlayer.UnpackDatagram(resp[:n]) - if err != nil { - t.Fatal(err) - } + n, err = ca.Read(resp) + assert.NoError(t, err) - if err := record.Unmarshal(messages[0]); err != nil { - t.Fatal(err) - } + messages, err := recordlayer.UnpackDatagram(resp[:n]) + assert.NoError(t, err) + assert.NoError(t, record.Unmarshal(messages[0])) serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) - if !ok { - t.Fatal("Failed to cast MessageServerHello") - } + assert.True(t, ok, "TestSupportedGroups: Failed to cast MessageServerHello") gotGroups := false for _, v := range serverHello.Extensions { @@ -3012,13 +2646,11 @@ func TestSupportedGroupsExtension(t *testing.T) { //nolint:cyclop } } - if gotGroups { - t.Errorf("TestSupportedGroups: supported_groups extension was sent in ServerHello") - } + assert.False(t, gotGroups, "TestSupportedGroups: supported_groups extension was sent in ServerHello") }) } -func TestSessionResume(t *testing.T) { //nolint:cyclop +func TestSessionResume(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -3069,32 +2701,23 @@ func TestSessionResume(t *testing.T) { //nolint:cyclop MTU: 100, } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) - if err != nil { - t.Fatalf("TestSessionResume: Server failed(%v)", err) - } + assert.NoError(t, err) state, ok := server.ConnectionState() - if !ok { - t.Fatal("TestSessionResume: ConnectionState failed") - } + assert.True(t, ok) + actualSessionID := state.SessionID actualMasterSecret := state.masterSecret - if !bytes.Equal(actualSessionID, id) { - t.Errorf("TestSessionResumetion: SessionID Mismatch: expected(%v) actual(%v)", id, actualSessionID) - } - if !bytes.Equal(actualMasterSecret, secret) { - t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", secret, actualMasterSecret) - } + assert.Equal(t, actualSessionID, id, "TestSessionResumetion SessionID mismatch") + assert.Equal(t, actualMasterSecret, secret, "TestSessionResumetion masterSecret mismatch") defer func() { - _ = server.Close() + assert.NoError(t, server.Close()) }() res := <-clientRes - if res.err != nil { - t.Fatal(res.err) - } - _ = res.c.Close() + assert.NoError(t, res.err) + assert.NoError(t, res.c.Close()) }) t.Run("new session", func(t *testing.T) { @@ -3124,34 +2747,25 @@ func TestSessionResume(t *testing.T) { //nolint:cyclop SessionStore: s2, } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) - if err != nil { - t.Fatalf("TestSessionResumetion: Server failed(%v)", err) - } + assert.NoError(t, err) state, ok := server.ConnectionState() - if !ok { - t.Fatal("TestSessionResumetion: ConnectionState failed") - } + assert.True(t, ok) actualSessionID := state.SessionID actualMasterSecret := state.masterSecret ss, _ := s2.Get(actualSessionID) - if !bytes.Equal(actualMasterSecret, ss.Secret) { - t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret) - } + assert.Equal(t, actualMasterSecret, ss.Secret, "TestSessionResumetion masterSecret mismatch") defer func() { - _ = server.Close() + assert.NoError(t, server.Close()) }() res := <-clientRes - if res.err != nil { - t.Fatal(res.err) - } + assert.NoError(t, res.err) + cs, _ := s1.Get([]byte(ca.RemoteAddr().String() + "_example.com")) - if !bytes.Equal(actualMasterSecret, cs.Secret) { - t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret) - } - _ = res.c.Close() + assert.Equal(t, actualMasterSecret, cs.Secret, "TestSessionResumetion mismatch") + assert.NoError(t, res.c.Close()) }) } @@ -3239,58 +2853,46 @@ func TestCipherSuiteMatchesCertificateType(t *testing.T) { //nolint:cyclop ) if test.generateRSA { - if signer, err = rsa.GenerateKey(rand.Reader, 2048); err != nil { - t.Fatal(err) - } + signer, err = rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) } else { - if signer, err = ecdsa.GenerateKey(cryptoElliptic.P256(), rand.Reader); err != nil { - t.Fatal(err) - } + signer, err = ecdsa.GenerateKey(cryptoElliptic.P256(), rand.Reader) + assert.NoError(t, err) } serverCert, err := selfsign.SelfSign(signer) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) - if s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: test.cipherList, Certificates: []tls.Certificate{serverCert}, - }, false); err != nil { - t.Fatal(err) - } else if err = s.Close(); err != nil { - t.Fatal(err) - } + }, false) + assert.NoError(t, err) + assert.NoError(t, s.Close()) - if c, err := <-client, <-clientErr; err != nil { - t.Fatal(err) - } else if err := c.Close(); err != nil { - t.Fatal(err) - } else if state, ok := c.ConnectionState(); !ok || state.cipherSuite.ID() != test.expectedCipher { - t.Fatalf("Expected(%s) and Actual(%s) CipherSuite do not match", test.expectedCipher, state.cipherSuite.ID()) - } + c := <-client + assert.NoError(t, <-clientErr) + assert.NoError(t, c.Close()) + + state, ok := c.ConnectionState() + assert.True(t, ok) + assert.Equal(t, test.expectedCipher, state.cipherSuite.ID()) }) } } // Test that we return the proper certificate if we are serving multiple ServerNames on a single Server. -func TestMultipleServerCertificates(t *testing.T) { //nolint:cyclop +func TestMultipleServerCertificates(t *testing.T) { fooCert, err := selfsign.GenerateSelfSignedWithDNS("foo") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) barCert, err := selfsign.GenerateSelfSignedWithDNS("bar") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) caPool := x509.NewCertPool() for _, cert := range []tls.Certificate{fooCert, barCert} { certificate, err := x509.ParseCertificate(cert.Certificate[0]) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) caPool.AddCert(certificate) } @@ -3338,24 +2940,18 @@ func TestMultipleServerCertificates(t *testing.T) { //nolint:cyclop client <- clientConn }() - if s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{fooCert, barCert}, - }, false); err != nil { - t.Fatal(err) - } else if err = s.Close(); err != nil { - t.Fatal(err) - } - - if c, err := <-client, <-clientErr; err != nil { - t.Fatal(err) - } else if err := c.Close(); err != nil { - t.Fatal(err) - } + }, false) + assert.NoError(t, err) + assert.NoError(t, s.Close()) + assert.NoError(t, <-clientErr) + assert.NoError(t, (<-client).Close()) }) } } -func TestEllipticCurveConfiguration(t *testing.T) { //nolint:cyclop +func TestEllipticCurveConfiguration(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -3403,47 +2999,25 @@ func TestEllipticCurveConfiguration(t *testing.T) { //nolint:cyclop CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves, }, true) - if err != nil { - t.Fatalf("Server error: %v", err) - } + assert.NoError(t, err) - if len(test.ConfigCurves) == 0 && len(test.HandshakeCurves) != len(server.fsm.cfg.ellipticCurves) { - t.Fatalf( - "Failed to default Elliptic curves, expected %d, got: %d", - len(test.HandshakeCurves), - len(server.fsm.cfg.ellipticCurves), - ) - } + ok := len(test.ConfigCurves) == 0 || len(test.ConfigCurves) == len(test.HandshakeCurves) + assert.True(t, ok, "Failed to default Elliptic curves") if len(test.ConfigCurves) != 0 { - if len(test.HandshakeCurves) != len(server.fsm.cfg.ellipticCurves) { - t.Fatalf( - "Failed to configure Elliptic curves, expect %d, got %d", - len(test.HandshakeCurves), - len(server.fsm.cfg.ellipticCurves), - ) - } + assert.Equal(t, len(test.HandshakeCurves), len(server.fsm.cfg.ellipticCurves), "Failed to configure Elliptic curves") + for i, c := range test.ConfigCurves { - if c != server.fsm.cfg.ellipticCurves[i] { - t.Fatalf("Failed to maintain Elliptic curve order, expected %s, got %s", c, server.fsm.cfg.ellipticCurves[i]) - } + assert.Equal(t, c, server.fsm.cfg.ellipticCurves[i], "Failed to maintain Elliptic curve order") } } res := <-resultCh - if res.err != nil { - t.Fatalf("Client error; %v", err) - } + assert.NoError(t, res.err, "Client error") defer func() { - err = server.Close() - if err != nil { - t.Fatal(err) - } - err = res.c.Close() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, server.Close()) + assert.NoError(t, res.c.Close()) }() } } @@ -3456,9 +3030,7 @@ func TestSkipHelloVerify(t *testing.T) { ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) gotHello := make(chan struct{}) go func() { @@ -3467,41 +3039,31 @@ func TestSkipHelloVerify(t *testing.T) { LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerifyHello: true, }, false) - if sErr != nil { - t.Error(sErr) + assert.NoError(t, sErr) - return - } buf := make([]byte, 1024) - if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck - t.Error(sErr) - } + _, sErr = server.Read(buf) //nolint:contextcheck + assert.NoError(t, sErr) gotHello <- struct{}{} - if sErr = server.Close(); sErr != nil { //nolint:contextcheck - t.Error(sErr) - } + assert.NoError(t, server.Close()) //nolint:contextcheck }() client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) - if err != nil { - t.Fatal(err) - } - if _, err = client.Write([]byte("hello")); err != nil { - t.Error(err) - } + assert.NoError(t, err) + + _, err = client.Write([]byte("hello")) + assert.NoError(t, err) + select { case <-gotHello: // OK case <-time.After(time.Second * 5): - t.Error("timeout") - } - - if err = client.Close(); err != nil { - t.Error(err) + assert.Fail(t, "timeout") } + assert.NoError(t, client.Close()) } type connWithCallback struct { @@ -3530,50 +3092,41 @@ func TestApplicationDataQueueLimited(t *testing.T) { defer cancel() ca, cb := dpipe.Pipe() - defer ca.Close() //nolint:errcheck - defer cb.Close() //nolint:errcheck + defer func() { + assert.NoError(t, ca.Close()) + }() + defer func() { + assert.NoError(t, cb.Close()) + }() done := make(chan struct{}) go func() { serverCert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Error(err) + assert.NoError(t, err) - return - } cfg := &Config{} cfg.Certificates = []tls.Certificate{serverCert} dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false, nil) - if err != nil { - t.Error(err) + assert.NoError(t, err) - return - } go func() { for i := 0; i < 5; i++ { dconn.lock.RLock() qlen := len(dconn.encryptedPackets) dconn.lock.RUnlock() - if qlen > maxAppDataPacketQueueSize { - t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets)) - } + assert.GreaterOrEqual(t, maxAppDataPacketQueueSize, qlen, "too many encrypted packets enqueued") time.Sleep(1 * time.Second) } }() - if err := dconn.HandshakeContext(ctx); err == nil { - t.Error("expected handshake to fail") - } + assert.Error(t, dconn.HandshakeContext(ctx)) close(done) }() extensions := []extension.Extension{} time.Sleep(50 * time.Millisecond) - err := sendClientHello([]byte{}, ca, 0, extensions) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, sendClientHello([]byte{}, ca, 0, extensions)) time.Sleep(50 * time.Millisecond) @@ -3589,20 +3142,19 @@ func TestApplicationDataQueueLimited(t *testing.T) { Data: []byte{1, 2, 3, 4}, }, }).Marshal() - if err != nil { - t.Fatal(err) - } - ca.Write(packet) // nolint + assert.NoError(t, err) + _, err = ca.Write(packet) + assert.NoError(t, err) if i%100 == 0 { time.Sleep(10 * time.Millisecond) } } time.Sleep(1 * time.Second) - ca.Close() // nolint + assert.NoError(t, ca.Close()) <-done } -func TestHelloRandom(t *testing.T) { //nolint:cyclop +func TestHelloRandom(t *testing.T) { report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -3610,16 +3162,12 @@ func TestHelloRandom(t *testing.T) { //nolint:cyclop ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) gotHello := make(chan struct{}) chRandom := [handshake.RandomBytesLength]byte{} _, err = rand.Read(chRandom[:]) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) go func() { server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ @@ -3627,28 +3175,20 @@ func TestHelloRandom(t *testing.T) { //nolint:cyclop if len(chi.CipherSuites) == 0 { return &certificate, nil } - - if !bytes.Equal(chi.RandomBytes[:], chRandom[:]) { - t.Error("client hello random differs") - } + assert.Equal(t, chRandom[:], chi.RandomBytes[:]) return &certificate, nil }, LoggerFactory: logging.NewDefaultLoggerFactory(), }, false) - if sErr != nil { - t.Error(sErr) + assert.NoError(t, sErr) - return - } buf := make([]byte, 1024) - if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck - t.Error(sErr) - } + _, sErr = server.Read(buf) //nolint:contextcheck + assert.NoError(t, sErr) + gotHello <- struct{}{} - if sErr = server.Close(); sErr != nil { //nolint:contextcheck - t.Error(sErr) - } + assert.NoError(t, server.Close()) //nolint:contextcheck }() client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ @@ -3658,22 +3198,19 @@ func TestHelloRandom(t *testing.T) { //nolint:cyclop }, InsecureSkipVerify: true, }, false) - if err != nil { - t.Fatal(err) - } - if _, err = client.Write([]byte("hello")); err != nil { - t.Error(err) - } + assert.NoError(t, err) + + _, err = client.Write([]byte("hello")) + assert.NoError(t, err) + select { case <-gotHello: // OK case <-time.After(time.Second * 5): - t.Error("timeout") + assert.Fail(t, "timeout") } - if err = client.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, client.Close()) } func TestOnConnectionAttempt(t *testing.T) { @@ -3688,9 +3225,7 @@ func TestOnConnectionAttempt(t *testing.T) { _, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ OnConnectionAttempt: func(in net.Addr) error { clientOnConnectionAttempt.Store(1) - if in == nil { - t.Fatal("net.Addr is nil") //nolint: govet - } + assert.NotNil(t, in) return nil }, @@ -3699,30 +3234,18 @@ func TestOnConnectionAttempt(t *testing.T) { }() expectedErr := &FatalError{} - if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ OnConnectionAttempt: func(in net.Addr) error { serverOnConnectionAttempt.Store(1) - if in == nil { - t.Fatal("net.Addr is nil") //nolint: govet - } + assert.NotNil(t, in) return expectedErr }, - }, true); !errors.Is(err, expectedErr) { - t.Fatal(err) - } - - if err := <-clientErr; err == nil { - t.Fatal(err) - } - - if v := serverOnConnectionAttempt.Load(); v != 1 { - t.Fatal("OnConnectionAttempt did not fire for server") - } - - if v := clientOnConnectionAttempt.Load(); v != 0 { - t.Fatal("OnConnectionAttempt fired for client") - } + }, true) + assert.ErrorIs(t, err, expectedErr) + assert.Error(t, <-clientErr) + assert.Equal(t, int32(1), serverOnConnectionAttempt.Load(), "OnConnectionAttempt did not fire for server") + assert.Equal(t, int32(0), clientOnConnectionAttempt.Load(), "OnConnectionAttempt fired for client") } func TestFragmentBuffer_Retransmission(t *testing.T) { @@ -3732,21 +3255,16 @@ func TestFragmentBuffer_Retransmission(t *testing.T) { 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, } - if _, isRetransmission, err := fragmentBuffer.push(frag); err != nil { - t.Fatal(err) - } else if isRetransmission { - t.Fatal("fragment should not be retransmission") - } + _, isRetransmission, err := fragmentBuffer.push(frag) + assert.NoError(t, err) + assert.False(t, isRetransmission) - if v, _ := fragmentBuffer.pop(); v == nil { - t.Fatal("Failed to pop fragment") - } + v, _ := fragmentBuffer.pop() + assert.NotNil(t, v) - if _, isRetransmission, err := fragmentBuffer.push(frag); err != nil { - t.Fatal(err) - } else if !isRetransmission { - t.Fatal("fragment should be retransmission") - } + _, isRetransmission, err = fragmentBuffer.push(frag) + assert.NoError(t, err) + assert.True(t, isRetransmission) } func TestConnectionState(t *testing.T) { @@ -3755,23 +3273,18 @@ func TestConnectionState(t *testing.T) { // Setup client clientCfg := &Config{} clientCert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + clientCfg.Certificates = []tls.Certificate{clientCert} clientCfg.InsecureSkipVerify = true client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), clientCfg) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) defer func() { _ = client.Close() }() _, ok := client.ConnectionState() - if ok { - t.Fatal("ConnectionState should be nil") - } + assert.False(t, ok) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -3783,22 +3296,17 @@ func TestConnectionState(t *testing.T) { // Setup server server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + defer func() { _ = server.Close() }() err = <-errorChannel - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) _, ok = client.ConnectionState() - if !ok { - t.Fatal("ConnectionState should not be nil") - } + assert.True(t, ok) } func TestMultiHandshake(t *testing.T) { @@ -3807,46 +3315,28 @@ func TestMultiHandshake(t *testing.T) { ca, cb := dpipe.Pipe() serverCert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{serverCert}, }) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) go func() { _ = server.Handshake() }() clientCert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ Certificates: []tls.Certificate{clientCert}, }) - if err != nil { - t.Fatal(err) - } - - if err = client.Handshake(); err == nil { - t.Fatal(err) - } - - if err = client.Handshake(); err == nil { - t.Fatal(err) - } - - if err = server.Close(); err != nil { - t.Fatal(err) - } - - if err = client.Close(); err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + assert.Error(t, client.Handshake()) + assert.Error(t, client.Handshake()) + assert.NoError(t, server.Close()) + assert.NoError(t, client.Close()) } func TestCloseDuringHandshake(t *testing.T) { @@ -3854,18 +3344,14 @@ func TestCloseDuringHandshake(t *testing.T) { defer test.TimeOut(time.Second * 10).Stop() serverCert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) for i := 0; i < 100; i++ { _, cb := dpipe.Pipe() server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{serverCert}, }) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) waitChan := make(chan struct{}) go func() { @@ -3874,9 +3360,7 @@ func TestCloseDuringHandshake(t *testing.T) { }() <-waitChan - if err = server.Close(); err != nil { - t.Fatal(err) - } + assert.NoError(t, server.Close()) } } @@ -3885,17 +3369,12 @@ func TestCloseWithoutHandshake(t *testing.T) { defer test.TimeOut(time.Second * 10).Stop() serverCert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + _, cb := dpipe.Pipe() server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{serverCert}, }) - if err != nil { - t.Fatal(err) - } - if err = server.Close(); err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + assert.NoError(t, server.Close()) } diff --git a/connection_id_test.go b/connection_id_test.go index aba5f72d0..26d0917f5 100644 --- a/connection_id_test.go +++ b/connection_id_test.go @@ -11,6 +11,7 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" + "github.com/stretchr/testify/assert" ) func TestRandomConnectionIDGenerator(t *testing.T) { @@ -29,9 +30,7 @@ func TestRandomConnectionIDGenerator(t *testing.T) { } for name, tc := range cases { t.Run(name, func(t *testing.T) { - if cidLen := len(RandomCIDGenerator(tc.size)()); cidLen != tc.size { - t.Errorf("%s\nRandomCIDGenerator: expected CID length %d, but got %d.", tc.reason, tc.size, cidLen) - } + assert.Equal(t, tc.size, len(RandomCIDGenerator(tc.size)()), "%s\nRandomCIDGenerator mismatch", tc.reason) }) } } @@ -46,9 +45,7 @@ func TestOnlySendCIDGenerator(t *testing.T) { } for name, tc := range cases { t.Run(name, func(t *testing.T) { - if cidLen := len(OnlySendCIDGenerator()()); cidLen != 0 { - t.Errorf("%s\nOnlySendCIDGenerator: expected CID length %d, but got %d.", tc.reason, 0, cidLen) - } + assert.Equalf(t, 0, len(OnlySendCIDGenerator()()), "%s\nOnlySendCIDGenerator mismatch", tc.reason) }) } } @@ -65,22 +62,19 @@ func TestCIDDatagramRouter(t *testing.T) { Data: []byte("application data"), }, }).Marshal() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + appData, err := (&protocol.ApplicationData{ Data: []byte("some data"), }).Marshal() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + inner, err := (&recordlayer.InnerPlaintext{ Content: appData, RealType: protocol.ContentTypeApplicationData, }).Marshal() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + cidHeader, err := (&recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, @@ -89,9 +83,8 @@ func TestCIDDatagramRouter(t *testing.T) { ConnectionID: cid, SequenceNumber: 1, }).Marshal() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + cases := map[string]struct { reason string size int @@ -140,9 +133,7 @@ func TestCIDDatagramRouter(t *testing.T) { ConnectionID: []byte("abcd"), SequenceNumber: 1, }).Marshal() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) return append(altCIDHeader, inner...) }(), @@ -170,9 +161,7 @@ func TestCIDDatagramRouter(t *testing.T) { ConnectionID: []byte("1234abcd"), SequenceNumber: 1, }).Marshal() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) return append(altCIDHeader, inner...) }()...), cidHeader...), inner...), @@ -183,12 +172,8 @@ func TestCIDDatagramRouter(t *testing.T) { for name, tc := range cases { t.Run(name, func(t *testing.T) { cid, ok := cidDatagramRouter(tc.size)(tc.datagram) - if ok != tc.ok { - t.Errorf("%s\ncidDatagramRouter: expected ok %t, but got %t.", tc.reason, tc.ok, ok) - } - if cid != tc.want { - t.Errorf("%s\ncidDatagramRouter: expected CID %s, but got %s.", tc.reason, tc.want, cid) - } + assert.Equal(t, tc.ok, ok, "%s\ncidDatagramRouter mismatch", tc.reason) + assert.Equal(t, tc.want, cid, "%s\ncidDatagramRouter mismatch", tc.reason) }) } } @@ -216,9 +201,8 @@ func TestCIDConnIdentifier(t *testing.T) { }, }, }).Marshal() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + appRecord, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: 1, @@ -228,9 +212,8 @@ func TestCIDConnIdentifier(t *testing.T) { Data: []byte("application data"), }, }).Marshal() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + cases := map[string]struct { reason string datagram []byte @@ -279,12 +262,8 @@ func TestCIDConnIdentifier(t *testing.T) { for name, tc := range cases { t.Run(name, func(t *testing.T) { cid, ok := cidConnIdentifier()(tc.datagram) - if ok != tc.ok { - t.Errorf("%s\ncidConnIdentifier: expected ok %t, but got %t.", tc.reason, tc.ok, ok) - } - if cid != tc.want { - t.Errorf("%s\ncidConnIdentifier: expected CID %s, but got %s.", tc.reason, tc.want, cid) - } + assert.Equalf(t, tc.ok, ok, "%s\ncidConnIdentifier mismatch", tc.reason) + assert.Equalf(t, tc.want, cid, "%s\ncidConnIdentifier mismatch", tc.reason) }) } } diff --git a/crypto_test.go b/crypto_test.go index 249ca2cdc..6e6a31229 100644 --- a/crypto_test.go +++ b/crypto_test.go @@ -4,13 +4,13 @@ package dtls import ( - "bytes" "crypto/x509" "encoding/pem" "testing" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/stretchr/testify/assert" ) // nolint: gosec @@ -47,9 +47,7 @@ JPhfPySIPG4UmwE4gW8t79vfOKxnUu2fDD1ZXUYopan6EckACNH/ func TestGenerateKeySignature(t *testing.T) { block, _ := pem.Decode([]byte(rawPrivateKey)) key, err := x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) clientRandom := []byte{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, @@ -82,9 +80,6 @@ func TestGenerateKeySignature(t *testing.T) { } signature, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, key, hash.SHA256) - if err != nil { - t.Error(err) - } else if !bytes.Equal(expectedSignature, signature) { - t.Errorf("Signature generation failed \nexp % 02x \nactual % 02x ", expectedSignature, signature) - } + assert.NoError(t, err) + assert.Equal(t, expectedSignature, signature) } diff --git a/e2e/e2e_lossy_test.go b/e2e/e2e_lossy_test.go index f49cb78f4..07370a930 100644 --- a/e2e/e2e_lossy_test.go +++ b/e2e/e2e_lossy_test.go @@ -14,6 +14,7 @@ import ( "github.com/pion/dtls/v3/pkg/crypto/selfsign" dtlsnet "github.com/pion/dtls/v3/pkg/net" transportTest "github.com/pion/transport/v3/test" + "github.com/stretchr/testify/assert" ) const ( @@ -34,14 +35,10 @@ func TestPionE2ELossy(t *testing.T) { //nolint:cyclop } serverCert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) clientCert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) for _, test := range []struct { LossChanceRange int @@ -146,9 +143,7 @@ func TestPionE2ELossy(t *testing.T) { //nolint:cyclop clientDone := make(chan runResult) br := transportTest.NewBridge() - if err = br.SetLossChance(chosenLoss); err != nil { - t.Fatal(err) - } + assert.NoError(t, br.SetLossChance(chosenLoss)) go func() { cfg := &dtls.Config{ @@ -191,14 +186,10 @@ func TestPionE2ELossy(t *testing.T) { //nolint:cyclop var serverConn, clientConn *dtls.Conn defer func() { if serverConn != nil { - if err = serverConn.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, serverConn.Close()) } if clientConn != nil { - if err = clientConn.Close(); err != nil { - t.Error(err) - } + assert.NoError(t, clientConn.Close()) } }() @@ -211,10 +202,8 @@ func TestPionE2ELossy(t *testing.T) { //nolint:cyclop select { case serverResult := <-serverDone: if serverResult.err != nil { - t.Errorf( - "Fail, serverError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", - clientConn != nil, serverConn != nil, chosenLoss, serverResult.err, - ) + assert.Failf(t, "Fail, serverError", "clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", + clientConn != nil, serverConn != nil, chosenLoss, serverResult.err) return } @@ -222,20 +211,16 @@ func TestPionE2ELossy(t *testing.T) { //nolint:cyclop serverConn = serverResult.dtlsConn case clientResult := <-clientDone: if clientResult.err != nil { - t.Errorf( - "Fail, clientError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", - clientConn != nil, serverConn != nil, chosenLoss, clientResult.err, - ) + assert.Failf(t, "Fail, clientError", "clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", + clientConn != nil, serverConn != nil, chosenLoss, clientResult.err) return } clientConn = clientResult.dtlsConn case <-testTimer.C: - t.Errorf( - "Test expired: clientComplete(%t) serverComplete(%t) LossChance(%d)", - clientConn != nil, serverConn != nil, chosenLoss, - ) + assert.Failf(t, "Test expired", "clientComplete(%t) serverComplete(%t) LossChance(%d)", + clientConn != nil, serverConn != nil, chosenLoss) return case <-time.After(10 * time.Millisecond): diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index f02d1507f..a93614c74 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -27,6 +27,7 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/transport/v3/test" + "github.com/stretchr/testify/assert" ) const ( @@ -44,9 +45,8 @@ var ( func randomPort(tb testing.TB) int { tb.Helper() conn, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - tb.Fatalf("failed to pickPort: %v", err) - } + assert.NoError(tb, err, "failed to pick port") + defer func() { _ = conn.Close() }() @@ -54,7 +54,7 @@ func randomPort(tb testing.TB) int { case *net.UDPAddr: return addr.Port default: - tb.Fatalf("unknown addr type %T", addr) + assert.Fail(tb, "failed to acquire port", "unknown addr type %T", addr) return 0 } @@ -147,19 +147,13 @@ func (c *comm) assert(t *testing.T) { //nolint:cyclop defer func() { if c.clientConn != nil { - if err := c.clientConn.Close(); err != nil { - t.Fatal(err) - } + assert.NoError(t, c.clientConn.Close()) } if c.serverConn != nil { - if err := c.serverConn.Close(); err != nil { - t.Fatal(err) - } + assert.NoError(t, c.serverConn.Close()) } if c.serverListener != nil { - if err := c.serverListener.Close(); err != nil { - t.Fatal(err) - } + assert.NoError(t, c.serverListener.Close()) } }() @@ -168,22 +162,18 @@ func (c *comm) assert(t *testing.T) { //nolint:cyclop for { select { case err := <-c.errChan: - t.Fatal(err) + assert.NoError(t, err) case <-time.After(testTimeLimit): - t.Fatalf("Test timeout, seenClient %t seenServer %t", seenClient, seenServer) + assert.Failf(t, "Test timeout", "seenClient %t seenServer %t", seenClient, seenServer) case clientMsg := <-c.clientChan: - if clientMsg != testMessage { - t.Fatalf("clientMsg does not equal test message: %s %s", clientMsg, testMessage) - } + assert.Equal(t, testMessage, clientMsg) seenClient = true if seenClient && seenServer { return } case serverMsg := <-c.serverChan: - if serverMsg != testMessage { - t.Fatalf("serverMsg does not equal test message: %s %s", serverMsg, testMessage) - } + assert.Equal(t, testMessage, serverMsg) seenServer = true if seenClient && seenServer { @@ -194,30 +184,26 @@ func (c *comm) assert(t *testing.T) { //nolint:cyclop }() } -func (c *comm) cleanup(t *testing.T) { //nolint:cyclop +func (c *comm) cleanup(t *testing.T) { t.Helper() clientDone, serverDone := false, false for { select { case err := <-c.clientDone: - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) clientDone = true if clientDone && serverDone { return } case err := <-c.serverDone: - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) serverDone = true if clientDone && serverDone { return } case <-time.After(testTimeLimit): - t.Fatalf("Test timeout waiting for server shutdown") + assert.Fail(t, "Test timeout waiting for server shutdown") } } } @@ -323,9 +309,7 @@ func testPionE2ESimple(t *testing.T, server, client func(*comm), opts ...dtlsCon defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) cfg := &dtls.Config{ Certificates: []tls.Certificate{cert}, @@ -402,9 +386,7 @@ func testPionE2EMTUs(t *testing.T, server, client func(*comm), opts ...dtlsConfO defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) cfg := &dtls.Config{ Certificates: []tls.Certificate{cert}, @@ -445,13 +427,9 @@ func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm), opts ... defer cancel() _, key, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) cert, err := selfsign.SelfSign(key) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) cfg := &dtls.Config{ Certificates: []tls.Certificate{cert}, @@ -482,22 +460,14 @@ func testPionE2ESimpleED25519ClientCert(t *testing.T, server, client func(*comm) defer cancel() _, skey, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) scert, err := selfsign.SelfSign(skey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) _, ckey, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) ccert, err := selfsign.SelfSign(ckey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) scfg := &dtls.Config{ Certificates: []tls.Certificate{scert}, @@ -532,20 +502,14 @@ func testPionE2ESimpleECDSAClientCert(t *testing.T, server, client func(*comm), defer cancel() scert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) ccert, err := selfsign.GenerateSelfSigned() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) clientCAs := x509.NewCertPool() caCert, err := x509.ParseCertificate(ccert.Certificate[0]) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) clientCAs.AddCert(caCert) scfg := &dtls.Config{ @@ -582,22 +546,14 @@ func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm), op defer cancel() spriv, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) scert, err := selfsign.SelfSign(spriv) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) cpriv, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) ccert, err := selfsign.SelfSign(cpriv) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) scfg := &dtls.Config{ Certificates: []tls.Certificate{scert}, @@ -633,9 +589,7 @@ func testPionE2ESimpleClientHelloHook(t *testing.T, server, client func(*comm), defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) modifiedCipher := dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA supportedList := []dtls.CipherSuiteID{ @@ -692,9 +646,7 @@ func testPionE2ESimpleServerHelloHook(t *testing.T, server, client func(*comm), defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) supportedList := []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM} diff --git a/errors_errno_test.go b/errors_errno_test.go index e2957a691..ed2bdbbbf 100644 --- a/errors_errno_test.go +++ b/errors_errno_test.go @@ -10,35 +10,27 @@ package dtls import ( - "errors" "net" "testing" + + "github.com/stretchr/testify/assert" ) func TestErrorsTemporary(t *testing.T) { // Allocate a UDP port no one is listening on. addrListen, err := net.ResolveUDPAddr("udp", "localhost:0") - if err != nil { - t.Fatalf("Unexpected failure to resolve: %v", err) - } + assert.NoError(t, err) + listener, err := net.ListenUDP("udp", addrListen) - if err != nil { - t.Fatalf("Unexpected failure to listen: %v", err) - } + assert.NoError(t, err) + raddr, ok := listener.LocalAddr().(*net.UDPAddr) - if !ok { - t.Fatal("Unexpedted type assertion error") - } - err = listener.Close() - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + assert.True(t, ok) + assert.NoError(t, listener.Close()) // Server is not listening. conn, errDial := net.DialUDP("udp", nil, raddr) - if errDial != nil { - t.Fatalf("Unexpected error: %v", errDial) - } + assert.NoError(t, errDial) _, _ = conn.Write([]byte{0x00}) // trigger _, err = conn.Read(make([]byte, 10)) @@ -49,14 +41,7 @@ func TestErrorsTemporary(t *testing.T) { } var ne net.Error - if !errors.As(netError(err), &ne) { - t.Fatalf("netError must return net.Error") - } - - if ne.Timeout() { - t.Errorf("%v must not be timeout error", err) - } - if !ne.Temporary() { //nolint:staticcheck - t.Errorf("%v must be temporary error", err) - } + assert.ErrorAs(t, netError(err), &ne) + assert.False(t, ne.Timeout()) + assert.True(t, ne.Temporary()) //nolint:staticcheck } diff --git a/errors_test.go b/errors_test.go index db3bffc59..47f21a345 100644 --- a/errors_test.go +++ b/errors_test.go @@ -8,6 +8,8 @@ import ( "fmt" "net" "testing" + + "github.com/stretchr/testify/assert" ) var errExample = errors.New("an example error") @@ -43,10 +45,7 @@ func TestErrorUnwrap(t *testing.T) { t.Run(fmt.Sprintf("%T", c.err), func(t *testing.T) { err := c.err for _, unwrapped := range c.errUnwrapped { - e := errors.Unwrap(err) - if !errors.Is(e, unwrapped) { - t.Errorf("Unwrapped error is expected to be '%v', got '%v'", unwrapped, e) - } + assert.ErrorIs(t, errors.Unwrap(err), unwrapped) } }) } @@ -69,18 +68,10 @@ func TestErrorNetError(t *testing.T) { testCase := testCase t.Run(fmt.Sprintf("%T", testCase.err), func(t *testing.T) { var ne net.Error - if !errors.As(testCase.err, &ne) { - t.Fatalf("%T doesn't implement net.Error", testCase.err) - } - if ne.Timeout() != testCase.timeout { - t.Errorf("%T.Timeout() should be %v", testCase.err, testCase.timeout) - } - if ne.Temporary() != testCase.temporary { //nolint:staticcheck - t.Errorf("%T.Temporary() should be %v", testCase.err, testCase.temporary) - } - if ne.Error() != testCase.str { - t.Errorf("%T.Error() should be %v", testCase.err, testCase.str) - } + assert.ErrorAs(t, testCase.err, &ne) + assert.Equal(t, testCase.timeout, ne.Timeout()) + assert.Equal(t, testCase.temporary, ne.Temporary()) //nolint:staticcheck + assert.Equal(t, testCase.str, ne.Error()) }) } } diff --git a/flight1handler_test.go b/flight1handler_test.go index 457ee413b..edc49d917 100644 --- a/flight1handler_test.go +++ b/flight1handler_test.go @@ -13,6 +13,7 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/logging" "github.com/pion/transport/v3/test" + "github.com/stretchr/testify/assert" ) type flight1TestMockFlightConn struct{} @@ -33,7 +34,7 @@ type flight1TestMockCipherSuite struct { } func (f *flight1TestMockCipherSuite) IsInitialized() bool { - f.t.Fatal("IsInitialized called with Certificate but not CertificateVerify") + assert.Fail(f.t, "IsInitialized called with Certificate but not CertificateVerify") return true } @@ -269,17 +270,13 @@ func TestFlight1_Process_ServerHelloLateArrival(t *testing.T) { //nolint:maintid cache.push(certificateRequest, 0, 4, handshake.TypeCertificateRequest, false) cache.push(serverHelloDone, 0, 5, handshake.TypeServerHelloDone, false) - if _, alt, err := flight1Parse(context.TODO(), mockConn, state, cache, cfg); err != nil { - t.Fatal(err) - } else if alt != nil { - t.Fatal(alt.String()) - } + _, alt, err := flight1Parse(context.TODO(), mockConn, state, cache, cfg) + assert.NoError(t, err) + assert.Nil(t, alt) cache.push(serverHello, 0, 0, handshake.TypeServerHello, false) cache.push(certificate1, 0, 1, handshake.TypeCertificate, false) - if _, alt, err := flight1Parse(context.TODO(), mockConn, state, cache, cfg); err != nil { - t.Fatal(err) - } else if alt != nil { - t.Fatal(alt.String()) - } + _, alt, err = flight1Parse(context.TODO(), mockConn, state, cache, cfg) + assert.NoError(t, err) + assert.Nil(t, alt) } diff --git a/flight3handler_test.go b/flight3handler_test.go index af7374d6d..44328831e 100644 --- a/flight3handler_test.go +++ b/flight3handler_test.go @@ -16,10 +16,11 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/transport/v3/dpipe" "github.com/pion/transport/v3/test" + "github.com/stretchr/testify/assert" ) // Assert that SupportedEllipticCurves is only sent when a ECC CipherSuite is available. -func TestSupportedEllipticCurves(t *testing.T) { //nolint:cyclop +func TestSupportedEllipticCurves(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -43,9 +44,7 @@ func TestSupportedEllipticCurves(t *testing.T) { //nolint:cyclop caAnalyzer := &connWithCallback{Conn: ca} caAnalyzer.onWrite = func(in []byte) { messages, err := recordlayer.UnpackDatagram(in) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) for i := range messages { h := &handshake.Handshake{} @@ -55,11 +54,8 @@ func TestSupportedEllipticCurves(t *testing.T) { //nolint:cyclop clientHello := &handshake.MessageClientHello{} msg, err := h.Message.Marshal() - if err != nil { - t.Fatal(err) - } else if err = clientHello.Unmarshal(msg); err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + assert.NoError(t, clientHello.Unmarshal(msg)) for _, e := range clientHello.Extensions { if e.TypeValue() == extension.SupportedEllipticCurvesTypeValue { @@ -87,7 +83,7 @@ func TestSupportedEllipticCurves(t *testing.T) { //nolint:cyclop ); err != nil { clientErr <- err } else { - clientErr <- client.Close() //nolint + clientErr <- client.Close() // nolint:errcheck,contextcheck } }() @@ -95,21 +91,12 @@ func TestSupportedEllipticCurves(t *testing.T) { //nolint:cyclop CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, } - if server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true); err != nil { - t.Fatalf("Server error %v", err) - } else { - if err = server.Close(); err != nil { - t.Fatal(err) - } - } - - if err := <-clientErr; err != nil { - t.Fatalf("Client error %v", err) - } + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) + assert.NoError(t, err) + assert.NoError(t, server.Close()) + assert.NoError(t, <-clientErr) for i := range expectedCurves { - if expectedCurves[i] != actualCurves[i] { - t.Fatal("List of curves in SupportedEllipticCurves does not match config") - } + assert.Equal(t, expectedCurves[i], actualCurves[i], "curves in SupportedEllipticCurves mismatch") } } diff --git a/flight4handler_test.go b/flight4handler_test.go index 458292b69..46e7bb1d9 100644 --- a/flight4handler_test.go +++ b/flight4handler_test.go @@ -6,7 +6,6 @@ package dtls import ( "context" "crypto/tls" - "errors" "testing" "time" @@ -17,12 +16,11 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/transport/v3/test" + "github.com/stretchr/testify/assert" ) type flight4TestMockFlightConn struct{} -var errHookCertReqFailed = errors.New("hook failed to modify SignatureHashAlgorithms") - func (f *flight4TestMockFlightConn) notify(context.Context, alert.Level, alert.Description) error { return nil } @@ -39,7 +37,7 @@ type flight4TestMockCipherSuite struct { } func (f *flight4TestMockCipherSuite) IsInitialized() bool { - f.t.Fatal("IsInitialized called with Certificate but not CertificateVerify") + assert.Fail(f.t, "IsInitialized called with Certificate but not CertificateVerify") return true } @@ -121,9 +119,8 @@ func TestFlight4_Process_CertificateVerify(t *testing.T) { cache.push(rawCertificate, 0, 0, handshake.TypeCertificate, true) cache.push(rawClientKeyExchange, 0, 1, handshake.TypeClientKeyExchange, true) - if _, _, err := flight4Parse(context.TODO(), mockConn, state, cache, cfg); err != nil { - t.Fatal(err) - } + _, _, err := flight4Parse(context.TODO(), mockConn, state, cache, cfg) + assert.NoError(t, err) } func TestFlight4_CertificateRequestHook(t *testing.T) { @@ -136,9 +133,7 @@ func TestFlight4_CertificateRequestHook(t *testing.T) { defer report() localKeypair, err := elliptic.GenerateKeypair(elliptic.P256) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) mockConn := &flight4TestMockFlightConn{} state := &State{ @@ -147,9 +142,7 @@ func TestFlight4_CertificateRequestHook(t *testing.T) { } cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) cfg := &handshakeConfig{ localCertificates: []tls.Certificate{cert}, @@ -163,27 +156,20 @@ func TestFlight4_CertificateRequestHook(t *testing.T) { } pkts, _, err := flight4Generate(mockConn, state, nil, cfg) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) for _, p := range pkts { if h, ok := p.record.Content.(*handshake.Handshake); ok { //nolint:nestif if h.Message.Type() == handshake.TypeCertificateRequest { mcr := &handshake.MessageCertificateRequest{} msg, err := h.Message.Marshal() - if err != nil { - t.Fatal(err) - } - err = mcr.Unmarshal(msg) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + assert.NoError(t, mcr.Unmarshal(msg)) if len(mcr.SignatureHashAlgorithms) == 0 { return } } } } - t.Fatal(errHookCertReqFailed) + assert.Fail(t, "hook failed to modify SignatureHashAlgorithms") } diff --git a/fragment_buffer_test.go b/fragment_buffer_test.go index 9e842b0a2..06c57181f 100644 --- a/fragment_buffer_test.go +++ b/fragment_buffer_test.go @@ -4,9 +4,9 @@ package dtls import ( - "errors" - "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestFragmentBuffer(t *testing.T) { @@ -125,26 +125,18 @@ func TestFragmentBuffer(t *testing.T) { fragmentBuffer := newFragmentBuffer() for _, frag := range test.In { status, _, err := fragmentBuffer.push(frag) - if err != nil { - t.Error(err) - } else if !status { - t.Errorf("fragmentBuffer didn't accept fragments for '%s'", test.Name) - } + assert.NoError(t, err) + assert.Truef(t, status, "fragmentBuffer didn't accept fragments for '%s'", test.Name) } for _, expected := range test.Expected { out, epoch := fragmentBuffer.pop() - if !reflect.DeepEqual(out, expected) { - t.Errorf("fragmentBuffer '%s' push/pop: got % 02x, want % 02x", test.Name, out, expected) - } - if epoch != test.Epoch { - t.Errorf("fragmentBuffer returned wrong epoch: got %d, want %d", epoch, test.Epoch) - } + assert.Equalf(t, expected, out, "fragmentBuffer '%s' pop should return expected output", test.Name) + assert.Equalf(t, test.Epoch, epoch, "fragmentBuffer returend wrong epoch") } - if frag, _ := fragmentBuffer.pop(); frag != nil { - t.Errorf("fragmentBuffer popped single buffer multiple times for '%s'", test.Name) - } + frag, _ := fragmentBuffer.pop() + assert.Nilf(t, frag, "fragmentBuffer '%s' pop should return nil when no more fragments are available", test.Name) } } @@ -152,16 +144,14 @@ func TestFragmentBuffer_Overflow(t *testing.T) { fragmentBuffer := newFragmentBuffer() // Push a buffer that doesn't exceed size limits - if _, _, err := fragmentBuffer.push([]byte{ + _, _, err := fragmentBuffer.push([]byte{ 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, - }); err != nil { - t.Fatal(err) - } + }) + assert.NoError(t, err) // Allocate a buffer that exceeds cache size largeBuffer := make([]byte, fragmentBufferMaxSize) - if _, _, err := fragmentBuffer.push(largeBuffer); !errors.Is(err, errFragmentBufferOverflow) { - t.Fatalf("Pushing a large buffer returned (%s) expected(%s)", err, errFragmentBufferOverflow) - } + _, _, err = fragmentBuffer.push(largeBuffer) + assert.ErrorIs(t, err, errFragmentBufferOverflow, "Pushing a large buffer should return an overflow error") } diff --git a/go.mod b/go.mod index 0d974fc9d..fe1e3261c 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,15 @@ module github.com/pion/dtls/v3 require ( github.com/pion/logging v0.2.3 github.com/pion/transport/v3 v3.0.7 + github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.32.0 golang.org/x/net v0.34.0 ) +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + go 1.20 diff --git a/go.sum b/go.sum index 51c8f7b21..5a56bef6b 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,18 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pion/logging v0.2.3 h1:gHuf0zpoh1GW67Nr6Gj4cv5Z9ZscU7g/EaoC/Ke/igI= github.com/pion/logging v0.2.3/go.mod h1:z8YfknkquMe1csOrxK5kc+5/ZPAzMxbKLX5aXpbpC90= github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= -golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= \ No newline at end of file +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handshake_cache_test.go b/handshake_cache_test.go index b655ac166..2a9e1ce08 100644 --- a/handshake_cache_test.go +++ b/handshake_cache_test.go @@ -4,11 +4,11 @@ package dtls import ( - "bytes" "testing" "github.com/pion/dtls/v3/internal/ciphersuite" "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/stretchr/testify/assert" ) func TestHandshakeCacheSinglePush(t *testing.T) { @@ -121,9 +121,7 @@ func TestHandshakeCacheSinglePush(t *testing.T) { h.push(i.data, i.epoch, i.messageSequence, i.typ, i.isClient) } verifyData := h.pullAndMerge(test.Rule...) - if !bytes.Equal(verifyData, test.Expected) { - t.Errorf("handshakeCache '%s' exp: % 02x actual % 02x", test.Name, test.Expected, verifyData) - } + assert.Equal(t, test.Expected, verifyData) } } @@ -215,11 +213,7 @@ func TestHandshakeCacheSessionHash(t *testing.T) { cipherSuite := ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{} verifyData, err := h.sessionHash(cipherSuite.HashFunc(), 0) - if err != nil { - t.Error(err) - } - if !bytes.Equal(verifyData, test.Expected) { - t.Errorf("handshakeCacheSesssionHassh '%s' exp: % 02x actual % 02x", test.Name, test.Expected, verifyData) - } + assert.NoError(t, err) + assert.Equal(t, test.Expected, verifyData, "handshakeCacheSessionHash") } } diff --git a/handshake_test.go b/handshake_test.go index 8c97d20b2..fe4936d02 100644 --- a/handshake_test.go +++ b/handshake_test.go @@ -4,13 +4,13 @@ package dtls import ( - "reflect" "testing" "time" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/stretchr/testify/assert" ) func TestHandshakeMessage(t *testing.T) { @@ -44,16 +44,10 @@ func TestHandshakeMessage(t *testing.T) { } h := &handshake.Handshake{} - if err := h.Unmarshal(rawHandshakeMessage); err != nil { - t.Error(err) - } else if !reflect.DeepEqual(h, parsedHandshake) { - t.Errorf("handshakeMessageClientHello unmarshal: got %#v, want %#v", h, parsedHandshake) - } + assert.NoError(t, h.Unmarshal(rawHandshakeMessage)) + assert.Equal(t, parsedHandshake, h, "handshakeMessageClientHello unmarshal") raw, err := h.Marshal() - if err != nil { - t.Error(err) - } else if !reflect.DeepEqual(raw, rawHandshakeMessage) { - t.Errorf("handshakeMessageClientHello marshal: got %#v, want %#v", raw, rawHandshakeMessage) - } + assert.NoError(t, err) + assert.Equal(t, rawHandshakeMessage, raw, "handshakeMessageClientHello marshal") } diff --git a/handshaker_test.go b/handshaker_test.go index 88e69ee0c..19215638a 100644 --- a/handshaker_test.go +++ b/handshaker_test.go @@ -19,6 +19,7 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/logging" "github.com/pion/transport/v3/test" + "github.com/stretchr/testify/assert" ) const nonZeroRetransmitInterval = 100 * time.Millisecond @@ -35,9 +36,7 @@ func TestWriteKeyLog(t *testing.T) { // Secrets follow the format