@@ -2,18 +2,12 @@ package dialer
22
33import (
44 "context"
5- "errors"
6- "fmt"
7- "net"
85 "net/netip"
96 "strings"
107 "time"
118
129 "codeberg.org/yarmak/secache"
13- "github.com/hashicorp/go-multierror"
1410 "golang.org/x/sync/singleflight"
15-
16- "github.com/SenseUnit/dumbproxy/dialer/dto"
1711)
1812
1913type resolverCacheKey struct {
@@ -27,21 +21,18 @@ type resolverCacheValue struct {
2721 err error
2822}
2923
30- type NameResolveCachingDialer struct {
31- resolver Resolver
32- preFilter bool
33- cache secache.Cache [resolverCacheKey , * resolverCacheValue ]
34- sf singleflight.Group
35- posTTL time.Duration
36- negTTL time.Duration
37- timeout time.Duration
38- next Dialer
24+ type CachingResolver struct {
25+ next Resolver
26+ cache secache.Cache [resolverCacheKey , * resolverCacheValue ]
27+ sf singleflight.Group
28+ posTTL time.Duration
29+ negTTL time.Duration
30+ timeout time.Duration
3931}
4032
41- func NewNameResolveCachingDialer (next Dialer , preFilter bool , resolver Resolver , posTTL , negTTL , timeout time.Duration ) * NameResolveCachingDialer {
42- return & NameResolveCachingDialer {
43- resolver : resolver ,
44- preFilter : preFilter ,
33+ func NewCachingResolver (next Resolver , posTTL , negTTL , timeout time.Duration ) * CachingResolver {
34+ return & CachingResolver {
35+ next : next ,
4536 cache : * (secache .New [resolverCacheKey , * resolverCacheValue ](
4637 3 ,
4738 func (key resolverCacheKey , item * resolverCacheValue ) bool {
@@ -51,62 +42,40 @@ func NewNameResolveCachingDialer(next Dialer, preFilter bool, resolver Resolver,
5142 posTTL : posTTL ,
5243 negTTL : negTTL ,
5344 timeout : timeout ,
54- next : next ,
5545 }
5646}
5747
58- func (nrcd * NameResolveCachingDialer ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
59- if nrcd .preFilter && WantsHostname (ctx , network , address , nrcd .next ) {
60- return nrcd .next .DialContext (ctx , network , address )
61- }
62-
63- host , port , err := net .SplitHostPort (address )
64- if err != nil {
65- return nil , fmt .Errorf ("failed to extract host and port from %s: %w" , address , err )
66- }
67-
48+ func (r * CachingResolver ) LookupNetIP (ctx context.Context , network , host string ) ([]netip.Addr , error ) {
6849 if addr , err := netip .ParseAddr (host ); err == nil {
6950 // literal IP address, just do unmapping
70- return nrcd .next .DialContext (ctx , network , net .JoinHostPort (addr .Unmap ().String (), port ))
71- }
72-
73- var resolveNetwork string
74- switch network {
75- case "udp4" , "tcp4" , "ip4" :
76- resolveNetwork = "ip4"
77- case "udp6" , "tcp6" , "ip6" :
78- resolveNetwork = "ip6"
79- case "udp" , "tcp" , "ip" :
80- resolveNetwork = "ip"
81- default :
82- return nil , fmt .Errorf ("resolving dial %q: unsupported network %q" , address , network )
51+ return r .next .LookupNetIP (ctx , network , addr .Unmap ().String ())
8352 }
8453
8554 host = strings .ToLower (host )
8655 key := resolverCacheKey {
87- network : resolveNetwork ,
56+ network : network ,
8857 host : host ,
8958 }
9059
91- res , ok := nrcd .cache .GetValidOrDelete (key )
60+ res , ok := r .cache .GetValidOrDelete (key )
9261 if ! ok {
93- v , _ , _ := nrcd .sf .Do (key .network + ":" + key .host , func () (any , error ) {
94- ctx , cl := context .WithTimeout (context .Background (), nrcd .timeout )
62+ v , _ , _ := r .sf .Do (key .network + ":" + key .host , func () (any , error ) {
63+ ctx , cl := context .WithTimeout (context .Background (), r .timeout )
9564 defer cl ()
96- res , err := nrcd . resolver .LookupNetIP (ctx , key .network , key .host )
65+ res , err := r . next .LookupNetIP (ctx , key .network , key .host )
9766 for i := range res {
9867 res [i ] = res [i ].Unmap ()
9968 }
100- setTTL := nrcd .negTTL
69+ setTTL := r .negTTL
10170 if err == nil {
102- setTTL = nrcd .posTTL
71+ setTTL = r .posTTL
10372 }
10473 item := & resolverCacheValue {
10574 expires : time .Now ().Add (setTTL ),
10675 addrs : res ,
10776 err : err ,
10877 }
109- nrcd .cache .Set (key , item )
78+ r .cache .Set (key , item )
11079 return item , nil
11180 })
11281 res = v .(* resolverCacheValue )
@@ -116,35 +85,7 @@ func (nrcd *NameResolveCachingDialer) DialContext(ctx context.Context, network,
11685 return nil , res .err
11786 }
11887
119- if nrcd .preFilter {
120- ctx = dto .OrigDstToContext (ctx , address )
121- }
122-
123- var dialErr error
124- var conn net.Conn
125-
126- for _ , ip := range res .addrs {
127- conn , err = nrcd .next .DialContext (ctx , network , net .JoinHostPort (ip .String (), port ))
128- if err == nil {
129- return conn , nil
130- }
131- dialErr = multierror .Append (dialErr , err )
132- var sae dto.StopAddressIteration
133- if errors .As (err , & sae ) {
134- break
135- }
136- }
137-
138- return nil , fmt .Errorf ("failed to dial %s: %w" , address , dialErr )
139- }
140-
141- func (nrcd * NameResolveCachingDialer ) Dial (network , address string ) (net.Conn , error ) {
142- return nrcd .DialContext (context .Background (), network , address )
143- }
144-
145- func (nrcd * NameResolveCachingDialer ) WantsHostname (ctx context.Context , net , address string ) bool {
146- return WantsHostname (ctx , net , address , nrcd .next )
88+ return res .addrs , nil
14789}
14890
149- var _ Dialer = new (NameResolveCachingDialer )
150- var _ HostnameWanter = new (NameResolveCachingDialer )
91+ var _ Resolver = new (CachingResolver )
0 commit comments