Skip to content

Commit 383b4ac

Browse files
committed
Add winiphlpapi
1 parent d9f6eb1 commit 383b4ac

File tree

6 files changed

+781
-4
lines changed

6 files changed

+781
-4
lines changed

common/windnsapi/dnsapi_test.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
1+
//go:build windows
2+
13
package windnsapi
24

35
import (
4-
"runtime"
56
"testing"
67

78
"github.com/stretchr/testify/require"
89
)
910

1011
func TestDNSAPI(t *testing.T) {
11-
if runtime.GOOS != "windows" {
12-
t.SkipNow()
13-
}
1412
t.Parallel()
1513
require.NoError(t, FlushResolverCache())
1614
}

common/winiphlpapi/helper.go

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
//go:build windows
2+
3+
package winiphlpapi
4+
5+
import (
6+
"context"
7+
"encoding/binary"
8+
"io"
9+
"net/netip"
10+
"os"
11+
"time"
12+
"unsafe"
13+
14+
E "github.com/sagernet/sing/common/exceptions"
15+
N "github.com/sagernet/sing/common/network"
16+
)
17+
18+
func LoadEStats() error {
19+
err := modiphlpapi.Load()
20+
if err != nil {
21+
return err
22+
}
23+
err = procGetTcpTable.Find()
24+
if err != nil {
25+
return err
26+
}
27+
err = procGetTcp6Table.Find()
28+
if err != nil {
29+
return err
30+
}
31+
err = procGetPerTcp6ConnectionEStats.Find()
32+
if err != nil {
33+
return err
34+
}
35+
err = procGetPerTcp6ConnectionEStats.Find()
36+
if err != nil {
37+
return err
38+
}
39+
err = procSetPerTcpConnectionEStats.Find()
40+
if err != nil {
41+
return err
42+
}
43+
err = procSetPerTcp6ConnectionEStats.Find()
44+
if err != nil {
45+
return err
46+
}
47+
return nil
48+
}
49+
50+
func LoadExtendedTable() error {
51+
err := modiphlpapi.Load()
52+
if err != nil {
53+
return err
54+
}
55+
err = procGetExtendedTcpTable.Find()
56+
if err != nil {
57+
return err
58+
}
59+
err = procGetExtendedUdpTable.Find()
60+
if err != nil {
61+
return err
62+
}
63+
return nil
64+
}
65+
66+
func FindPid(network string, source netip.AddrPort) (uint32, error) {
67+
switch N.NetworkName(network) {
68+
case N.NetworkTCP:
69+
if source.Addr().Is4() {
70+
tcpTable, err := GetExtendedTcpTable()
71+
if err != nil {
72+
return 0, err
73+
}
74+
for _, row := range tcpTable {
75+
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
76+
return row.DwOwningPid, nil
77+
}
78+
}
79+
} else {
80+
tcpTable, err := GetExtendedTcp6Table()
81+
if err != nil {
82+
return 0, err
83+
}
84+
for _, row := range tcpTable {
85+
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
86+
return row.DwOwningPid, nil
87+
}
88+
}
89+
}
90+
case N.NetworkUDP:
91+
if source.Addr().Is4() {
92+
udpTable, err := GetExtendedUdpTable()
93+
if err != nil {
94+
return 0, err
95+
}
96+
for _, row := range udpTable {
97+
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
98+
return row.DwOwningPid, nil
99+
}
100+
}
101+
} else {
102+
udpTable, err := GetExtendedUdp6Table()
103+
if err != nil {
104+
return 0, err
105+
}
106+
for _, row := range udpTable {
107+
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
108+
return row.DwOwningPid, nil
109+
}
110+
}
111+
}
112+
}
113+
return 0, E.New("process not found for ", source)
114+
}
115+
116+
func WriteAndWaitAck(ctx context.Context, source netip.AddrPort, destination netip.AddrPort, writer io.Writer, payload []byte) error {
117+
if source.Addr().Is4() {
118+
tcpTable, err := GetTcpTable()
119+
if err != nil {
120+
return err
121+
}
122+
var tcpRow *MibTcpRow
123+
for _, row := range tcpTable {
124+
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) ||
125+
destination == netip.AddrPortFrom(DwordToAddr(row.DwRemoteAddr), DwordToPort(row.DwRemotePort)) {
126+
tcpRow = &row
127+
break
128+
}
129+
}
130+
if tcpRow == nil {
131+
return E.New("row not found for: ", source)
132+
}
133+
err = SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
134+
EnableCollection: true,
135+
})
136+
if err != nil {
137+
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
138+
}
139+
defer SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
140+
EnableCollection: false,
141+
})
142+
_, err = writer.Write(payload)
143+
if err != nil {
144+
return err
145+
}
146+
for {
147+
select {
148+
case <-ctx.Done():
149+
return ctx.Err()
150+
default:
151+
}
152+
eStstsSendBuffer, err := GetPerTcpConnectionEStatsSendBuffer(tcpRow)
153+
if err != nil {
154+
return err
155+
}
156+
if eStstsSendBuffer.CurRetxQueue == 0 {
157+
return nil
158+
}
159+
time.Sleep(10 * time.Millisecond)
160+
}
161+
} else {
162+
tcpTable, err := GetTcp6Table()
163+
if err != nil {
164+
return err
165+
}
166+
var tcpRow *MibTcp6Row
167+
for _, row := range tcpTable {
168+
if source == netip.AddrPortFrom(netip.AddrFrom16(row.LocalAddr), DwordToPort(row.LocalPort)) ||
169+
destination == netip.AddrPortFrom(netip.AddrFrom16(row.RemoteAddr), DwordToPort(row.RemotePort)) {
170+
tcpRow = &row
171+
break
172+
}
173+
}
174+
if tcpRow == nil {
175+
return E.New("row not found for: ", source)
176+
}
177+
err = SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
178+
EnableCollection: true,
179+
})
180+
if err != nil {
181+
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
182+
}
183+
defer SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
184+
EnableCollection: false,
185+
})
186+
_, err = writer.Write(payload)
187+
if err != nil {
188+
return err
189+
}
190+
for {
191+
select {
192+
case <-ctx.Done():
193+
return ctx.Err()
194+
default:
195+
}
196+
eStstsSendBuffer, err := GetPerTcp6ConnectionEStatsSendBuffer(tcpRow)
197+
if err != nil {
198+
return err
199+
}
200+
if eStstsSendBuffer.CurRetxQueue == 0 {
201+
return nil
202+
}
203+
time.Sleep(10 * time.Millisecond)
204+
}
205+
}
206+
}
207+
208+
func DwordToAddr(addr uint32) netip.Addr {
209+
return netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(&addr)))
210+
}
211+
212+
func DwordToPort(dword uint32) uint16 {
213+
return binary.BigEndian.Uint16((*[4]byte)(unsafe.Pointer(&dword))[:])
214+
}

0 commit comments

Comments
 (0)