Skip to content

Commit 23af22f

Browse files
committed
Add winiphlpapi
1 parent d9f6eb1 commit 23af22f

File tree

6 files changed

+780
-4
lines changed

6 files changed

+780
-4
lines changed

common/windnsapi/dnsapi_test.go

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
1+
//go:build windows
2+
13
package windnsapi
24

35
import (
4-
"runtime"
56
"testing"
67

78
"github.com/stretchr/testify/require"
89
)
910

1011
func TestDNSAPI(t *testing.T) {
11-
if runtime.GOOS != "windows" {
12-
t.SkipNow()
13-
}
1412
t.Parallel()
1513
require.NoError(t, FlushResolverCache())
1614
}

common/winiphlpapi/helper.go

+217
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
//go:build windows
2+
3+
package winiphlpapi
4+
5+
import (
6+
"context"
7+
"encoding/binary"
8+
M "github.com/sagernet/sing/common/metadata"
9+
"net"
10+
"net/netip"
11+
"os"
12+
"time"
13+
"unsafe"
14+
15+
E "github.com/sagernet/sing/common/exceptions"
16+
N "github.com/sagernet/sing/common/network"
17+
)
18+
19+
func LoadEStats() error {
20+
err := modiphlpapi.Load()
21+
if err != nil {
22+
return err
23+
}
24+
err = procGetTcpTable.Find()
25+
if err != nil {
26+
return err
27+
}
28+
err = procGetTcp6Table.Find()
29+
if err != nil {
30+
return err
31+
}
32+
err = procGetPerTcp6ConnectionEStats.Find()
33+
if err != nil {
34+
return err
35+
}
36+
err = procGetPerTcp6ConnectionEStats.Find()
37+
if err != nil {
38+
return err
39+
}
40+
err = procSetPerTcpConnectionEStats.Find()
41+
if err != nil {
42+
return err
43+
}
44+
err = procSetPerTcp6ConnectionEStats.Find()
45+
if err != nil {
46+
return err
47+
}
48+
return nil
49+
}
50+
51+
func LoadExtendedTable() error {
52+
err := modiphlpapi.Load()
53+
if err != nil {
54+
return err
55+
}
56+
err = procGetExtendedTcpTable.Find()
57+
if err != nil {
58+
return err
59+
}
60+
err = procGetExtendedUdpTable.Find()
61+
if err != nil {
62+
return err
63+
}
64+
return nil
65+
}
66+
67+
func FindPid(network string, source netip.AddrPort) (uint32, error) {
68+
switch N.NetworkName(network) {
69+
case N.NetworkTCP:
70+
if source.Addr().Is4() {
71+
tcpTable, err := GetExtendedTcpTable()
72+
if err != nil {
73+
return 0, err
74+
}
75+
for _, row := range tcpTable {
76+
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
77+
return row.DwOwningPid, nil
78+
}
79+
}
80+
} else {
81+
tcpTable, err := GetExtendedTcp6Table()
82+
if err != nil {
83+
return 0, err
84+
}
85+
for _, row := range tcpTable {
86+
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
87+
return row.DwOwningPid, nil
88+
}
89+
}
90+
}
91+
case N.NetworkUDP:
92+
if source.Addr().Is4() {
93+
udpTable, err := GetExtendedUdpTable()
94+
if err != nil {
95+
return 0, err
96+
}
97+
for _, row := range udpTable {
98+
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
99+
return row.DwOwningPid, nil
100+
}
101+
}
102+
} else {
103+
udpTable, err := GetExtendedUdp6Table()
104+
if err != nil {
105+
return 0, err
106+
}
107+
for _, row := range udpTable {
108+
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
109+
return row.DwOwningPid, nil
110+
}
111+
}
112+
}
113+
}
114+
return 0, E.New("process not found for ", source)
115+
}
116+
117+
func WriteAndWaitAck(ctx context.Context, conn net.Conn, payload []byte) error {
118+
source := M.AddrPortFromNet(conn.LocalAddr())
119+
destination := M.AddrPortFromNet(conn.RemoteAddr())
120+
if source.Addr().Is4() {
121+
tcpTable, err := GetTcpTable()
122+
if err != nil {
123+
return err
124+
}
125+
var tcpRow *MibTcpRow
126+
for _, row := range tcpTable {
127+
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) ||
128+
destination == netip.AddrPortFrom(DwordToAddr(row.DwRemoteAddr), DwordToPort(row.DwRemotePort)) {
129+
tcpRow = &row
130+
break
131+
}
132+
}
133+
if tcpRow == nil {
134+
return E.New("row not found for: ", source)
135+
}
136+
err = SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
137+
EnableCollection: true,
138+
})
139+
if err != nil {
140+
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
141+
}
142+
defer SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
143+
EnableCollection: false,
144+
})
145+
_, err = conn.Write(payload)
146+
if err != nil {
147+
return err
148+
}
149+
for {
150+
select {
151+
case <-ctx.Done():
152+
return ctx.Err()
153+
default:
154+
}
155+
eStstsSendBuffer, err := GetPerTcpConnectionEStatsSendBuffer(tcpRow)
156+
if err != nil {
157+
return err
158+
}
159+
if eStstsSendBuffer.CurRetxQueue == 0 {
160+
return nil
161+
}
162+
time.Sleep(10 * time.Millisecond)
163+
}
164+
} else {
165+
tcpTable, err := GetTcp6Table()
166+
if err != nil {
167+
return err
168+
}
169+
var tcpRow *MibTcp6Row
170+
for _, row := range tcpTable {
171+
if source == netip.AddrPortFrom(netip.AddrFrom16(row.LocalAddr), DwordToPort(row.LocalPort)) ||
172+
destination == netip.AddrPortFrom(netip.AddrFrom16(row.RemoteAddr), DwordToPort(row.RemotePort)) {
173+
tcpRow = &row
174+
break
175+
}
176+
}
177+
if tcpRow == nil {
178+
return E.New("row not found for: ", source)
179+
}
180+
err = SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
181+
EnableCollection: true,
182+
})
183+
if err != nil {
184+
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
185+
}
186+
defer SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
187+
EnableCollection: false,
188+
})
189+
_, err = conn.Write(payload)
190+
if err != nil {
191+
return err
192+
}
193+
for {
194+
select {
195+
case <-ctx.Done():
196+
return ctx.Err()
197+
default:
198+
}
199+
eStstsSendBuffer, err := GetPerTcp6ConnectionEStatsSendBuffer(tcpRow)
200+
if err != nil {
201+
return err
202+
}
203+
if eStstsSendBuffer.CurRetxQueue == 0 {
204+
return nil
205+
}
206+
time.Sleep(10 * time.Millisecond)
207+
}
208+
}
209+
}
210+
211+
func DwordToAddr(addr uint32) netip.Addr {
212+
return netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(&addr)))
213+
}
214+
215+
func DwordToPort(dword uint32) uint16 {
216+
return binary.BigEndian.Uint16((*[4]byte)(unsafe.Pointer(&dword))[:])
217+
}

0 commit comments

Comments
 (0)