Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@ package nftables

import (
"encoding/binary"
"errors"
"net"
"net/netip"

"github.com/google/nftables/binaryutil"
"golang.org/x/sys/unix"
)

var (
MaxIPv4 = net.IP{255, 255, 255, 255}
MaxIPv6 = net.IP{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
)

func extraHeader(family uint8, resID uint16) []byte {
return append([]byte{
family,
Expand Down Expand Up @@ -126,3 +133,89 @@ func NetInterval(cidr string) (net.IP, net.IP, error) {

return first, nextIP(last), nil
}

// endIp returns the last address in a given network.
func endIp(netIp net.IP, mask net.IPMask) net.IP {
ip := make(net.IP, len(netIp))
copy(ip, netIp)

for i := 0; i < len(mask); i++ {
ipIdx := len(ip) - i - 1
ip[ipIdx] = netIp[ipIdx] | ^mask[len(mask)-i-1]
}

return ip
}

// NetFromRange returns a CIDR IP network given a start and end address.
// If an exact match is found, ok will be true. If not, no IPNet will be returned, and ok will be false.
func NetFromRange(first net.IP, last net.IP) (*net.IPNet, bool, error) {
ip1 := net.IP(first)
ip2 := net.IP(last)

maxLen := 32
isIpv6 := ip1.To4() == nil

if isIpv6 && ip2.To4() != nil || !isIpv6 && ip2.To4() == nil {
return nil, false, errors.New("Cannot mix IPv4 and IPv6 or process empty IP.")
}

if isIpv6 {
maxLen = 128
}

var match *net.IPNet
for l := maxLen; l >= -1; l-- {
cidrmask := net.CIDRMask(l, maxLen)
ipmask := ip2.Mask(cidrmask)
ipnet := net.IPNet{
IP: ipmask,
Mask: cidrmask,
}

if ipnet.Contains(ip1) {
match = &ipnet
break
}

}

matchFirst := match.IP.Equal(ip1)

// short-circuit if first address is not start of the network
if !matchFirst {
return nil, matchFirst, nil
}

matchSecond := endIp(match.IP, match.Mask).Equal(ip2)

if !matchSecond {
return nil, matchSecond, nil
}

return match, true, nil
}

// NetFromInterval returns a CIDR IP network given a start and end address as found in intervals.
// This is similar to NetFromRange, but subtracts one address from the end of the range.
// If the resulting network is an exact match, ok will be true.
func NetFromInterval(first net.IP, last net.IP) (out *net.IPNet, ok bool, err error) {
var previous net.IP

if len(last) == 0 {
if first.To4() == nil {
previous = MaxIPv6
} else {
previous = MaxIPv4
}
} else {
ip2, ok := netip.AddrFromSlice(last)
if !ok {
return nil, false, errors.New("Failed to construct slice from network.")
}

previous = ip2.Prev().AsSlice()
}

return NetFromRange(first, previous)
}
226 changes: 226 additions & 0 deletions util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,229 @@ func TestNetInterval(t *testing.T) {
})
}
}

func TestEndIp(t *testing.T) {
tests := []struct {
network string
wantEndIp string
}{
{
network: "10.0.0.0/24",
wantEndIp: "10.0.0.255",
},
{
network: "192.168.4.32/27",
wantEndIp: "192.168.4.63",
},
{
network: "2001:db8:100::/64",
wantEndIp: "2001:db8:100:0:ffff:ffff:ffff:ffff",
},
{
network: "2001:db8:100:a:b::50/64",
wantEndIp: "2001:db8:100:a:ffff:ffff:ffff:ffff",
},
}
for _, tt := range tests {
taddr, tnet, err := net.ParseCIDR(tt.network)
if err != nil {
t.Fatalf("endIp() error parsing test CIDR = %v", err)
}

t.Run(tnet.String(), func(t *testing.T) {
gotEndIp := endIp(taddr, tnet.Mask)
if !gotEndIp.Equal(net.ParseIP(tt.wantEndIp)) {
t.Errorf("endIp() gotEndIp = %s, wantEndIp = %s", gotEndIp, tt.wantEndIp)
}
})
}
}

func TestNetFromRange(t *testing.T) {
tests := []struct {
name string
first string
last string
wantNet string
wantOk bool
wantErr bool
}{
{
first: "0.0.0.0",
last: "255.255.255.255",
wantNet: "0.0.0.0/0",
wantOk: true,
wantErr: false,
},
{
first: "0.0.0.1",
last: "255.255.255.254",
wantNet: "", // not exactly 0.0.0.0/0
wantOk: false,
wantErr: false,
},
{
first: "192.168.4.0",
last: "192.168.4.255",
wantNet: "192.168.4.0/24",
wantOk: true,
wantErr: false,
},
{
first: "192.0.2.16",
last: "192.0.2.30",
wantNet: "", // not exactly 192.0.2.16/28
wantOk: false,
wantErr: false,
},
{
first: "2001:db8:100::",
last: "2001:db8:100:ffff:ffff:ffff:ffff:ffff",
wantNet: "2001:db8:100::/48",
wantOk: true,
wantErr: false,
},
{
first: "2001:db8:100::100",
last: "2001:db8:100:0:ffff:ffff:ffff:ffff",
wantNet: "", // not exactly 2001:db8:100::/64
wantOk: false,
wantErr: false,
},
{
first: "2001:db8:100::",
last: "192.0.2.30",
wantNet: "",
wantOk: true,
wantErr: true,
},
{
first: "192.0.2.30",
last: "2001:db8:100::",
wantNet: "",
wantOk: true,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.first+"-"+tt.last, func(t *testing.T) {
gotNet, gotOk, err := NetFromRange(net.ParseIP(tt.first), net.ParseIP(tt.last))
if (err != nil) != tt.wantErr {
t.Errorf("NetFromRange() error = %v, wantErr = %v", err, tt.wantErr)
}

if tt.wantNet == "" {
if gotNet != nil {
t.Errorf("NetFromInterval() gotNet = %v, wantNet = nil", gotNet)
}

return
}

_, wantNetParsed, err := net.ParseCIDR(tt.wantNet)
if err != nil {
t.Fatalf("NetFromRange() error parsing test network = %v", err)
}

if tt.wantOk != gotOk {
t.Errorf("NetFromRange() gotOk = %t, wantOk = %t", gotOk, tt.wantOk)
}

if !reflect.DeepEqual(gotNet, wantNetParsed) {
t.Errorf("NetFromRange() gotNet = %+v, wantNet = %+v", gotNet, wantNetParsed)
}
})
}
}

func TestNetFromInterval(t *testing.T) {
tests := []struct {
name string
first string
last string
wantNet string
wantOk bool
wantErr bool
}{
{
first: "192.0.2.16",
last: "192.0.2.32",
wantNet: "192.0.2.16/28",
wantOk: true,
wantErr: false,
},
{
first: "128.0.0.0",
last: "",
wantNet: "128.0.0.0/1",
wantOk: true,
wantErr: false,
},
{
first: "2001:db8:100::",
last: "2001:db8:101::",
wantNet: "2001:db8:100::/48",
wantOk: true,
wantErr: false,
},
{
first: "2001:db8:a1:11::",
last: "2001:db8:a1:12::",
wantNet: "2001:db8:a1:11::/64",
wantOk: true,
wantErr: false,
},
{
first: "2001:db8:100::100",
last: "2001:db8:100:0:ffff:ffff:ffff:ffff",
wantNet: "", // not exactly 2001:db8:100::/64
wantOk: false,
wantErr: false,
},
{
first: "2001:db8:100::",
last: "192.0.2.30",
wantNet: "",
wantOk: true,
wantErr: true,
},
{
first: "192.0.2.30",
last: "2001:db8:100::",
wantNet: "",
wantOk: true,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.first+"-"+tt.last, func(t *testing.T) {
gotNet, gotOk, err := NetFromInterval(net.ParseIP(tt.first), net.ParseIP(tt.last))
if (err != nil) != tt.wantErr {
t.Errorf("NetFromInterval() error = %v, wantErr = %v", err, tt.wantErr)
}

if tt.wantNet == "" {
if gotNet != nil {
t.Errorf("NetFromInterval() gotNet = %v, wantNet = nil", gotNet)
}

return
}

_, wantNetParsed, err := net.ParseCIDR(tt.wantNet)
if err != nil {
t.Fatalf("NetFromInterval() error parsing test network = %v", err)
}

if tt.wantOk != gotOk {
t.Errorf("NetFromInterval() gotOk = %t, wantOk = %t", gotOk, tt.wantOk)
}

if !reflect.DeepEqual(gotNet, wantNetParsed) {
t.Errorf("NetFromInterval() gotNet = %+v, wantNet = %+v", gotNet, wantNetParsed)
}
})
}
}