Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pkg/protocol/extension/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,13 @@ var (
errInvalidCertificateAuthFormat = &protocol.FatalError{
Err: errors.New("invalid Certificate Authorities extension format"), //nolint:err113
}
errEmptyOIDFilter = &protocol.InternalError{
Err: errors.New("no oid set for a OID filter"), //nolint:err113
}
errOIDFiltersFormat = &protocol.FatalError{
Err: errors.New("invalid OID filters extension format"), //nolint:err113
}
errDuplicateOID = &protocol.FatalError{
Err: errors.New("duplicate OID filters"), //nolint:err113
}
)
1 change: 1 addition & 0 deletions pkg/protocol/extension/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const (
CookieTypeValue TypeValue = 44
PskKeyExchangeModesTypeValue TypeValue = 45
CertificateAuthoritiesTypeValue TypeValue = 47
OIDFiltersTypeValue TypeValue = 48
PostHandshakeAuthTypeValue TypeValue = 49
SignatureAlgorithmsCertTypeValue TypeValue = 50
KeyShareTypeValue TypeValue = 51
Expand Down
105 changes: 105 additions & 0 deletions pkg/protocol/extension/oid_filters.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// SPDX-FileCopyrightText: 2026 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package extension

import "golang.org/x/crypto/cryptobyte"

// OIDFilters defines a DTLS 1.3 extension that is used to allow server to
// provide a set of OID/value pairs which it would like the client's
// certificate to match.
//
// https://datatracker.ietf.org/doc/html/rfc8446#section-4.2.5
type OIDFilters struct {
Filters []OIDFilter
}

type OIDFilter struct {
OID []byte
Values []byte
}

// TypeValue returns the extension TypeValue.
func (o OIDFilters) TypeValue() TypeValue {
return OIDFiltersTypeValue
}

// Marshal encodes the extension.
func (o *OIDFilters) Marshal() ([]byte, error) {
var out cryptobyte.Builder
out.AddUint16(uint16(o.TypeValue()))

out.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) {
seen := map[string]struct{}{}
for _, filter := range o.Filters {
if len(filter.OID) < 1 {
builder.SetError(errEmptyOIDFilter)
}
if _, ok := seen[string(filter.OID)]; ok {
builder.SetError(errDuplicateOID)
}
seen[string(filter.OID)] = struct{}{}
builder.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(filter.OID)
})
builder.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if filter.Values != nil {
b.AddBytes(filter.Values)
}
})
}
})
})

return out.Bytes()
}

// Unmarshal populates the extension from encoded data.
func (o *OIDFilters) Unmarshal(data []byte) error { //nolint:cyclop
val := cryptobyte.String(data)

var extension uint16
if !val.ReadUint16(&extension) || TypeValue(extension) != o.TypeValue() {
return errInvalidExtensionType
}

var extData cryptobyte.String
if !val.ReadUint16LengthPrefixed(&extData) {
return errBufferTooSmall
}

var filterList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&filterList) || !extData.Empty() {
return errLengthMismatch
}

seen := map[string]struct{}{}

for !filterList.Empty() {
var filter OIDFilter

var oid cryptobyte.String
if !filterList.ReadUint8LengthPrefixed(&oid) || oid.Empty() {
return errOIDFiltersFormat
}
if _, ok := seen[string(oid)]; ok {
return errDuplicateOID
}
seen[string(oid)] = struct{}{}

filter.OID = make([]byte, len(oid))
copy(filter.OID, oid)

var values cryptobyte.String
if !filterList.ReadUint16LengthPrefixed(&values) {
return errOIDFiltersFormat
}
filter.Values = make([]byte, len(values))
copy(filter.Values, values)

o.Filters = append(o.Filters, filter)
}

return nil
}
211 changes: 211 additions & 0 deletions pkg/protocol/extension/oid_filters_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// SPDX-FileCopyrightText: 2026 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package extension

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestOIDFilters(t *testing.T) {
oid := []byte{0x55, 0x04, 0x03}
values := []byte{0xde, 0xad, 0xbe, 0xef}
filter := OIDFilter{OID: oid, Values: values}
extension := OIDFilters{Filters: []OIDFilter{filter}}

raw, err := extension.Marshal()
assert.NoError(t, err)

expect := []byte{
0x00, 0x30, // extension type (48)
0x00, 0x0c, // extension data length
0x00, 0x0a, // filter list length
0x03, // OID length
0x55, 0x04, 0x03, // OID bytes (id-at-commonName)
0x00, 0x04, // values length
0xde, 0xad, 0xbe, 0xef, // values bytes
}
assert.Equal(t, expect, raw)

newExtension := OIDFilters{}
assert.NoError(t, newExtension.Unmarshal(expect))
assert.Len(t, newExtension.Filters, 1)
assert.Equal(t, oid, newExtension.Filters[0].OID)
assert.Equal(t, values, newExtension.Filters[0].Values)
}

