Skip to content

Commit 5954fef

Browse files
committed
refactor name caching from dialer to resolver level
1 parent 7ef0b39 commit 5954fef

2 files changed

Lines changed: 33 additions & 102 deletions

File tree

dialer/rescache.go

Lines changed: 22 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,12 @@ package dialer
22

33
import (
44
"context"
5-
"errors"
6-
"fmt"
7-
"net"
85
"net/netip"
96
"strings"
107
"time"
118

129
"codeberg.org/yarmak/secache"
13-
"github.com/hashicorp/go-multierror"
1410
"golang.org/x/sync/singleflight"
15-
16-
"github.com/SenseUnit/dumbproxy/dialer/dto"
1711
)
1812

1913
type resolverCacheKey struct {
@@ -27,21 +21,18 @@ type resolverCacheValue struct {
2721
err error
2822
}
2923

30-
type NameResolveCachingDialer struct {
31-
resolver Resolver
32-
preFilter bool
33-
cache secache.Cache[resolverCacheKey, *resolverCacheValue]
34-
sf singleflight.Group
35-
posTTL time.Duration
36-
negTTL time.Duration
37-
timeout time.Duration
38-
next Dialer
24+
type CachingResolver struct {
25+
next Resolver
26+
cache secache.Cache[resolverCacheKey, *resolverCacheValue]
27+
sf singleflight.Group
28+
posTTL time.Duration
29+
negTTL time.Duration
30+
timeout time.Duration
3931
}
4032

41-
func NewNameResolveCachingDialer(next Dialer, preFilter bool, resolver Resolver, posTTL, negTTL, timeout time.Duration) *NameResolveCachingDialer {
42-
return &NameResolveCachingDialer{
43-
resolver: resolver,
44-
preFilter: preFilter,
33+
func NewCachingResolver(next Resolver, posTTL, negTTL, timeout time.Duration) *CachingResolver {
34+
return &CachingResolver{
35+
next: next,
4536
cache: *(secache.New[resolverCacheKey, *resolverCacheValue](
4637
3,
4738
func(key resolverCacheKey, item *resolverCacheValue) bool {
@@ -51,62 +42,40 @@ func NewNameResolveCachingDialer(next Dialer, preFilter bool, resolver Resolver,
5142
posTTL: posTTL,
5243
negTTL: negTTL,
5344
timeout: timeout,
54-
next: next,
5545
}
5646
}
5747

58-
func (nrcd *NameResolveCachingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
59-
if nrcd.preFilter && WantsHostname(ctx, network, address, nrcd.next) {
60-
return nrcd.next.DialContext(ctx, network, address)
61-
}
62-
63-
host, port, err := net.SplitHostPort(address)
64-
if err != nil {
65-
return nil, fmt.Errorf("failed to extract host and port from %s: %w", address, err)
66-
}
67-
48+
func (r *CachingResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
6849
if addr, err := netip.ParseAddr(host); err == nil {
6950
// literal IP address, just do unmapping
70-
return nrcd.next.DialContext(ctx, network, net.JoinHostPort(addr.Unmap().String(), port))
71-
}
72-
73-
var resolveNetwork string
74-
switch network {
75-
case "udp4", "tcp4", "ip4":
76-
resolveNetwork = "ip4"
77-
case "udp6", "tcp6", "ip6":
78-
resolveNetwork = "ip6"
79-
case "udp", "tcp", "ip":
80-
resolveNetwork = "ip"
81-
default:
82-
return nil, fmt.Errorf("resolving dial %q: unsupported network %q", address, network)
51+
return r.next.LookupNetIP(ctx, network, addr.Unmap().String())
8352
}
8453

8554
host = strings.ToLower(host)
8655
key := resolverCacheKey{
87-
network: resolveNetwork,
56+
network: network,
8857
host: host,
8958
}
9059

91-
res, ok := nrcd.cache.GetValidOrDelete(key)
60+
res, ok := r.cache.GetValidOrDelete(key)
9261
if !ok {
93-
v, _, _ := nrcd.sf.Do(key.network+":"+key.host, func() (any, error) {
94-
ctx, cl := context.WithTimeout(context.Background(), nrcd.timeout)
62+
v, _, _ := r.sf.Do(key.network+":"+key.host, func() (any, error) {
63+
ctx, cl := context.WithTimeout(context.Background(), r.timeout)
9564
defer cl()
96-
res, err := nrcd.resolver.LookupNetIP(ctx, key.network, key.host)
65+
res, err := r.next.LookupNetIP(ctx, key.network, key.host)
9766
for i := range res {
9867
res[i] = res[i].Unmap()
9968
}
100-
setTTL := nrcd.negTTL
69+
setTTL := r.negTTL
10170
if err == nil {
102-
setTTL = nrcd.posTTL
71+
setTTL = r.posTTL
10372
}
10473
item := &resolverCacheValue{
10574
expires: time.Now().Add(setTTL),
10675
addrs: res,
10776
err: err,
10877
}
109-
nrcd.cache.Set(key, item)
78+
r.cache.Set(key, item)
11079
return item, nil
11180
})
11281
res = v.(*resolverCacheValue)
@@ -116,35 +85,7 @@ func (nrcd *NameResolveCachingDialer) DialContext(ctx context.Context, network,
11685
return nil, res.err
11786
}
11887

119-
if nrcd.preFilter {
120-
ctx = dto.OrigDstToContext(ctx, address)
121-
}
122-
123-
var dialErr error
124-
var conn net.Conn
125-
126-
for _, ip := range res.addrs {
127-
conn, err = nrcd.next.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
128-
if err == nil {
129-
return conn, nil
130-
}
131-
dialErr = multierror.Append(dialErr, err)
132-
var sae dto.StopAddressIteration
133-
if errors.As(err, &sae) {
134-
break
135-
}
136-
}
137-
138-
return nil, fmt.Errorf("failed to dial %s: %w", address, dialErr)
139-
}
140-
141-
func (nrcd *NameResolveCachingDialer) Dial(network, address string) (net.Conn, error) {
142-
return nrcd.DialContext(context.Background(), network, address)
143-
}
144-
145-
func (nrcd *NameResolveCachingDialer) WantsHostname(ctx context.Context, net, address string) bool {
146-
return WantsHostname(ctx, net, address, nrcd.next)
88+
return res.addrs, nil
14789
}
14890

149-
var _ Dialer = new(NameResolveCachingDialer)
150-
var _ HostnameWanter = new(NameResolveCachingDialer)
91+
var _ Resolver = new(CachingResolver)

main.go

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -669,20 +669,20 @@ func run() int {
669669
return 3
670670
}
671671
}
672-
nameResolver = resolver.Prefer(nameResolver, args.dnsPreferAddress.Value())
673-
674-
// construct dialers
675-
var dialerRoot dialer.Dialer = dialer.NewBoundDialer(new(net.Dialer), args.sourceIPHints)
676672
if args.dnsCacheTTL > 0 {
677-
dialerRoot = dialer.NewNameResolveCachingDialer(
678-
dialerRoot,
679-
false,
673+
nameResolver = dialer.NewCachingResolver(
680674
nameResolver,
681675
args.dnsCacheTTL,
682676
args.dnsCacheNegTTL,
683677
args.dnsCacheTimeout,
684678
)
685679
}
680+
nameResolver = resolver.Prefer(nameResolver, args.dnsPreferAddress.Value())
681+
682+
// construct dialers
683+
var dialerRoot dialer.Dialer = dialer.NewBoundDialer(new(net.Dialer), args.sourceIPHints)
684+
// this resolving dialer resolves dials unconditionally, for sake of cache or resolving privacy
685+
dialerRoot = dialer.NewNameResolvingDialer(dialerRoot, nameResolver)
686686
if len(args.proxy) > 0 {
687687
for _, proxy := range args.proxy {
688688
if proxy.literal {
@@ -713,20 +713,10 @@ func run() int {
713713
}
714714
}
715715

716-
dialerRoot = dialer.NewFilterDialer(filterRoot.Access, dialerRoot) // must follow after resolving in chain
717-
718-
if args.dnsCacheTTL > 0 {
719-
dialerRoot = dialer.NewNameResolveCachingDialer(
720-
dialerRoot,
721-
true,
722-
nameResolver,
723-
args.dnsCacheTTL,
724-
args.dnsCacheNegTTL,
725-
args.dnsCacheTimeout,
726-
)
727-
} else {
728-
dialerRoot = dialer.NewNameResolvingDialer(dialerRoot, nameResolver)
729-
}
716+
dialerRoot = dialer.NewFilterDialer(filterRoot.Access, dialerRoot)
717+
// this resolving dialer resolves dials conditionally (unless upstream dialer tells not to)
718+
// for sake of access filtering by destination address
719+
dialerRoot = dialer.NewNameResolvingDialer(dialerRoot, nameResolver)
730720

731721
// unholy plug
732722
if args.tt {

0 commit comments

Comments
 (0)