Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions controllers/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -693,10 +693,9 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
logic.AddNetworkToAllocatedIpMap(network.Name)
logic.CreateFallbackNameserver(network.Name)
if featureFlags.EnableOverlappingEgressRanges {
// assign virtual NAT pool fields
logic.AssignVirtualNATDefaults(&network, network.AddressRange)
// Update network with virtual NAT settings
if err := logic.UpsertNetwork(&network); err != nil {
if err := logic.AllocateUniqueVNATPool(&network); err != nil {
logger.Log(0, r.Header.Get("user"), "failed to allocate unique virtual NAT pool:", err.Error())
} else if err := logic.UpsertNetwork(&network); err != nil {
logger.Log(0, r.Header.Get("user"), "failed to update network with virtual NAT settings:", err.Error())
}
}
Expand Down Expand Up @@ -808,14 +807,14 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
return
}

netOld := &schema.Network{Name: payload.Name}
err = netOld.Get(r.Context())
currNet := &schema.Network{Name: payload.Name}
err = currNet.Get(r.Context())
if err != nil {
slog.Info("error fetching network", "user", r.Header.Get("user"), "err", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
err = logic.UpdateNetwork(netOld, &payload)
err = logic.UpdateNetwork(currNet, &payload)
if err != nil {
slog.Info("failed to update network", "user", r.Header.Get("user"), "err", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
Expand All @@ -824,5 +823,5 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
go mq.PublishPeerUpdate(false)
slog.Info("updated network", "network", payload.Name, "user", r.Header.Get("user"))
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(payload)
json.NewEncoder(w).Encode(currNet)
}
157 changes: 153 additions & 4 deletions logic/networks.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package logic

import (
"context"
"crypto/sha1"
"encoding/binary"
"errors"
"fmt"
"math/big"
"net"
"sort"
"strings"
Expand Down Expand Up @@ -242,6 +245,155 @@ func cidrOverlaps(a, b *net.IPNet) bool {
return a.Contains(b.IP) || b.Contains(a.IP)
}

const (
FallbackVNATPool = "198.18.0.0/15"
VNATPoolPrefixLen = 22
DefaultSitePrefixV4 = 24
CgnatCIDR = "100.64.0.0/10"
)

// AllocateUniqueVNATPool allocates a unique Virtual NAT pool for a network,
// ensuring it doesn't conflict with pools already assigned to other networks.
func AllocateUniqueVNATPool(network *schema.Network) error {
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return fmt.Errorf("failed to list networks: %w", err)
}

allocatedPools := make(map[string]struct{})
for _, n := range networks {
if n.VirtualNATSitePrefixLenIPv4 > 0 {
if _, _, err := net.ParseCIDR(n.VirtualNATPoolIPv4); err == nil {
allocatedPools[n.VirtualNATPoolIPv4] = struct{}{}
}
}
}

_, cgnatNet, err := net.ParseCIDR(CgnatCIDR)
if err != nil {
return fmt.Errorf("failed to parse CGNAT CIDR: %w", err)
}

_, fallbackNet, err := net.ParseCIDR(FallbackVNATPool)
if err != nil {
return fmt.Errorf("failed to parse fallback pool: %w", err)
}

vpnCIDR := network.AddressRange
needsUniquePool := false

if vpnCIDR == "" {
needsUniquePool = true
} else {
_, vpnNet, err := net.ParseCIDR(vpnCIDR)
if err != nil || vpnNet == nil {
needsUniquePool = true
} else if cidrOverlaps(vpnNet, cgnatNet) {
needsUniquePool = true
}
}

if needsUniquePool {
uniquePool := AllocateUniquePoolFromFallback(fallbackNet, VNATPoolPrefixLen, allocatedPools, network.Name)
if uniquePool == "" {
return fmt.Errorf("failed to allocate unique Virtual NAT pool for network %s: pool exhausted", network.Name)
}
network.VirtualNATPoolIPv4 = uniquePool
network.VirtualNATSitePrefixLenIPv4 = DefaultSitePrefixV4
} else {
AssignVirtualNATDefaults(network, vpnCIDR)
}

return nil
}

// AllocateUniquePoolFromFallback allocates a unique subnet of the given prefix length
// from the fallback pool, skipping any subnets already present in the allocated map.
func AllocateUniquePoolFromFallback(pool *net.IPNet, newPrefixLen int, allocated map[string]struct{}, seed string) string {
if pool == nil {
return ""
}

poolPrefixLen, bits := pool.Mask.Size()
if newPrefixLen < poolPrefixLen || newPrefixLen > bits {
return ""
}

total := 1 << uint(newPrefixLen-poolPrefixLen)
start := vnatHashIndex(seed, total)

for i := 0; i < total; i++ {
idx := (start + i) % total
cand := NthSubnet(pool, newPrefixLen, idx)
if cand == nil || cand.IP == nil {
continue
}
cs := cand.String()
if _, _, err := net.ParseCIDR(cs); err != nil {
continue
}
if _, used := allocated[cs]; !used {
return cs
}
}

return ""
}

// NthSubnet calculates the nth subnet of a given prefix length within a pool.
func NthSubnet(pool *net.IPNet, newPrefixLen int, n int) *net.IPNet {
if pool == nil {
return nil
}

poolPrefixLen, bits := pool.Mask.Size()
if newPrefixLen < poolPrefixLen || newPrefixLen > bits || n < 0 {
return nil
}

base := ipToBigInt(pool.IP)
size := new(big.Int).Lsh(big.NewInt(1), uint(bits-newPrefixLen))
offset := new(big.Int).Mul(big.NewInt(int64(n)), size)
ipInt := new(big.Int).Add(base, offset)
ip := bigIntToIP(ipInt, bits)

mask := net.CIDRMask(newPrefixLen, bits)
return &net.IPNet{IP: ip.Mask(mask), Mask: mask}
}

func ipToBigInt(ip net.IP) *big.Int {
if v4 := ip.To4(); v4 != nil {
return new(big.Int).SetBytes(v4)
}
if v6 := ip.To16(); v6 != nil {
return new(big.Int).SetBytes(v6)
}
return big.NewInt(0)
}

func bigIntToIP(i *big.Int, bits int) net.IP {
b := i.Bytes()
byteLen := bits / 8
if len(b) < byteLen {
pad := make([]byte, byteLen-len(b))
b = append(pad, b...)
}
ip := net.IP(b)
if bits == 32 {
return ip.To4()
}
return ip
}

func vnatHashIndex(seed string, mod int) int {
if mod <= 1 {
return 0
}
sum := sha1.Sum([]byte(seed))
v := binary.BigEndian.Uint32(sum[:4])
return int(v % uint32(mod))
}

// CreateNetwork - creates a network in database
func CreateNetwork(_network *schema.Network) error {
if _network.AddressRange != "" {
Expand Down Expand Up @@ -645,11 +797,8 @@ func UpdateNetwork(currentNetwork, newNetwork *schema.Network) error {
}
}
currentNetwork.VirtualNATSitePrefixLenIPv4 = newNetwork.VirtualNATSitePrefixLenIPv4
} else {
// If both are empty, clear the settings
currentNetwork.VirtualNATPoolIPv4 = newNetwork.VirtualNATPoolIPv4
currentNetwork.VirtualNATSitePrefixLenIPv4 = newNetwork.VirtualNATSitePrefixLenIPv4
}
// When both VNAT fields are omitted from the update, preserve existing settings
return currentNetwork.Update(db.WithContext(context.TODO()))
}

Expand Down
Loading