Skip to content

Commit 1db35da

Browse files
authored
Add ResetSetElements method (#345)
ResetSetElements resets the stateful expressions (e.g., counters) of all elements in the specified set. The reset is applied immediately (no Flush is required) and the returned elements reflect their state prior to the reset. This update also adds a reset flag to the internal getSetElements method, along with a safeguard for situations where the user provides no valid set identifier.
1 parent 4f65ea4 commit 1db35da

File tree

2 files changed

+92
-6
lines changed

2 files changed

+92
-6
lines changed

nftables_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8491,3 +8491,73 @@ func TestSetElementCounter(t *testing.T) {
84918491
t.Fatalf("got counter %v, want %v", gotCounter, counter)
84928492
}
84938493
}
8494+
8495+
func TestResetSetElements(t *testing.T) {
8496+
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
8497+
defer nftest.CleanupSystemConn(t, newNS)
8498+
defer conn.FlushRuleset()
8499+
8500+
table := conn.AddTable(&nftables.Table{
8501+
Name: "test-table",
8502+
Family: nftables.TableFamilyIPv4,
8503+
})
8504+
set := &nftables.Set{
8505+
Name: "test-set",
8506+
Table: table,
8507+
KeyType: nftables.TypeInetService,
8508+
}
8509+
elements := []nftables.SetElement{
8510+
{
8511+
Key: binaryutil.BigEndian.PutUint16(80),
8512+
Counter: &expr.Counter{
8513+
Bytes: 1024,
8514+
Packets: 10,
8515+
},
8516+
},
8517+
{
8518+
Key: binaryutil.BigEndian.PutUint16(443),
8519+
Counter: &expr.Counter{
8520+
Bytes: 2048,
8521+
Packets: 20,
8522+
},
8523+
},
8524+
}
8525+
if err := conn.AddSet(set, elements); err != nil {
8526+
t.Fatalf("failed to add set: %v", err)
8527+
}
8528+
if err := conn.Flush(); err != nil {
8529+
t.Fatalf("failed to flush: %v", err)
8530+
}
8531+
8532+
got, err := conn.ResetSetElements(set)
8533+
if err != nil {
8534+
t.Fatalf("failed to reset set elements: %v", err)
8535+
}
8536+
if len(got) != len(elements) {
8537+
t.Fatalf("got %d elements, want %d", len(got), len(elements))
8538+
}
8539+
8540+
for i, ge := range got {
8541+
want := elements[i]
8542+
if !bytes.Equal(ge.Key, want.Key) {
8543+
t.Errorf("element %d: got key %v, want %v", i, ge.Key, want.Key)
8544+
}
8545+
if !reflect.DeepEqual(ge.Counter, want.Counter) {
8546+
t.Errorf("element %d: got counter %v, want %v", i, ge.Counter, want.Counter)
8547+
}
8548+
}
8549+
8550+
resetEls, err := conn.GetSetElements(set)
8551+
if err != nil {
8552+
t.Fatalf("failed to get set elements after reset: %v", err)
8553+
}
8554+
for i, re := range resetEls {
8555+
want := elements[i]
8556+
if !bytes.Equal(re.Key, want.Key) {
8557+
t.Errorf("element %d: got key %v, want %v", i, re.Key, want.Key)
8558+
}
8559+
if re.Counter.Bytes != 0 || re.Counter.Packets != 0 {
8560+
t.Errorf("element %d: got counter %v, want zeroed counter", i, re.Counter)
8561+
}
8562+
}
8563+
}

set.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,9 @@ func (cc *Conn) GetSetByName(t *Table, name string) (*Set, error) {
10431043
// getSetElements retrieves elements from a set.
10441044
// If e is empty, all elements are retrieved. Otherwise, only the specified
10451045
// elements are retrieved if they exist in the set.
1046-
func (cc *Conn) getSetElements(s *Set, e []SetElement) ([]SetElement, error) {
1046+
// If reset is true, the stateful expressions (e.g., counters) of the elements
1047+
// being retrieved are reset.
1048+
func (cc *Conn) getSetElements(s *Set, e []SetElement, reset bool) ([]SetElement, error) {
10471049
conn, closer, err := cc.netlinkConn()
10481050
if err != nil {
10491051
return nil, err
@@ -1053,13 +1055,14 @@ func (cc *Conn) getSetElements(s *Set, e []SetElement) ([]SetElement, error) {
10531055
flags := netlink.Request
10541056
attrs := []netlink.Attribute{
10551057
{Type: unix.NFTA_SET_ELEM_LIST_TABLE, Data: []byte(s.Table.Name + "\x00")},
1056-
{Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")},
10571058
}
10581059

10591060
if s.Name != "" {
10601061
attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_SET_ELEM_LIST_SET, Data: []byte(s.Name + "\x00")})
1061-
} else {
1062+
} else if s.ID > 0 {
10621063
attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_SET_ELEM_LIST_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)})
1064+
} else {
1065+
return nil, fmt.Errorf("set must either have a valid name or ID")
10631066
}
10641067

10651068
if len(e) > 0 {
@@ -1086,9 +1089,14 @@ func (cc *Conn) getSetElements(s *Set, e []SetElement) ([]SetElement, error) {
10861089
return nil, err
10871090
}
10881091

1092+
msgType := nftMsgGetSetElem
1093+
if reset {
1094+
msgType = nftMsgGetSetElemReset
1095+
}
1096+
10891097
message := netlink.Message{
10901098
Header: netlink.Header{
1091-
Type: nftMsgGetSetElem.HeaderType(),
1099+
Type: msgType.HeaderType(),
10921100
Flags: flags,
10931101
},
10941102
Data: append(extraHeader(uint8(s.Table.Family), 0), data...),
@@ -1116,10 +1124,18 @@ func (cc *Conn) getSetElements(s *Set, e []SetElement) ([]SetElement, error) {
11161124

11171125
// GetSetElements returns the elements in the specified set.
11181126
func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) {
1119-
return cc.getSetElements(s, nil)
1127+
return cc.getSetElements(s, nil, false)
11201128
}
11211129

11221130
// FindSetElements returns the specified elements in the set.
11231131
func (cc *Conn) FindSetElements(s *Set, e []SetElement) ([]SetElement, error) {
1124-
return cc.getSetElements(s, e)
1132+
return cc.getSetElements(s, e, false)
1133+
}
1134+
1135+
// ResetSetElements resets the stateful expressions (e.g., counters) of all
1136+
// elements in the specified set. The reset is applied immediately
1137+
// (no Flush is required). The returned elements reflect their state prior to
1138+
// the reset. Requires a kernel version >= 6.3.
1139+
func (cc *Conn) ResetSetElements(s *Set) ([]SetElement, error) {
1140+
return cc.getSetElements(s, nil, true)
11251141
}

0 commit comments

Comments
 (0)