Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,42 @@ func NetFirstAndLastIP(networkCIDR string) (first, last net.IP, err error) {

return first, last, nil
}

// nextIp returns the next IP address after the given one.
// If the next address overflows, the sentinel values 0.0.0.0 (IPv4)
// or :: (IPv6) are returned.
func nextIP(ip net.IP) net.IP {
if ip == nil {
return nil
}

next := make(net.IP, len(ip))
copy(next, ip)

for i := len(next) - 1; i >= 0; i-- {
next[i]++
if next[i] != 0 {
return next
}
}

// All bytes overflowed to 0
return next
}

// NetInterval returns the half-open ([start, end)) interval of a CIDR string.
// This is the range that nftables uses for interval matching with set elements.
// Unlike NetFirstAndLastIP, the end value is one past the last IP in the
// network. If the last IP is overflowed, the end value will be a zero IP.
//
// For example, for the CIDR "10.0.0.0/24", NetInterval returns
// first=10.0.0.0 and last=10.0.1.0. Note that last is one more than the
// broadcast address of the CIDR.
func NetInterval(cidr string) (net.IP, net.IP, error) {
first, last, err := NetFirstAndLastIP(cidr)
if err != nil {
return first, last, err
}

return first, nextIP(last), nil
}
125 changes: 125 additions & 0 deletions util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,128 @@ func TestNetFirstAndLastIP(t *testing.T) {
})
}
}

func TestNetInterval(t *testing.T) {
tests := []struct {
name string
cidr string
wantFirstIP net.IP
wantLastIP net.IP
wantErr bool
}{
{
name: "Test Invalid",
cidr: "invalid-cidr",
wantFirstIP: nil,
wantLastIP: nil,
wantErr: true,
},
{
name: "Test IPV4 /0",
cidr: "0.0.0.0/0",
wantFirstIP: net.IP{0, 0, 0, 0},
wantLastIP: net.IP{0, 0, 0, 0},
wantErr: false,
},
{
name: "Test IPV4 /8",
cidr: "10.0.0.0/8",
wantFirstIP: net.IP{10, 0, 0, 0},
wantLastIP: net.IP{11, 0, 0, 0},
wantErr: false,
},
{
name: "Test IPV4 /16",
cidr: "10.0.0.0/16",
wantFirstIP: net.IP{10, 0, 0, 0},
wantLastIP: net.IP{10, 1, 0, 0},
wantErr: false,
},
{
name: "Test IPV4 /24",
cidr: "10.0.0.0/24",
wantFirstIP: net.IP{10, 0, 0, 0},
wantLastIP: net.IP{10, 0, 1, 0},
wantErr: false,
},
{
name: "Test IPV4 /31 near max",
cidr: "255.255.255.255/31",
wantFirstIP: net.IP{255, 255, 255, 254},
wantLastIP: net.IP{0, 0, 0, 0},
wantErr: false,
},
{
name: "Test IPV4 /32",
cidr: "10.0.0.1/32",
wantFirstIP: net.IP{10, 0, 0, 1},
wantLastIP: net.IP{10, 0, 0, 2},
wantErr: false,
},
{
name: "Test IPv4 /0 with max",
cidr: "255.255.255.255/0",
wantFirstIP: net.IP{0, 0, 0, 0},
wantLastIP: net.IP{0, 0, 0, 0},
wantErr: false,
},
{
name: "Test IPv6 /0",
cidr: "::/0",
wantFirstIP: net.ParseIP("::"),
wantLastIP: net.ParseIP("::"),
wantErr: false,
},
{
name: "Test IPv6 /48",
cidr: "2001:db8::/48",
wantFirstIP: net.ParseIP("2001:db8::"),
wantLastIP: net.ParseIP("2001:db8:1::"),
wantErr: false,
},
{
name: "Test IPv6 /64",
cidr: "2001:db8::/64",
wantFirstIP: net.ParseIP("2001:db8::"),
wantLastIP: net.ParseIP("2001:db8::1:0:0:0:0"),
wantErr: false,
},
{
name: "Test IPv6 /120 near max",
cidr: "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00/120",
wantFirstIP: net.ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00"),
wantLastIP: net.ParseIP("::"),
wantErr: false,
},
{
name: "Test IPv6 /128",
cidr: "2001:db8::1/128",
wantFirstIP: net.ParseIP("2001:db8::1"),
wantLastIP: net.ParseIP("2001:db8::2"),
wantErr: false,
},
{
name: "Test IPv6 /0 with max",
cidr: "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/0",
wantFirstIP: net.ParseIP("::"),
wantLastIP: net.ParseIP("::"),
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotFirstIP, gotLastIP, err := NetInterval(tt.cidr)
if (err != nil) != tt.wantErr {
t.Errorf("NetInterval() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotFirstIP, tt.wantFirstIP) {
t.Errorf("NetInterval() gotFirstIP = %v, want %v", gotFirstIP, tt.wantFirstIP)
}
if !reflect.DeepEqual(gotLastIP, tt.wantLastIP) {
t.Errorf("NetInterval() gotLastIP = %v, want %v", gotLastIP, tt.wantLastIP)
}
})
}
}