|
| 1 | +package remotedns |
| 2 | + |
| 3 | +import ( |
| 4 | + "net" |
| 5 | + "sync" |
| 6 | + "time" |
| 7 | + |
| 8 | + "github.com/jellydator/ttlcache/v3" |
| 9 | +) |
| 10 | + |
| 11 | +var ( |
| 12 | + ipToName = ttlcache.New[string, string]() |
| 13 | + nameToIP = ttlcache.New[string, net.IP]() |
| 14 | + mutex = sync.Mutex{} |
| 15 | + |
| 16 | + ip4NextAddress net.IP |
| 17 | + ip4BroadcastAddress net.IP |
| 18 | +) |
| 19 | + |
| 20 | +func findOrInsertNameAndReturnIP(ipVersion int, name string) net.IP { |
| 21 | + if ipVersion != 4 { |
| 22 | + panic("Method not implemented for IPv6") |
| 23 | + } |
| 24 | + mutex.Lock() |
| 25 | + defer mutex.Unlock() |
| 26 | + var result net.IP = nil |
| 27 | + var ipnet *net.IPNet |
| 28 | + var nextAddress *net.IP |
| 29 | + var broadcastAddress net.IP |
| 30 | + if ipVersion == 4 { |
| 31 | + ipnet = ip4net |
| 32 | + nextAddress = &ip4NextAddress |
| 33 | + broadcastAddress = ip4BroadcastAddress |
| 34 | + } |
| 35 | + |
| 36 | + nameToIP.DeleteExpired() |
| 37 | + ipToName.DeleteExpired() |
| 38 | + |
| 39 | + entry := nameToIP.Get(name) |
| 40 | + if entry != nil { |
| 41 | + ip := entry.Value() |
| 42 | + ipToName.Touch(ip.String()) |
| 43 | + return ip |
| 44 | + } |
| 45 | + |
| 46 | + // Beginning from the pointer to the next most likely free IP, loop through the IP address space |
| 47 | + // until either a free IP is found or the space is exhausted |
| 48 | + passedBroadcastAddress := false |
| 49 | + for result == nil { |
| 50 | + if nextAddress.Equal(broadcastAddress) { |
| 51 | + *nextAddress = getNetworkAddress(ipnet) |
| 52 | + *nextAddress = incrementIP(ipnet.IP) |
| 53 | + |
| 54 | + // We have seen the broadcast address twice during looping |
| 55 | + // This means that our IP address space is exhausted |
| 56 | + if passedBroadcastAddress { |
| 57 | + return nil |
| 58 | + } |
| 59 | + passedBroadcastAddress = true |
| 60 | + } |
| 61 | + |
| 62 | + // Skip the listen address if that is inside our pool range |
| 63 | + if nextAddress.Equal(listenAddress) { |
| 64 | + *nextAddress = incrementIP(*nextAddress) |
| 65 | + continue |
| 66 | + } |
| 67 | + |
| 68 | + // Do not touch entries that exist in the cache already. |
| 69 | + hasKey := ipToName.Has((*nextAddress).String()) |
| 70 | + if !hasKey { |
| 71 | + _ = ipToName.Set((*nextAddress).String(), name, time.Duration(dnsTTL)*time.Second+cacheGraceTime) |
| 72 | + _ = nameToIP.Set(name, *nextAddress, time.Duration(dnsTTL)*time.Second+cacheGraceTime) |
| 73 | + result = *nextAddress |
| 74 | + } |
| 75 | + |
| 76 | + *nextAddress = incrementIP(*nextAddress) |
| 77 | + } |
| 78 | + |
| 79 | + return result |
| 80 | +} |
| 81 | + |
| 82 | +func getCachedName(address net.IP) (string, bool) { |
| 83 | + mutex.Lock() |
| 84 | + defer mutex.Unlock() |
| 85 | + entry := ipToName.Get(address.String()) |
| 86 | + if entry == nil { |
| 87 | + return "", false |
| 88 | + } |
| 89 | + nameToIP.Touch(entry.Value()) |
| 90 | + return entry.Value(), true |
| 91 | +} |
0 commit comments