diff --git a/filter.go b/filter.go index bc8be51..087c1aa 100644 --- a/filter.go +++ b/filter.go @@ -1,17 +1,20 @@ package conntrack import ( + "maps" + "github.com/ti-mo/netfilter" ) // Filter is an object used to limit dump and flush operations to flows matching // certain fields. Use [NewFilter] to create a new filter, then chain methods to -// set filter fields. Methods mutate the Filter in place and return it for -// chaining purposes. +// set filter fields. +// +// Methods return a new Filter with the specified field set. // // Pass a filter to [Conn.DumpFilter] or [Conn.FlushFilter]. type Filter interface { - // Family sets the address (L3) family to filter on, similar to conntrack's + // Family sets the address (L3) family to filter on, similar to conntrack's. // -f/--family. // // Common values are [netfilter.ProtoIPv4] and [netfilter.ProtoIPv6]. @@ -74,8 +77,9 @@ type filter struct { } func (f *filter) Family(l3 netfilter.ProtoFamily) Filter { - f.l3 = l3 - return f + return f.withClone(func(cpy *filter) { + cpy.l3 = l3 + }) } func (f *filter) family() netfilter.ProtoFamily { @@ -83,28 +87,40 @@ func (f *filter) family() netfilter.ProtoFamily { } func (f *filter) Mark(mark uint32) Filter { - f.f[ctaMark] = netfilter.Uint32Bytes(mark) - return f + return f.withClone(func(cpy *filter) { + cpy.f[ctaMark] = netfilter.Uint32Bytes(mark) + }) } func (f *filter) MarkMask(mask uint32) Filter { - f.f[ctaMarkMask] = netfilter.Uint32Bytes(mask) - return f + return f.withClone(func(cpy *filter) { + cpy.f[ctaMarkMask] = netfilter.Uint32Bytes(mask) + }) } func (f *filter) Status(status Status) Filter { - f.f[ctaStatus] = netfilter.Uint32Bytes(uint32(status)) - return f + return f.withClone(func(cpy *filter) { + cpy.f[ctaStatus] = netfilter.Uint32Bytes(uint32(status)) + }) } func (f *filter) StatusMask(mask uint32) Filter { - f.f[ctaStatusMask] = netfilter.Uint32Bytes(mask) - return f + return f.withClone(func(cpy *filter) { + cpy.f[ctaStatusMask] = netfilter.Uint32Bytes(mask) + }) } func (f *filter) Zone(zone uint16) Filter { - f.f[ctaZone] = netfilter.Uint16Bytes(zone) - return f + return f.withClone(func(cpy *filter) { + cpy.f[ctaZone] = netfilter.Uint16Bytes(zone) + }) +} + +func (f *filter) withClone(fn func(cpy *filter)) *filter { + clone := *f + clone.f = maps.Clone(f.f) + fn(&clone) + return &clone } func (f *filter) marshal() []netfilter.Attribute { diff --git a/filter_test.go b/filter_test.go index b356304..972d6ef 100644 --- a/filter_test.go +++ b/filter_test.go @@ -45,3 +45,21 @@ func TestFilterMarshal(t *testing.T) { assert.Equal(t, want, got) } + +func TestFilterMutate(t *testing.T) { + f := NewFilter(). + Mark(1). + Family(1) + + mod := f. + Mark(2). + Family(2) + + // Ensure original filter is unchanged. + assert.NotEqual(t, f, mod) + assert.Equal(t, []byte{0, 0, 0, 1}, f.(*filter).f[ctaMark]) + assert.Equal(t, netfilter.ProtoFamily(1), f.(*filter).l3) + + assert.Equal(t, []byte{0, 0, 0, 2}, mod.(*filter).f[ctaMark]) + assert.Equal(t, netfilter.ProtoFamily(2), mod.(*filter).l3) +}