@@ -9,54 +9,95 @@ import (
9
9
N "github.com/sagernet/sing/common/network"
10
10
)
11
11
12
- type NATPacketConn struct {
12
+ type NATPacketConn interface {
13
13
N.NetPacketConn
14
- origin M.Socksaddr
15
- destination M.Socksaddr
14
+ UpdateDestination (destinationAddress netip.Addr )
16
15
}
17
16
18
- func NewNATPacketConn (conn N.NetPacketConn , origin M.Socksaddr , destination M.Socksaddr ) * NATPacketConn {
19
- return & NATPacketConn {
17
+ func NewUnidirectionalNATPacketConn (conn N.NetPacketConn , origin M.Socksaddr , destination M.Socksaddr ) NATPacketConn {
18
+ return & unidirectionalNATPacketConn {
20
19
NetPacketConn : conn ,
21
20
origin : origin ,
22
21
destination : destination ,
23
22
}
24
23
}
25
24
26
- func (c * NATPacketConn ) ReadFrom (p []byte ) (n int , addr net.Addr , err error ) {
25
+ func NewNATPacketConn (conn N.NetPacketConn , origin M.Socksaddr , destination M.Socksaddr ) NATPacketConn {
26
+ return & bidirectionalNATPacketConn {
27
+ NetPacketConn : conn ,
28
+ origin : origin ,
29
+ destination : destination ,
30
+ }
31
+ }
32
+
33
+ type unidirectionalNATPacketConn struct {
34
+ N.NetPacketConn
35
+ origin M.Socksaddr
36
+ destination M.Socksaddr
37
+ }
38
+
39
+ func (c * unidirectionalNATPacketConn ) WriteTo (p []byte , addr net.Addr ) (n int , err error ) {
40
+ if M .SocksaddrFromNet (addr ) == c .destination {
41
+ addr = c .origin .UDPAddr ()
42
+ }
43
+ return c .NetPacketConn .WriteTo (p , addr )
44
+ }
45
+
46
+ func (c * unidirectionalNATPacketConn ) WritePacket (buffer * buf.Buffer , destination M.Socksaddr ) error {
47
+ if destination == c .destination {
48
+ destination = c .origin
49
+ }
50
+ return c .NetPacketConn .WritePacket (buffer , destination )
51
+ }
52
+
53
+ func (c * unidirectionalNATPacketConn ) UpdateDestination (destinationAddress netip.Addr ) {
54
+ c .destination = M .SocksaddrFrom (destinationAddress , c .destination .Port )
55
+ }
56
+
57
+ func (c * unidirectionalNATPacketConn ) Upstream () any {
58
+ return c .NetPacketConn
59
+ }
60
+
61
+ type bidirectionalNATPacketConn struct {
62
+ N.NetPacketConn
63
+ origin M.Socksaddr
64
+ destination M.Socksaddr
65
+ }
66
+
67
+ func (c * bidirectionalNATPacketConn ) ReadFrom (p []byte ) (n int , addr net.Addr , err error ) {
27
68
n , addr , err = c .NetPacketConn .ReadFrom (p )
28
69
if err == nil && M .SocksaddrFromNet (addr ) == c .origin {
29
70
addr = c .destination .UDPAddr ()
30
71
}
31
72
return
32
73
}
33
74
34
- func (c * NATPacketConn ) WriteTo (p []byte , addr net.Addr ) (n int , err error ) {
75
+ func (c * bidirectionalNATPacketConn ) WriteTo (p []byte , addr net.Addr ) (n int , err error ) {
35
76
if M .SocksaddrFromNet (addr ) == c .destination {
36
77
addr = c .origin .UDPAddr ()
37
78
}
38
79
return c .NetPacketConn .WriteTo (p , addr )
39
80
}
40
81
41
- func (c * NATPacketConn ) ReadPacket (buffer * buf.Buffer ) (destination M.Socksaddr , err error ) {
82
+ func (c * bidirectionalNATPacketConn ) ReadPacket (buffer * buf.Buffer ) (destination M.Socksaddr , err error ) {
42
83
destination , err = c .NetPacketConn .ReadPacket (buffer )
43
84
if destination == c .origin {
44
85
destination = c .destination
45
86
}
46
87
return
47
88
}
48
89
49
- func (c * NATPacketConn ) WritePacket (buffer * buf.Buffer , destination M.Socksaddr ) error {
90
+ func (c * bidirectionalNATPacketConn ) WritePacket (buffer * buf.Buffer , destination M.Socksaddr ) error {
50
91
if destination == c .destination {
51
92
destination = c .origin
52
93
}
53
94
return c .NetPacketConn .WritePacket (buffer , destination )
54
95
}
55
96
56
- func (c * NATPacketConn ) UpdateDestination (destinationAddress netip.Addr ) {
97
+ func (c * bidirectionalNATPacketConn ) UpdateDestination (destinationAddress netip.Addr ) {
57
98
c .destination = M .SocksaddrFrom (destinationAddress , c .destination .Port )
58
99
}
59
100
60
- func (c * NATPacketConn ) Upstream () any {
101
+ func (c * bidirectionalNATPacketConn ) Upstream () any {
61
102
return c .NetPacketConn
62
103
}
0 commit comments