func TestOIDFilters_MultipleFilters(t *testing.T) {
oid1 := []byte{0x55, 0x04}
values1 := []byte{0xaa, 0xbb}
oid2 := []byte{0x55, 0x05}
values2 := []byte{0x01, 0x02, 0x03, 0x04}
extension := OIDFilters{Filters: []OIDFilter{
{OID: oid1, Values: values1},
{OID: oid2, Values: values2},
}}

raw, err := extension.Marshal()
assert.NoError(t, err)

expect := []byte{
0x00, 0x30, // extension type
0x00, 0x12, // extension data length
0x00, 0x10, // filter list length
0x02, // OID length
0x55, 0x04, // OID bytes
0x00, 0x02, // values length
0xaa, 0xbb, // values bytes
0x02, // OID length
0x55, 0x05, // OID bytes
0x00, 0x04, // values length
0x01, 0x02, 0x03, 0x04, // values bytes
}
assert.Equal(t, expect, raw)

newExtension := OIDFilters{}
assert.NoError(t, newExtension.Unmarshal(expect))
assert.Len(t, newExtension.Filters, 2)
assert.Equal(t, oid1, newExtension.Filters[0].OID)
assert.Equal(t, values1, newExtension.Filters[0].Values)
assert.Equal(t, oid2, newExtension.Filters[1].OID)
assert.Equal(t, values2, newExtension.Filters[1].Values)
}

func TestOIDFilters_DuplicateFilters(t *testing.T) {
oid := []byte{0x55, 0x04}
values1 := []byte{0xaa, 0xbb}
values2 := []byte{0xcc, 0xdd}
extension := OIDFilters{Filters: []OIDFilter{
{OID: oid, Values: values1},
{OID: oid, Values: values2},
}}

_, err := extension.Marshal()
assert.ErrorIs(t, err, errDuplicateOID)

raw := []byte{
0x00, 0x30, // extension type
0x00, 0x10, // extension data length
0x00, 0x0e, // filter list length
0x02, // OID length
0x55, 0x04, // OID bytes
0x00, 0x02, // values length
0xaa, 0xbb, // values bytes
0x02, // OID length
0x55, 0x04, // OID bytes
0x00, 0x02, // values length
0xcc, 0xdd, // values bytes
}

newExtension := OIDFilters{}
assert.ErrorIs(t, newExtension.Unmarshal(raw), errDuplicateOID)
}

func TestOIDFilters_EmptyValues(t *testing.T) {
oid := []byte{0x55, 0x04, 0x03}
extension := OIDFilters{Filters: []OIDFilter{
{OID: oid, Values: []byte{}},
}}

raw, err := extension.Marshal()
assert.NoError(t, err)

expect := []byte{
0x00, 0x30, // extension type
0x00, 0x08, // extension data length
0x00, 0x06, // filter list length
0x03, // OID length
0x55, 0x04, 0x03, // OID bytes
0x00, 0x00, // values length (empty)
}
assert.Equal(t, expect, raw)

newExtension := OIDFilters{}
assert.NoError(t, newExtension.Unmarshal(expect))
assert.Len(t, newExtension.Filters, 1)
assert.Equal(t, oid, newExtension.Filters[0].OID)
assert.Empty(t, newExtension.Filters[0].Values)
}

func TestOIDFilters_EmptyFilterList(t *testing.T) {
extension := OIDFilters{Filters: []OIDFilter{}}

raw, err := extension.Marshal()
assert.NoError(t, err)

expect := []byte{
0x00, 0x30, // extension type
0x00, 0x02, // extension data length
0x00, 0x00, // filter list length (empty)
}
assert.Equal(t, expect, raw)

newExtension := OIDFilters{}
assert.NoError(t, newExtension.Unmarshal(expect))
assert.Empty(t, newExtension.Filters)
}

func TestOIDFilters_EmptyOID(t *testing.T) {
raw := []byte{
0x00, 0x30, // extension type
0x00, 0x04, // extension data length
0x00, 0x02, // filter list length
0x00, // OID length = 0 (invalid)
0x00, // start of values length
}
newExtension := OIDFilters{}
assert.ErrorIs(t, newExtension.Unmarshal(raw), errOIDFiltersFormat)
}

func TestOIDFilters_MarshalEmptyOID(t *testing.T) {
extension := OIDFilters{Filters: []OIDFilter{
{OID: []byte{}, Values: []byte{0x01}},
}}
_, err := extension.Marshal()
assert.ErrorIs(t, err, errEmptyOIDFilter)
}

func FuzzOIDFiltersUnmarshal(f *testing.F) {
f.Add([]byte{
0x00, 0x30,
0x00, 0x0c,
0x00, 0x0a,
0x03, 0x55, 0x04, 0x03,
0x00, 0x04, 0xde, 0xad, 0xbe, 0xef,
})
f.Add([]byte{
0x00, 0x30,
0x00, 0x02,
0x00, 0x00,
})
f.Add([]byte{
0x00, 0x30,
0x00, 0x04,
0x00, 0x02,
0x00, 0x00,
})
f.Add([]byte{0x00, 0x30})
f.Add([]byte{})

f.Fuzz(func(t *testing.T, a []byte) {
ext := OIDFilters{}
err := ext.Unmarshal(a)
if err == nil {
seen := map[string]struct{}{}
for _, filter := range ext.Filters {
assert.NotEmpty(t, filter.OID)
_, dup := seen[string(filter.OID)]
assert.False(t, dup)
seen[string(filter.OID)] = struct{}{}
}
// Check if extension is stable after parsing
marshaled, marshalErr := ext.Marshal()
assert.NoError(t, marshalErr)
ext2 := OIDFilters{}
assert.NoError(t, ext2.Unmarshal(marshaled))
assert.Equal(t, ext.Filters, ext2.Filters)
}
})
}
Loading