@@ -2,18 +2,12 @@ package dialer
22
33import (
44 "context"
5- "errors"
6- "fmt"
7- "net"
85 "net/netip"
96 "strings"
107 "time"
118
129 "codeberg.org/yarmak/secache"
13- "github.com/hashicorp/go-multierror"
1410 "golang.org/x/sync/singleflight"
15-
16- "github.com/SenseUnit/dumbproxy/dialer/dto"
1711)
1812
1913type resolverCacheKey struct {
@@ -27,21 +21,18 @@ type resolverCacheValue struct {
2721 err error
2822}
2923
30- type NameResolveCachingDialer struct {
31- resolver Resolver
32- cache secache.Cache [resolverCacheKey , * resolverCacheValue ]
33- sf singleflight.Group
34- posTTL time.Duration
35- negTTL time.Duration
36- timeout time.Duration
37- next Dialer
24+ type CachingResolver struct {
25+ next Resolver
26+ cache secache.Cache [resolverCacheKey , * resolverCacheValue ]
27+ sf singleflight.Group
28+ posTTL time.Duration
29+ negTTL time.Duration
30+ timeout time.Duration
3831}
3932
40- func NewNameResolveCachingDialer (next Dialer , resolver Resolver , posTTL , negTTL , timeout time.Duration ) * NameResolveCachingDialer {
41- // func(c *ttlcache.Cache[resolverCacheKey, resolverCacheValue], key resolverCacheKey) *ttlcache.Item[resolverCacheKey, resolverCacheValue] {
42- // },
43- return & NameResolveCachingDialer {
44- resolver : resolver ,
33+ func NewCachingResolver (next Resolver , posTTL , negTTL , timeout time.Duration ) * CachingResolver {
34+ return & CachingResolver {
35+ next : next ,
4536 cache : * (secache .New [resolverCacheKey , * resolverCacheValue ](
4637 3 ,
4738 func (key resolverCacheKey , item * resolverCacheValue ) bool {
@@ -51,62 +42,40 @@ func NewNameResolveCachingDialer(next Dialer, resolver Resolver, posTTL, negTTL,
5142 posTTL : posTTL ,
5243 negTTL : negTTL ,
5344 timeout : timeout ,
54- next : next ,
5545 }
5646}
5747
58- func (nrcd * NameResolveCachingDialer ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
59- if WantsHostname (ctx , network , address , nrcd .next ) {
60- return nrcd .next .DialContext (ctx , network , address )
61- }
62-
63- host , port , err := net .SplitHostPort (address )
64- if err != nil {
65- return nil , fmt .Errorf ("failed to extract host and port from %s: %w" , address , err )
66- }
67-
48+ func (r * CachingResolver ) LookupNetIP (ctx context.Context , network , host string ) ([]netip.Addr , error ) {
6849 if addr , err := netip .ParseAddr (host ); err == nil {
6950 // literal IP address, just do unmapping
70- return nrcd .next .DialContext (ctx , network , net .JoinHostPort (addr .Unmap ().String (), port ))
71- }
72-
73- var resolveNetwork string
74- switch network {
75- case "udp4" , "tcp4" , "ip4" :
76- resolveNetwork = "ip4"
77- case "udp6" , "tcp6" , "ip6" :
78- resolveNetwork = "ip6"
79- case "udp" , "tcp" , "ip" :
80- resolveNetwork = "ip"
81- default :
82- return nil , fmt .Errorf ("resolving dial %q: unsupported network %q" , address , network )
51+ return r .next .LookupNetIP (ctx , network , addr .Unmap ().String ())
8352 }
8453
8554 host = strings .ToLower (host )
8655 key := resolverCacheKey {
87- network : resolveNetwork ,
56+ network : network ,
8857 host : host ,
8958 }
9059
91- res , ok := nrcd .cache .GetValidOrDelete (key )
60+ res , ok := r .cache .GetValidOrDelete (key )
9261 if ! ok {
93- v , _ , _ := nrcd .sf .Do (key .network + ":" + key .host , func () (any , error ) {
94- ctx , cl := context .WithTimeout (context .Background (), nrcd .timeout )
62+ v , _ , _ := r .sf .Do (key .network + ":" + key .host , func () (any , error ) {
63+ ctx , cl := context .WithTimeout (context .Background (), r .timeout )
9564 defer cl ()
96- res , err := nrcd . resolver .LookupNetIP (ctx , key .network , key .host )
65+ res , err := r . next .LookupNetIP (ctx , key .network , key .host )
9766 for i := range res {
9867 res [i ] = res [i ].Unmap ()
9968 }
100- setTTL := nrcd .negTTL
69+ setTTL := r .negTTL
10170 if err == nil {
102- setTTL = nrcd .posTTL
71+ setTTL = r .posTTL
10372 }
10473 item := & resolverCacheValue {
10574 expires : time .Now ().Add (setTTL ),
10675 addrs : res ,
10776 err : err ,
10877 }
109- nrcd .cache .Set (key , item )
78+ r .cache .Set (key , item )
11079 return item , nil
11180 })
11281 res = v .(* resolverCacheValue )
@@ -116,33 +85,7 @@ func (nrcd *NameResolveCachingDialer) DialContext(ctx context.Context, network,
11685 return nil , res .err
11786 }
11887
119- ctx = dto .OrigDstToContext (ctx , address )
120-
121- var dialErr error
122- var conn net.Conn
123-
124- for _ , ip := range res .addrs {
125- conn , err = nrcd .next .DialContext (ctx , network , net .JoinHostPort (ip .String (), port ))
126- if err == nil {
127- return conn , nil
128- }
129- dialErr = multierror .Append (dialErr , err )
130- var sae dto.StopAddressIteration
131- if errors .As (err , & sae ) {
132- break
133- }
134- }
135-
136- return nil , fmt .Errorf ("failed to dial %s: %w" , address , dialErr )
137- }
138-
139- func (nrcd * NameResolveCachingDialer ) Dial (network , address string ) (net.Conn , error ) {
140- return nrcd .DialContext (context .Background (), network , address )
141- }
142-
143- func (nrcd * NameResolveCachingDialer ) WantsHostname (ctx context.Context , net , address string ) bool {
144- return WantsHostname (ctx , net , address , nrcd .next )
88+ return res .addrs , nil
14589}
14690
147- var _ Dialer = new (NameResolveCachingDialer )
148- var _ HostnameWanter = new (NameResolveCachingDialer )
91+ var _ Resolver = new (CachingResolver )
0 commit comments