diff --git a/pkg/protocol/ack.go b/pkg/protocol/ack.go new file mode 100644 index 00000000..e85ff263 --- /dev/null +++ b/pkg/protocol/ack.go @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package protocol + +import ( + "golang.org/x/crypto/cryptobyte" +) + +// ACK is the DTLS 1.3 content type used to acknowledge receipt of +// handshake records. +// +// https://datatracker.ietf.org/doc/html/rfc9147#section-7 +type ACK struct { + // Records is the list of RecordNumbers being acknowledged. + Records []RecordNumber +} + +// RecordNumber identifies a specific DTLS record by its epoch and sequence number. +// The 128-bit value matches the unpacked RecordNumber structure from RFC 9147 Section 4.2. +type RecordNumber struct { + Epoch uint64 + SequenceNumber uint64 +} + +// ContentType returns the content type for ACK records (26). +func (a ACK) ContentType() ContentType { + return ContentTypeACK +} + +// Marshal encodes the ACK message to its wire format. +func (a *ACK) Marshal() ([]byte, error) { + var out cryptobyte.Builder + + out.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, rec := range a.Records { + b.AddUint64(rec.Epoch) + b.AddUint64(rec.SequenceNumber) + } + }) + + return out.Bytes() +} + +// Unmarshal decodes an ACK message from its wire format. +func (a *ACK) Unmarshal(data []byte) error { + val := cryptobyte.String(data) + + var recordList cryptobyte.String + if !val.ReadUint16LengthPrefixed(&recordList) || !val.Empty() { + return errLengthMismatch + } + + a.Records = make([]RecordNumber, 0) + + for !recordList.Empty() { + var rec RecordNumber + if !recordList.ReadUint64(&rec.Epoch) || !recordList.ReadUint64(&rec.SequenceNumber) { + return errInvalidACK + } + a.Records = append(a.Records, rec) + } + + return nil +} diff --git a/pkg/protocol/ack_test.go b/pkg/protocol/ack_test.go new file mode 100644 index 00000000..1807557c --- /dev/null +++ b/pkg/protocol/ack_test.go @@ -0,0 +1,111 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package protocol + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestACK(t *testing.T) { + ack := ACK{ + Records: []RecordNumber{ + {Epoch: 1, SequenceNumber: 42}, + }, + } + + raw, err := ack.Marshal() + assert.NoError(t, err) + + expect := []byte{ + 0x00, 0x10, // record list length (1 record × 16 bytes) + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // epoch = 1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2a, // sequence_number = 42 + } + assert.Equal(t, expect, raw) + + newACK := ACK{} + assert.NoError(t, newACK.Unmarshal(expect)) + assert.Len(t, newACK.Records, 1) + assert.Equal(t, uint64(1), newACK.Records[0].Epoch) + assert.Equal(t, uint64(42), newACK.Records[0].SequenceNumber) +} + +func TestACK_MultipleRecords(t *testing.T) { + ack := ACK{ + Records: []RecordNumber{ + {Epoch: 1, SequenceNumber: 1}, + {Epoch: 1, SequenceNumber: 2}, + {Epoch: 2, SequenceNumber: 0}, + }, + } + + raw, err := ack.Marshal() + assert.NoError(t, err) + + expect := []byte{ + 0x00, 0x30, // record list length (3 × 16 bytes) + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // epoch = 1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // sequence_number = 1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // epoch = 1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, // sequence_number = 2 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, // epoch = 2 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // sequence_number = 0 + } + assert.Equal(t, expect, raw) + + newACK := ACK{} + assert.NoError(t, newACK.Unmarshal(expect)) + assert.Len(t, newACK.Records, 3) + assert.Equal(t, uint64(1), newACK.Records[0].Epoch) + assert.Equal(t, uint64(1), newACK.Records[0].SequenceNumber) + assert.Equal(t, uint64(1), newACK.Records[1].Epoch) + assert.Equal(t, uint64(2), newACK.Records[1].SequenceNumber) + assert.Equal(t, uint64(2), newACK.Records[2].Epoch) + assert.Equal(t, uint64(0), newACK.Records[2].SequenceNumber) +} + +func TestACK_EmptyRecords(t *testing.T) { + ack := ACK{Records: []RecordNumber{}} + + raw, err := ack.Marshal() + assert.NoError(t, err) + + expect := []byte{ + 0x00, 0x00, // record list length (empty) + } + assert.Equal(t, expect, raw) + + newACK := ACK{} + assert.NoError(t, newACK.Unmarshal(expect)) + assert.Empty(t, newACK.Records) +} + +func TestACK_UnmarshalTruncatedRecord(t *testing.T) { + // Length prefix claims 16 bytes but only 7 are present. + raw := []byte{ + 0x00, 0x10, // record list length = 16 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // only 7 bytes of epoch + } + newACK := ACK{} + assert.ErrorIs(t, newACK.Unmarshal(raw), errLengthMismatch) +} + +func TestACK_UnmarshalTrailingData(t *testing.T) { + // Valid record list followed by unexpected trailing bytes. + raw := []byte{ + 0x00, 0x10, // record list length = 16 (one record) + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // epoch = 1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // sequence_number = 1 + 0xde, 0xad, // trailing garbage + } + newACK := ACK{} + assert.ErrorIs(t, newACK.Unmarshal(raw), errLengthMismatch) +} + +func TestACK_UnmarshalEmpty(t *testing.T) { + newACK := ACK{} + assert.NoError(t, newACK.Unmarshal([]byte{0x00, 0x00})) +} diff --git a/pkg/protocol/content.go b/pkg/protocol/content.go index 58bbdc5b..fd16d44a 100644 --- a/pkg/protocol/content.go +++ b/pkg/protocol/content.go @@ -15,6 +15,7 @@ const ( ContentTypeHandshake ContentType = 22 ContentTypeApplicationData ContentType = 23 ContentTypeConnectionID ContentType = 25 + ContentTypeACK ContentType = 26 ) // Content is the top level distinguisher for a DTLS Datagram. diff --git a/pkg/protocol/errors.go b/pkg/protocol/errors.go index 6e69a750..1e344bef 100644 --- a/pkg/protocol/errors.go +++ b/pkg/protocol/errors.go @@ -12,6 +12,10 @@ import ( var ( errBufferTooSmall = &TemporaryError{Err: errors.New("buffer is too small")} //nolint:err113 errInvalidCipherSpec = &FatalError{Err: errors.New("cipher spec invalid")} //nolint:err113 + errInvalidACK = &FatalError{Err: errors.New("ack invalid")} //nolint:err113 + errLengthMismatch = &InternalError{ + Err: errors.New("data length and declared length do not match"), //nolint:err113 + } ) // FatalError indicates that the DTLS connection is no longer available.