Skip to content

Commit 02e7d4f

Browse files
authored
Add NetInterval helper (#342)
When creating set elements that represent a network, the interval range must be half-open [start, end) rather than inclusive [start, end]. For example, for 10.0.0.0/24, the expected range is 10.0.0.0 to 10.0.1.0 instead of 10.0.0.0 to 10.0.0.255. This change introduces a NetInterval helper that returns the correct range given a CIDR string.
1 parent d8090e2 commit 02e7d4f

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

util.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,42 @@ func NetFirstAndLastIP(networkCIDR string) (first, last net.IP, err error) {
8787

8888
return first, last, nil
8989
}
90+
91+
// nextIp returns the next IP address after the given one.
92+
// If the next address overflows, the sentinel values 0.0.0.0 (IPv4)
93+
// or :: (IPv6) are returned.
94+
func nextIP(ip net.IP) net.IP {
95+
if ip == nil {
96+
return nil
97+
}
98+
99+
next := make(net.IP, len(ip))
100+
copy(next, ip)
101+
102+
for i := len(next) - 1; i >= 0; i-- {
103+
next[i]++
104+
if next[i] != 0 {
105+
return next
106+
}
107+
}
108+
109+
// All bytes overflowed to 0
110+
return next
111+
}
112+
113+
// NetInterval returns the half-open ([start, end)) interval of a CIDR string.
114+
// This is the range that nftables uses for interval matching with set elements.
115+
// Unlike NetFirstAndLastIP, the end value is one past the last IP in the
116+
// network. If the last IP is overflowed, the end value will be a zero IP.
117+
//
118+
// For example, for the CIDR "10.0.0.0/24", NetInterval returns
119+
// first=10.0.0.0 and last=10.0.1.0. Note that last is one more than the
120+
// broadcast address of the CIDR.
121+
func NetInterval(cidr string) (net.IP, net.IP, error) {
122+
first, last, err := NetFirstAndLastIP(cidr)
123+
if err != nil {
124+
return first, last, err
125+
}
126+
127+
return first, nextIP(last), nil
128+
}

util_test.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,128 @@ func TestNetFirstAndLastIP(t *testing.T) {
7676
})
7777
}
7878
}
79+
80+
func TestNetInterval(t *testing.T) {
81+
tests := []struct {
82+
name string
83+
cidr string
84+
wantFirstIP net.IP
85+
wantLastIP net.IP
86+
wantErr bool
87+
}{
88+
{
89+
name: "Test Invalid",
90+
cidr: "invalid-cidr",
91+
wantFirstIP: nil,
92+
wantLastIP: nil,
93+
wantErr: true,
94+
},
95+
{
96+
name: "Test IPV4 /0",
97+
cidr: "0.0.0.0/0",
98+
wantFirstIP: net.IP{0, 0, 0, 0},
99+
wantLastIP: net.IP{0, 0, 0, 0},
100+
wantErr: false,
101+
},
102+
{
103+
name: "Test IPV4 /8",
104+
cidr: "10.0.0.0/8",
105+
wantFirstIP: net.IP{10, 0, 0, 0},
106+
wantLastIP: net.IP{11, 0, 0, 0},
107+
wantErr: false,
108+
},
109+
{
110+
name: "Test IPV4 /16",
111+
cidr: "10.0.0.0/16",
112+
wantFirstIP: net.IP{10, 0, 0, 0},
113+
wantLastIP: net.IP{10, 1, 0, 0},
114+
wantErr: false,
115+
},
116+
{
117+
name: "Test IPV4 /24",
118+
cidr: "10.0.0.0/24",
119+
wantFirstIP: net.IP{10, 0, 0, 0},
120+
wantLastIP: net.IP{10, 0, 1, 0},
121+
wantErr: false,
122+
},
123+
{
124+
name: "Test IPV4 /31 near max",
125+
cidr: "255.255.255.255/31",
126+
wantFirstIP: net.IP{255, 255, 255, 254},
127+
wantLastIP: net.IP{0, 0, 0, 0},
128+
wantErr: false,
129+
},
130+
{
131+
name: "Test IPV4 /32",
132+
cidr: "10.0.0.1/32",
133+
wantFirstIP: net.IP{10, 0, 0, 1},
134+
wantLastIP: net.IP{10, 0, 0, 2},
135+
wantErr: false,
136+
},
137+
{
138+
name: "Test IPv4 /0 with max",
139+
cidr: "255.255.255.255/0",
140+
wantFirstIP: net.IP{0, 0, 0, 0},
141+
wantLastIP: net.IP{0, 0, 0, 0},
142+
wantErr: false,
143+
},
144+
{
145+
name: "Test IPv6 /0",
146+
cidr: "::/0",
147+
wantFirstIP: net.ParseIP("::"),
148+
wantLastIP: net.ParseIP("::"),
149+
wantErr: false,
150+
},
151+
{
152+
name: "Test IPv6 /48",
153+
cidr: "2001:db8::/48",
154+
wantFirstIP: net.ParseIP("2001:db8::"),
155+
wantLastIP: net.ParseIP("2001:db8:1::"),
156+
wantErr: false,
157+
},
158+
{
159+
name: "Test IPv6 /64",
160+
cidr: "2001:db8::/64",
161+
wantFirstIP: net.ParseIP("2001:db8::"),
162+
wantLastIP: net.ParseIP("2001:db8::1:0:0:0:0"),
163+
wantErr: false,
164+
},
165+
{
166+
name: "Test IPv6 /120 near max",
167+
cidr: "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00/120",
168+
wantFirstIP: net.ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00"),
169+
wantLastIP: net.ParseIP("::"),
170+
wantErr: false,
171+
},
172+
{
173+
name: "Test IPv6 /128",
174+
cidr: "2001:db8::1/128",
175+
wantFirstIP: net.ParseIP("2001:db8::1"),
176+
wantLastIP: net.ParseIP("2001:db8::2"),
177+
wantErr: false,
178+
},
179+
{
180+
name: "Test IPv6 /0 with max",
181+
cidr: "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/0",
182+
wantFirstIP: net.ParseIP("::"),
183+
wantLastIP: net.ParseIP("::"),
184+
wantErr: false,
185+
},
186+
}
187+
188+
for _, tt := range tests {
189+
t.Run(tt.name, func(t *testing.T) {
190+
gotFirstIP, gotLastIP, err := NetInterval(tt.cidr)
191+
if (err != nil) != tt.wantErr {
192+
t.Errorf("NetInterval() error = %v, wantErr %v", err, tt.wantErr)
193+
return
194+
}
195+
if !reflect.DeepEqual(gotFirstIP, tt.wantFirstIP) {
196+
t.Errorf("NetInterval() gotFirstIP = %v, want %v", gotFirstIP, tt.wantFirstIP)
197+
}
198+
if !reflect.DeepEqual(gotLastIP, tt.wantLastIP) {
199+
t.Errorf("NetInterval() gotLastIP = %v, want %v", gotLastIP, tt.wantLastIP)
200+
}
201+
})
202+
}
203+
}

0 commit comments

Comments
 (0)