Skip to content

Commit 9d983f1

Browse files
authored
Merge pull request #216 from SenseUnit/bump_def_dns_cache_ttl
Rework DNS resolving
2 parents 566a9a5 + 5954fef commit 9d983f1

2 files changed

Lines changed: 48 additions & 102 deletions

File tree

dialer/rescache.go

Lines changed: 22 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,12 @@ package dialer
22

33
import (
44
"context"
5-
"errors"
6-
"fmt"
7-
"net"
85
"net/netip"
96
"strings"
107
"time"
118

129
"codeberg.org/yarmak/secache"
13-
"github.com/hashicorp/go-multierror"
1410
"golang.org/x/sync/singleflight"
15-
16-
"github.com/SenseUnit/dumbproxy/dialer/dto"
1711
)
1812

1913
type resolverCacheKey struct {
@@ -27,21 +21,18 @@ type resolverCacheValue struct {
2721
err error
2822
}
2923

30-
type NameResolveCachingDialer struct {
31-
resolver Resolver
32-
cache secache.Cache[resolverCacheKey, *resolverCacheValue]
33-
sf singleflight.Group
34-
posTTL time.Duration
35-
negTTL time.Duration
36-
timeout time.Duration
37-
next Dialer
24+
type CachingResolver struct {
25+
next Resolver
26+
cache secache.Cache[resolverCacheKey, *resolverCacheValue]
27+
sf singleflight.Group
28+
posTTL time.Duration
29+
negTTL time.Duration
30+
timeout time.Duration
3831
}
3932

40-
func NewNameResolveCachingDialer(next Dialer, resolver Resolver, posTTL, negTTL, timeout time.Duration) *NameResolveCachingDialer {
41-
// func(c *ttlcache.Cache[resolverCacheKey, resolverCacheValue], key resolverCacheKey) *ttlcache.Item[resolverCacheKey, resolverCacheValue] {
42-
// },
43-
return &NameResolveCachingDialer{
44-
resolver: resolver,
33+
func NewCachingResolver(next Resolver, posTTL, negTTL, timeout time.Duration) *CachingResolver {
34+
return &CachingResolver{
35+
next: next,
4536
cache: *(secache.New[resolverCacheKey, *resolverCacheValue](
4637
3,
4738
func(key resolverCacheKey, item *resolverCacheValue) bool {
@@ -51,62 +42,40 @@ func NewNameResolveCachingDialer(next Dialer, resolver Resolver, posTTL, negTTL,
5142
posTTL: posTTL,
5243
negTTL: negTTL,
5344
timeout: timeout,
54-
next: next,
5545
}
5646
}
5747

58-
func (nrcd *NameResolveCachingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
59-
if WantsHostname(ctx, network, address, nrcd.next) {
60-
return nrcd.next.DialContext(ctx, network, address)
61-
}
62-
63-
host, port, err := net.SplitHostPort(address)
64-
if err != nil {
65-
return nil, fmt.Errorf("failed to extract host and port from %s: %w", address, err)
66-
}
67-
48+
func (r *CachingResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
6849
if addr, err := netip.ParseAddr(host); err == nil {
6950
// literal IP address, just do unmapping
70-
return nrcd.next.DialContext(ctx, network, net.JoinHostPort(addr.Unmap().String(), port))
71-
}
72-
73-
var resolveNetwork string
74-
switch network {
75-
case "udp4", "tcp4", "ip4":
76-
resolveNetwork = "ip4"
77-
case "udp6", "tcp6", "ip6":
78-
resolveNetwork = "ip6"
79-
case "udp", "tcp", "ip":
80-
resolveNetwork = "ip"
81-
default:
82-
return nil, fmt.Errorf("resolving dial %q: unsupported network %q", address, network)
51+
return r.next.LookupNetIP(ctx, network, addr.Unmap().String())
8352
}
8453

8554
host = strings.ToLower(host)
8655
key := resolverCacheKey{
87-
network: resolveNetwork,
56+
network: network,
8857
host: host,
8958
}
9059

91-
res, ok := nrcd.cache.GetValidOrDelete(key)
60+
res, ok := r.cache.GetValidOrDelete(key)
9261
if !ok {
93-
v, _, _ := nrcd.sf.Do(key.network+":"+key.host, func() (any, error) {
94-
ctx, cl := context.WithTimeout(context.Background(), nrcd.timeout)
62+
v, _, _ := r.sf.Do(key.network+":"+key.host, func() (any, error) {
63+
ctx, cl := context.WithTimeout(context.Background(), r.timeout)
9564
defer cl()
96-
res, err := nrcd.resolver.LookupNetIP(ctx, key.network, key.host)
65+
res, err := r.next.LookupNetIP(ctx, key.network, key.host)
9766
for i := range res {
9867
res[i] = res[i].Unmap()
9968
}
100-
setTTL := nrcd.negTTL
69+
setTTL := r.negTTL
10170
if err == nil {
102-
setTTL = nrcd.posTTL
71+
setTTL = r.posTTL
10372
}
10473
item := &resolverCacheValue{
10574
expires: time.Now().Add(setTTL),
10675
addrs: res,
10776
err: err,
10877
}
109-
nrcd.cache.Set(key, item)
78+
r.cache.Set(key, item)
11079
return item, nil
11180
})
11281
res = v.(*resolverCacheValue)
@@ -116,33 +85,7 @@ func (nrcd *NameResolveCachingDialer) DialContext(ctx context.Context, network,
11685
return nil, res.err
11786
}
11887

119-
ctx = dto.OrigDstToContext(ctx, address)
120-
121-
var dialErr error
122-
var conn net.Conn
123-
124-
for _, ip := range res.addrs {
125-
conn, err = nrcd.next.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
126-
if err == nil {
127-
return conn, nil
128-
}
129-
dialErr = multierror.Append(dialErr, err)
130-
var sae dto.StopAddressIteration
131-
if errors.As(err, &sae) {
132-
break
133-
}
134-
}
135-
136-
return nil, fmt.Errorf("failed to dial %s: %w", address, dialErr)
137-
}
138-
139-
func (nrcd *NameResolveCachingDialer) Dial(network, address string) (net.Conn, error) {
140-
return nrcd.DialContext(context.Background(), network, address)
141-
}
142-
143-
func (nrcd *NameResolveCachingDialer) WantsHostname(ctx context.Context, net, address string) bool {
144-
return WantsHostname(ctx, net, address, nrcd.next)
88+
return res.addrs, nil
14589
}
14690

147-
var _ Dialer = new(NameResolveCachingDialer)
148-
var _ HostnameWanter = new(NameResolveCachingDialer)
91+
var _ Resolver = new(CachingResolver)

main.go

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ func parse_args() *CLIArgs {
475475
return nil
476476
})
477477
flag.Var(&args.dnsPreferAddress, "dns-prefer-address", "address resolution preference (none/ipv4/ipv6)")
478-
flag.DurationVar(&args.dnsCacheTTL, "dns-cache-ttl", 0, "enable DNS cache with specified fixed TTL")
478+
flag.DurationVar(&args.dnsCacheTTL, "dns-cache-ttl", 10, "enable DNS cache with specified fixed TTL")
479479
flag.DurationVar(&args.dnsCacheNegTTL, "dns-cache-neg-ttl", time.Second, "TTL for negative responses of DNS cache")
480480
flag.DurationVar(&args.dnsCacheTimeout, "dns-cache-timeout", 5*time.Second, "timeout for shared resolves of DNS cache")
481481
flag.DurationVar(&args.reqHeaderTimeout, "req-header-timeout", 30*time.Second, "amount of time allowed to read request headers")
@@ -660,8 +660,29 @@ func run() int {
660660
filterRoot = access.NewDstAddrFilter(args.denyDstAddr.Value(), filterRoot)
661661
}
662662

663+
// setup name resolution
664+
var nameResolver dialer.Resolver = net.DefaultResolver
665+
if len(args.dnsServers) > 0 {
666+
nameResolver, err = resolver.FastFromURLs(args.dnsServers...)
667+
if err != nil {
668+
mainLogger.Critical("Failed to create name resolver: %v", err)
669+
return 3
670+
}
671+
}
672+
if args.dnsCacheTTL > 0 {
673+
nameResolver = dialer.NewCachingResolver(
674+
nameResolver,
675+
args.dnsCacheTTL,
676+
args.dnsCacheNegTTL,
677+
args.dnsCacheTimeout,
678+
)
679+
}
680+
nameResolver = resolver.Prefer(nameResolver, args.dnsPreferAddress.Value())
681+
663682
// construct dialers
664683
var dialerRoot dialer.Dialer = dialer.NewBoundDialer(new(net.Dialer), args.sourceIPHints)
684+
// this resolving dialer resolves dials unconditionally, for sake of cache or resolving privacy
685+
dialerRoot = dialer.NewNameResolvingDialer(dialerRoot, nameResolver)
665686
if len(args.proxy) > 0 {
666687
for _, proxy := range args.proxy {
667688
if proxy.literal {
@@ -692,28 +713,10 @@ func run() int {
692713
}
693714
}
694715

695-
dialerRoot = dialer.NewFilterDialer(filterRoot.Access, dialerRoot) // must follow after resolving in chain
696-
697-
var nameResolver dialer.Resolver = net.DefaultResolver
698-
if len(args.dnsServers) > 0 {
699-
nameResolver, err = resolver.FastFromURLs(args.dnsServers...)
700-
if err != nil {
701-
mainLogger.Critical("Failed to create name resolver: %v", err)
702-
return 3
703-
}
704-
}
705-
nameResolver = resolver.Prefer(nameResolver, args.dnsPreferAddress.Value())
706-
if args.dnsCacheTTL > 0 {
707-
dialerRoot = dialer.NewNameResolveCachingDialer(
708-
dialerRoot,
709-
nameResolver,
710-
args.dnsCacheTTL,
711-
args.dnsCacheNegTTL,
712-
args.dnsCacheTimeout,
713-
)
714-
} else {
715-
dialerRoot = dialer.NewNameResolvingDialer(dialerRoot, nameResolver)
716-
}
716+
dialerRoot = dialer.NewFilterDialer(filterRoot.Access, dialerRoot)
717+
// this resolving dialer resolves dials conditionally (unless upstream dialer tells not to)
718+
// for sake of access filtering by destination address
719+
dialerRoot = dialer.NewNameResolvingDialer(dialerRoot, nameResolver)
717720

718721
// unholy plug
719722
if args.tt {

0 commit comments

Comments
 (0)