From 8dff604468ff8a26def622ac531e78210d8f9418 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 29 Jan 2025 19:59:59 +0800 Subject: [PATCH 1/5] Add winiphlpapi --- common/windnsapi/dnsapi_test.go | 6 +- common/winiphlpapi/helper.go | 217 +++++++++++++++++ common/winiphlpapi/iphlpapi.go | 313 +++++++++++++++++++++++++ common/winiphlpapi/iphlpapi_test.go | 90 +++++++ common/winiphlpapi/syscall_windows.go | 27 +++ common/winiphlpapi/zsyscall_windows.go | 131 +++++++++++ 6 files changed, 780 insertions(+), 4 deletions(-) create mode 100644 common/winiphlpapi/helper.go create mode 100644 common/winiphlpapi/iphlpapi.go create mode 100644 common/winiphlpapi/iphlpapi_test.go create mode 100644 common/winiphlpapi/syscall_windows.go create mode 100644 common/winiphlpapi/zsyscall_windows.go diff --git a/common/windnsapi/dnsapi_test.go b/common/windnsapi/dnsapi_test.go index adf582d0..c5ea8310 100644 --- a/common/windnsapi/dnsapi_test.go +++ b/common/windnsapi/dnsapi_test.go @@ -1,16 +1,14 @@ +//go:build windows + package windnsapi import ( - "runtime" "testing" "github.com/stretchr/testify/require" ) func TestDNSAPI(t *testing.T) { - if runtime.GOOS != "windows" { - t.SkipNow() - } t.Parallel() require.NoError(t, FlushResolverCache()) } diff --git a/common/winiphlpapi/helper.go b/common/winiphlpapi/helper.go new file mode 100644 index 00000000..aace6c8d --- /dev/null +++ b/common/winiphlpapi/helper.go @@ -0,0 +1,217 @@ +//go:build windows + +package winiphlpapi + +import ( + "context" + "encoding/binary" + M "github.com/sagernet/sing/common/metadata" + "net" + "net/netip" + "os" + "time" + "unsafe" + + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" +) + +func LoadEStats() error { + err := modiphlpapi.Load() + if err != nil { + return err + } + err = procGetTcpTable.Find() + if err != nil { + return err + } + err = procGetTcp6Table.Find() + if err != nil { + return err + } + err = procGetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + err = procGetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + err = procSetPerTcpConnectionEStats.Find() + if err != nil { + return err + } + err = procSetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + return nil +} + +func LoadExtendedTable() error { + err := modiphlpapi.Load() + if err != nil { + return err + } + err = procGetExtendedTcpTable.Find() + if err != nil { + return err + } + err = procGetExtendedUdpTable.Find() + if err != nil { + return err + } + return nil +} + +func FindPid(network string, source netip.AddrPort) (uint32, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + if source.Addr().Is4() { + tcpTable, err := GetExtendedTcpTable() + if err != nil { + return 0, err + } + for _, row := range tcpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } else { + tcpTable, err := GetExtendedTcp6Table() + if err != nil { + return 0, err + } + for _, row := range tcpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } + case N.NetworkUDP: + if source.Addr().Is4() { + udpTable, err := GetExtendedUdpTable() + if err != nil { + return 0, err + } + for _, row := range udpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } else { + udpTable, err := GetExtendedUdp6Table() + if err != nil { + return 0, err + } + for _, row := range udpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } + } + return 0, E.New("process not found for ", source) +} + +func WriteAndWaitAck(ctx context.Context, conn net.Conn, payload []byte) error { + source := M.AddrPortFromNet(conn.LocalAddr()) + destination := M.AddrPortFromNet(conn.RemoteAddr()) + if source.Addr().Is4() { + tcpTable, err := GetTcpTable() + if err != nil { + return err + } + var tcpRow *MibTcpRow + for _, row := range tcpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) || + destination == netip.AddrPortFrom(DwordToAddr(row.DwRemoteAddr), DwordToPort(row.DwRemotePort)) { + tcpRow = &row + break + } + } + if tcpRow == nil { + return E.New("row not found for: ", source) + } + err = SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: true, + }) + if err != nil { + return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err) + } + defer SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: false, + }) + _, err = conn.Write(payload) + if err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + eStstsSendBuffer, err := GetPerTcpConnectionEStatsSendBuffer(tcpRow) + if err != nil { + return err + } + if eStstsSendBuffer.CurRetxQueue == 0 { + return nil + } + time.Sleep(10 * time.Millisecond) + } + } else { + tcpTable, err := GetTcp6Table() + if err != nil { + return err + } + var tcpRow *MibTcp6Row + for _, row := range tcpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.LocalAddr), DwordToPort(row.LocalPort)) || + destination == netip.AddrPortFrom(netip.AddrFrom16(row.RemoteAddr), DwordToPort(row.RemotePort)) { + tcpRow = &row + break + } + } + if tcpRow == nil { + return E.New("row not found for: ", source) + } + err = SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: true, + }) + if err != nil { + return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err) + } + defer SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: false, + }) + _, err = conn.Write(payload) + if err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + eStstsSendBuffer, err := GetPerTcp6ConnectionEStatsSendBuffer(tcpRow) + if err != nil { + return err + } + if eStstsSendBuffer.CurRetxQueue == 0 { + return nil + } + time.Sleep(10 * time.Millisecond) + } + } +} + +func DwordToAddr(addr uint32) netip.Addr { + return netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(&addr))) +} + +func DwordToPort(dword uint32) uint16 { + return binary.BigEndian.Uint16((*[4]byte)(unsafe.Pointer(&dword))[:]) +} diff --git a/common/winiphlpapi/iphlpapi.go b/common/winiphlpapi/iphlpapi.go new file mode 100644 index 00000000..74e5b90e --- /dev/null +++ b/common/winiphlpapi/iphlpapi.go @@ -0,0 +1,313 @@ +//go:build windows + +package winiphlpapi + +import ( + "errors" + "os" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + TcpTableBasicListener uint32 = iota + TcpTableBasicConnections + TcpTableBasicAll + TcpTableOwnerPidListener + TcpTableOwnerPidConnections + TcpTableOwnerPidAll + TcpTableOwnerModuleListener + TcpTableOwnerModuleConnections + TcpTableOwnerModuleAll +) + +const ( + UdpTableBasic uint32 = iota + UdpTableOwnerPid + UdpTableOwnerModule +) + +const ( + TcpConnectionEstatsSynOpts uint32 = iota + TcpConnectionEstatsData + TcpConnectionEstatsSndCong + TcpConnectionEstatsPath + TcpConnectionEstatsSendBuff + TcpConnectionEstatsRec + TcpConnectionEstatsObsRec + TcpConnectionEstatsBandwidth + TcpConnectionEstatsFineRtt + TcpConnectionEstatsMaximum +) + +type MibTcpTable struct { + DwNumEntries uint32 + Table [1]MibTcpRow +} + +type MibTcpRow struct { + DwState uint32 + DwLocalAddr uint32 + DwLocalPort uint32 + DwRemoteAddr uint32 + DwRemotePort uint32 +} + +type MibTcp6Table struct { + DwNumEntries uint32 + Table [1]MibTcp6Row +} + +type MibTcp6Row struct { + State uint32 + LocalAddr [16]byte + LocalScopeId uint32 + LocalPort uint32 + RemoteAddr [16]byte + RemoteScopeId uint32 + RemotePort uint32 +} + +type MibTcpTableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibTcpRowOwnerPid +} + +type MibTcpRowOwnerPid struct { + DwState uint32 + DwLocalAddr uint32 + DwLocalPort uint32 + DwRemoteAddr uint32 + DwRemotePort uint32 + DwOwningPid uint32 +} + +type MibTcp6TableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibTcp6RowOwnerPid +} + +type MibTcp6RowOwnerPid struct { + UcLocalAddr [16]byte + DwLocalScopeId uint32 + DwLocalPort uint32 + UcRemoteAddr [16]byte + DwRemoteScopeId uint32 + DwRemotePort uint32 + DwState uint32 + DwOwningPid uint32 +} + +type MibUdpTableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibUdpRowOwnerPid +} + +type MibUdpRowOwnerPid struct { + DwLocalAddr uint32 + DwLocalPort uint32 + DwOwningPid uint32 +} + +type MibUdp6TableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibUdp6RowOwnerPid +} + +type MibUdp6RowOwnerPid struct { + UcLocalAddr [16]byte + DwLocalScopeId uint32 + DwLocalPort uint32 + DwOwningPid uint32 +} + +type TcpEstatsSendBufferRodV0 struct { + CurRetxQueue uint64 + MaxRetxQueue uint64 + CurAppWQueue uint64 + MaxAppWQueue uint64 +} + +type TcpEstatsSendBuffRwV0 struct { + EnableCollection bool +} + +const ( + offsetOfMibTcpTable = unsafe.Offsetof(MibTcpTable{}.Table) + offsetOfMibTcp6Table = unsafe.Offsetof(MibTcp6Table{}.Table) + offsetOfMibTcpTableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table) + offsetOfMibTcp6TableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table) + offsetOfMibUdpTableOwnerPid = unsafe.Offsetof(MibUdpTableOwnerPid{}.Table) + offsetOfMibUdp6TableOwnerPid = unsafe.Offsetof(MibUdp6TableOwnerPid{}.Table) + sizeOfTcpEstatsSendBuffRwV0 = unsafe.Sizeof(TcpEstatsSendBuffRwV0{}) + sizeOfTcpEstatsSendBufferRodV0 = unsafe.Sizeof(TcpEstatsSendBufferRodV0{}) +) + +func GetTcpTable() ([]MibTcpRow, error) { + var size uint32 + err := getTcpTable(nil, &size, false) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, err + } + for { + table := make([]byte, size) + err = getTcpTable(&table[0], &size, false) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, err + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcpRow)(unsafe.Pointer(&table[offsetOfMibTcpTable])), dwNumEntries), nil + } +} + +func GetTcp6Table() ([]MibTcp6Row, error) { + var size uint32 + err := getTcp6Table(nil, &size, false) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, err + } + for { + table := make([]byte, size) + err = getTcp6Table(&table[0], &size, false) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, err + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcp6Row)(unsafe.Pointer(&table[offsetOfMibTcp6Table])), dwNumEntries), nil + } +} + +func GetExtendedTcpTable() ([]MibTcpRowOwnerPid, error) { + var size uint32 + err := getExtendedTcpTable(nil, &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcpTableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedTcp6Table() ([]MibTcp6RowOwnerPid, error) { + var size uint32 + err := getExtendedTcpTable(nil, &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcp6TableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedUdpTable() ([]MibUdpRowOwnerPid, error) { + var size uint32 + err := getExtendedUdpTable(nil, &size, false, windows.AF_INET, UdpTableOwnerPid, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET, UdpTableOwnerPid, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibUdpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdpTableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedUdp6Table() ([]MibUdp6RowOwnerPid, error) { + var size uint32 + err := getExtendedUdpTable(nil, &size, false, windows.AF_INET6, UdpTableOwnerPid, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET6, UdpTableOwnerPid, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibUdp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdp6TableOwnerPid])), dwNumEntries), nil + } +} + +func GetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow) (*TcpEstatsSendBufferRodV0, error) { + var rod TcpEstatsSendBufferRodV0 + err := getPerTcpConnectionEStats(row, + TcpConnectionEstatsSendBuff, + 0, + 0, + 0, + 0, + 0, + 0, + uintptr(unsafe.Pointer(&rod)), + 0, + uint64(sizeOfTcpEstatsSendBufferRodV0), + ) + if err != nil { + return nil, err + } + return &rod, nil +} + +func GetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row) (*TcpEstatsSendBufferRodV0, error) { + var rod TcpEstatsSendBufferRodV0 + err := getPerTcp6ConnectionEStats(row, + TcpConnectionEstatsSendBuff, + 0, + 0, + 0, + 0, + 0, + 0, + uintptr(unsafe.Pointer(&rod)), + 0, + uint64(sizeOfTcpEstatsSendBufferRodV0), + ) + if err != nil { + return nil, err + } + return &rod, nil +} + +func SetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow, rw *TcpEstatsSendBuffRwV0) error { + return setPerTcpConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0) +} + +func SetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row, rw *TcpEstatsSendBuffRwV0) error { + return setPerTcp6ConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0) +} diff --git a/common/winiphlpapi/iphlpapi_test.go b/common/winiphlpapi/iphlpapi_test.go new file mode 100644 index 00000000..5fc3b741 --- /dev/null +++ b/common/winiphlpapi/iphlpapi_test.go @@ -0,0 +1,90 @@ +//go:build windows + +package winiphlpapi_test + +import ( + "context" + "net" + "syscall" + "testing" + + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/winiphlpapi" + + "github.com/stretchr/testify/require" +) + +func TestFindPidTcp4(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidTcp6(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "[::1]:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidUdp4(t *testing.T) { + t.Parallel() + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidUdp6(t *testing.T) { + t.Parallel() + conn, err := net.ListenPacket("udp", "[::1]:0") + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestWaitAck4(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello")) + require.NoError(t, err) +} + +func TestWaitAck6(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "[::1]:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello")) + require.NoError(t, err) +} diff --git a/common/winiphlpapi/syscall_windows.go b/common/winiphlpapi/syscall_windows.go new file mode 100644 index 00000000..f6aab14c --- /dev/null +++ b/common/winiphlpapi/syscall_windows.go @@ -0,0 +1,27 @@ +package winiphlpapi + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcptable +//sys getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcpTable + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcp6table +//sys getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcp6Table + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcpconnectionestats +//sys getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcpConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcp6connectionestats +//sys getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcp6ConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcpconnectionestats +//sys setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcpConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcp6connectionestats +//sys setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcp6ConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable +//sys getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedTcpTable + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedudptable +//sys getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedUdpTable diff --git a/common/winiphlpapi/zsyscall_windows.go b/common/winiphlpapi/zsyscall_windows.go new file mode 100644 index 00000000..e5e93088 --- /dev/null +++ b/common/winiphlpapi/zsyscall_windows.go @@ -0,0 +1,131 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package winiphlpapi + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") + + procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable") + procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable") + procGetPerTcp6ConnectionEStats = modiphlpapi.NewProc("GetPerTcp6ConnectionEStats") + procGetPerTcpConnectionEStats = modiphlpapi.NewProc("GetPerTcpConnectionEStats") + procGetTcp6Table = modiphlpapi.NewProc("GetTcp6Table") + procGetTcpTable = modiphlpapi.NewProc("GetTcpTable") + procSetPerTcp6ConnectionEStats = modiphlpapi.NewProc("SetPerTcp6ConnectionEStats") + procSetPerTcpConnectionEStats = modiphlpapi.NewProc("SetPerTcpConnectionEStats") +) + +func getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) { + var _p0 uint32 + if bOrder { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procGetExtendedTcpTable.Addr(), 6, uintptr(unsafe.Pointer(pTcpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) { + var _p0 uint32 + if bOrder { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procGetExtendedUdpTable.Addr(), 6, uintptr(unsafe.Pointer(pUdpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) { + r0, _, _ := syscall.Syscall12(procGetPerTcp6ConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) { + r0, _, _ := syscall.Syscall12(procGetPerTcpConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) { + var _p0 uint32 + if order { + _p0 = 1 + } + r0, _, _ := syscall.Syscall(procGetTcp6Table.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) { + var _p0 uint32 + if order { + _p0 = 1 + } + r0, _, _ := syscall.Syscall(procGetTcpTable.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) { + r0, _, _ := syscall.Syscall6(procSetPerTcp6ConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) { + r0, _, _ := syscall.Syscall6(procSetPerTcpConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} From 3464ed3babc051e8982eedf6b895b5bf24ac4896 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 10 Feb 2025 18:59:17 +0800 Subject: [PATCH 2/5] Fix merge objects --- common/json/badjson/merge_objects.go | 10 +++----- common/json/internal/contextjson/keys.go | 20 +++++++++++++++ common/json/internal/contextjson/keys_test.go | 25 +++++++++++++++++++ 3 files changed, 49 insertions(+), 6 deletions(-) create mode 100644 common/json/internal/contextjson/keys.go create mode 100644 common/json/internal/contextjson/keys_test.go diff --git a/common/json/badjson/merge_objects.go b/common/json/badjson/merge_objects.go index fa6c2d42..5b232097 100644 --- a/common/json/badjson/merge_objects.go +++ b/common/json/badjson/merge_objects.go @@ -2,9 +2,11 @@ package badjson import ( "context" + "reflect" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" + cJSON "github.com/sagernet/sing/common/json/internal/contextjson" ) func MarshallObjects(objects ...any) ([]byte, error) { @@ -31,16 +33,12 @@ func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error } func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error { - parentContent, err := newJSONObject(ctx, parentObject) - if err != nil { - return err - } var content JSONObject - err = content.UnmarshalJSONContext(ctx, inputContent) + err := content.UnmarshalJSONContext(ctx, inputContent) if err != nil { return err } - for _, key := range parentContent.Keys() { + for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) { content.Remove(key) } if object == nil { diff --git a/common/json/internal/contextjson/keys.go b/common/json/internal/contextjson/keys.go new file mode 100644 index 00000000..589007f6 --- /dev/null +++ b/common/json/internal/contextjson/keys.go @@ -0,0 +1,20 @@ +package json + +import ( + "reflect" + + "github.com/sagernet/sing/common" +) + +func ObjectKeys(object reflect.Type) []string { + switch object.Kind() { + case reflect.Pointer: + return ObjectKeys(object.Elem()) + case reflect.Struct: + default: + panic("invalid non-struct input") + } + return common.Map(cachedTypeFields(object).list, func(field field) string { + return field.name + }) +} diff --git a/common/json/internal/contextjson/keys_test.go b/common/json/internal/contextjson/keys_test.go new file mode 100644 index 00000000..5de4dc57 --- /dev/null +++ b/common/json/internal/contextjson/keys_test.go @@ -0,0 +1,25 @@ +package json_test + +import ( + "reflect" + "testing" + + json "github.com/sagernet/sing/common/json/internal/contextjson" + + "github.com/stretchr/testify/require" +) + +type MyObject struct { + Hello string `json:"hello,omitempty"` + MyWorld + MyWorld2 string `json:"-"` +} + +type MyWorld struct { + World string `json:"world,omitempty"` +} + +func TestObjectKeys(t *testing.T) { + keys := json.ObjectKeys(reflect.TypeOf(&MyObject{})) + require.Equal(t, []string{"hello", "world"}, keys) +} From 0a3c811e429ca9ddc5d7d070439e64d818ca9992 Mon Sep 17 00:00:00 2001 From: Tommy Wu Date: Thu, 6 Feb 2025 10:45:27 +0800 Subject: [PATCH 3/5] add Digest authentication for http proxy server https://datatracker.ietf.org/doc/html/rfc2617 server will send both Basic and Digest header to client client can use either Basic or Digest for authentication Change-Id: Iaa6629c143551770c836af3ead823bd148b244c6 --- common/auth/auth.go | 71 +++++++++++++- common/param/param.go | 189 +++++++++++++++++++++++++++++++++++++ protocol/http/handshake.go | 68 ++++++++++++- 3 files changed, 322 insertions(+), 6 deletions(-) create mode 100644 common/param/param.go diff --git a/common/auth/auth.go b/common/auth/auth.go index b1be60d1..81597b6a 100644 --- a/common/auth/auth.go +++ b/common/auth/auth.go @@ -1,6 +1,23 @@ package auth -import "github.com/sagernet/sing/common" +import ( + "crypto/md5" + "encoding/hex" + "fmt" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/param" +) + +const Realm = "sing-box" + +type Challenge struct { + Username string + Nonce string + CNonce string + Nc string + Response string +} type User struct { Username string @@ -28,3 +45,55 @@ func (au *Authenticator) Verify(username string, password string) bool { passwordList, ok := au.userMap[username] return ok && common.Contains(passwordList, password) } + +func (au *Authenticator) VerifyDigest(method string, uri string, s string) (string, bool) { + c, err := ParseChallenge(s) + if err != nil { + return "", false + } + if c.Username == "" || c.Nonce == "" || c.Nc == "" || c.CNonce == "" || c.Response == "" { + return "", false + } + passwordList, ok := au.userMap[c.Username] + if ok { + for _, password := range passwordList { + ha1 := md5str(c.Username + ":" + Realm + ":" + password) + ha2 := md5str(method + ":" + uri) + resp := md5str(ha1 + ":" + c.Nonce + ":" + c.Nc + ":" + c.CNonce + ":auth:" + ha2) + if resp == c.Response { + return c.Username, true + } + } + } + return "", false +} + +func ParseChallenge(s string) (*Challenge, error) { + pp, err := param.Parse(s) + if err != nil { + return nil, fmt.Errorf("digest: invalid challenge: %w", err) + } + var c Challenge + + for _, p := range pp { + switch p.Key { + case "username": + c.Username = p.Value + case "nonce": + c.Nonce = p.Value + case "cnonce": + c.CNonce = p.Value + case "nc": + c.Nc = p.Value + case "response": + c.Response = p.Value + } + } + return &c, nil +} + +func md5str(str string) string { + h := md5.New() + h.Write([]byte(str)) + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/common/param/param.go b/common/param/param.go new file mode 100644 index 00000000..7fbd4d8a --- /dev/null +++ b/common/param/param.go @@ -0,0 +1,189 @@ +package param + +// code retrieve from https://github.com/icholy/digest/tree/master/internal/param + +import ( + "bufio" + "fmt" + "io" + "strconv" + "strings" +) + +// Param is a key/value header parameter +type Param struct { + Key string + Value string + Quote bool +} + +// String returns the formatted parameter +func (p Param) String() string { + if p.Quote { + return p.Key + "=" + strconv.Quote(p.Value) + } + return p.Key + "=" + p.Value +} + +// Format formats the parameters to be included in the header +func Format(pp ...Param) string { + var b strings.Builder + for i, p := range pp { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(p.String()) + } + return b.String() +} + +// Parse parses the header parameters +func Parse(s string) ([]Param, error) { + var pp []Param + br := bufio.NewReader(strings.NewReader(s)) + for i := 0; true; i++ { + // skip whitespace + if err := skipWhite(br); err != nil { + return nil, err + } + // see if there's more to read + if _, err := br.Peek(1); err == io.EOF { + break + } + // read key/value pair + p, err := parseParam(br, i == 0) + if err != nil { + return nil, fmt.Errorf("param: %w", err) + } + pp = append(pp, p) + } + return pp, nil +} + +func parseIdent(br *bufio.Reader) (string, error) { + var ident []byte + for { + b, err := br.ReadByte() + if err == io.EOF { + break + } + if err != nil { + return "", err + } + if !(('a' <= b && b <= 'z') || ('A' <= b && b <= 'Z') || '0' <= b && b <= '9' || b == '-') { + if err := br.UnreadByte(); err != nil { + return "", err + } + break + } + ident = append(ident, b) + } + return string(ident), nil +} + +func parseByte(br *bufio.Reader, expect byte) error { + b, err := br.ReadByte() + if err != nil { + if err == io.EOF { + return fmt.Errorf("expected '%c', got EOF", expect) + } + return err + } + if b != expect { + return fmt.Errorf("expected '%c', got '%c'", expect, b) + } + return nil +} + +func parseString(br *bufio.Reader) (string, error) { + var s []rune + // read the open quote + if err := parseByte(br, '"'); err != nil { + return "", err + } + // read the string + var escaped bool + for { + r, _, err := br.ReadRune() + if err != nil { + return "", err + } + if escaped { + s = append(s, r) + escaped = false + continue + } + if r == '\\' { + escaped = true + continue + } + // closing quote + if r == '"' { + break + } + s = append(s, r) + } + return string(s), nil +} + +func skipWhite(br *bufio.Reader) error { + for { + b, err := br.ReadByte() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + if b != ' ' { + return br.UnreadByte() + } + } +} + +func parseParam(br *bufio.Reader, first bool) (Param, error) { + // skip whitespace + if err := skipWhite(br); err != nil { + return Param{}, err + } + if !first { + // read the comma separator + if err := parseByte(br, ','); err != nil { + return Param{}, err + } + // skip whitespace + if err := skipWhite(br); err != nil { + return Param{}, err + } + } + // read the key + key, err := parseIdent(br) + if err != nil { + return Param{}, err + } + // skip whitespace + if err := skipWhite(br); err != nil { + return Param{}, err + } + // read the equals sign + if err := parseByte(br, '='); err != nil { + return Param{}, err + } + // skip whitespace + if err := skipWhite(br); err != nil { + return Param{}, err + } + // read the value + var value string + var quote bool + if b, _ := br.Peek(1); len(b) == 1 && b[0] == '"' { + quote = true + value, err = parseString(br) + } else { + value, err = parseIdent(br) + } + if err != nil { + return Param{}, err + } + return Param{Key: key, Value: value, Quote: quote}, nil +} diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index fd5817b7..86249395 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -3,7 +3,9 @@ package http import ( std_bufio "bufio" "context" + "crypto/rand" "encoding/base64" + "encoding/hex" "io" "net" "net/http" @@ -42,6 +44,12 @@ func HandleConnectionEx( authOk bool ) authorization := request.Header.Get("Proxy-Authorization") + if strings.HasPrefix(authorization, "Digest ") { + username, authOk = authenticator.VerifyDigest(request.Method, request.RequestURI, authorization[7:]) + if authOk { + ctx = auth.ContextWithUser(ctx, username) + } + } if strings.HasPrefix(authorization, "Basic ") { userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:]) userPswdArr := strings.SplitN(string(userPassword), ":", 2) @@ -56,10 +64,31 @@ func HandleConnectionEx( } if !authOk { // Since no one else is using the library, use a fixed realm until rewritten - err = responseWith( - request, http.StatusProxyAuthRequired, - "Proxy-Authenticate", `Basic realm="sing-box" charset="UTF-8"`, - ).Write(conn) + // define realm in common/auth package, still "sing-box" now + nonce := ""; + randomBytes := make([]byte, 16) + _, err = rand.Read(randomBytes) + if err == nil { + nonce = hex.EncodeToString(randomBytes) + } + if nonce == "" { + err = responseWithBody( + request, http.StatusProxyAuthRequired, + "Proxy authentication required", + "Content-Type", "text/plain; charset=utf-8", + "Proxy-Authenticate", "Basic realm=\"" + auth.Realm + "\"", + "Connection", "close", + ).Write(conn) + } else { + err = responseWithBody( + request, http.StatusProxyAuthRequired, + "Proxy authentication required", + "Content-Type", "text/plain; charset=utf-8", + "Proxy-Authenticate", "Basic realm=\"" + auth.Realm + "\"", + "Proxy-Authenticate", "Digest realm=\"" + auth.Realm + "\", nonce=\"" + nonce + "\", qop=\"auth\", stale=false", + "Connection", "close", + ).Write(conn) + } if err != nil { return err } @@ -68,7 +97,8 @@ func HandleConnectionEx( } else if authorization != "" { return E.New("http: authentication failed, Proxy-Authorization=", authorization) } else { - return E.New("http: authentication failed, no Proxy-Authorization header") + //return E.New("http: authentication failed, no Proxy-Authorization header") + continue } } } @@ -270,3 +300,31 @@ func responseWith(request *http.Request, statusCode int, headers ...string) *htt Header: header, } } + +func responseWithBody(request *http.Request, statusCode int, body string, headers ...string) *http.Response { + var header http.Header + if len(headers) > 0 { + header = make(http.Header) + for i := 0; i < len(headers); i += 2 { + header.Add(headers[i], headers[i+1]) + } + } + var bodyReadCloser io.ReadCloser + var bodyContentLength = int64(0) + if body != "" { + bodyReadCloser = io.NopCloser(strings.NewReader(body)) + bodyContentLength = int64(len(body)) + } + return &http.Response{ + StatusCode: statusCode, + Status: http.StatusText(statusCode), + Proto: request.Proto, + ProtoMajor: request.ProtoMajor, + ProtoMinor: request.ProtoMinor, + Header: header, + Body: bodyReadCloser, + ContentLength: bodyContentLength, + Close: true, + } +} + From 456ad80670240cb508da372fb503ba128568b852 Mon Sep 17 00:00:00 2001 From: Tommy Wu Date: Sun, 9 Feb 2025 07:48:36 +0800 Subject: [PATCH 4/5] if digest auth passed, skip basic auth check Change-Id: If0287548bdb9fa9ca7685f1f7f09653b68f22da5 --- protocol/http/handshake.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 86249395..bc64a39f 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -50,7 +50,7 @@ func HandleConnectionEx( ctx = auth.ContextWithUser(ctx, username) } } - if strings.HasPrefix(authorization, "Basic ") { + if !authOk && strings.HasPrefix(authorization, "Basic ") { userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:]) userPswdArr := strings.SplitN(string(userPassword), ":", 2) if len(userPswdArr) == 2 { From f6bfc05058a7eb58f384813b3501216a47f68828 Mon Sep 17 00:00:00 2001 From: Tommy Wu Date: Wed, 12 Feb 2025 10:49:20 +0800 Subject: [PATCH 5/5] add sha256 support Change-Id: I3885b2c616b2bcdeef4127e92747d9a87a6621eb --- common/auth/auth.go | 41 +++++++++++++++++++++++++++++--------- protocol/http/handshake.go | 3 ++- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/common/auth/auth.go b/common/auth/auth.go index 81597b6a..c1f03c95 100644 --- a/common/auth/auth.go +++ b/common/auth/auth.go @@ -2,6 +2,7 @@ package auth import ( "crypto/md5" + "crypto/sha256" "encoding/hex" "fmt" @@ -12,11 +13,13 @@ import ( const Realm = "sing-box" type Challenge struct { - Username string - Nonce string - CNonce string - Nc string - Response string + Username string + Nonce string + Algorithm string + Uri string + CNonce string + Nc string + Response string } type User struct { @@ -54,13 +57,23 @@ func (au *Authenticator) VerifyDigest(method string, uri string, s string) (stri if c.Username == "" || c.Nonce == "" || c.Nc == "" || c.CNonce == "" || c.Response == "" { return "", false } + if c.Uri != "" { + uri = c.Uri + } passwordList, ok := au.userMap[c.Username] if ok { for _, password := range passwordList { - ha1 := md5str(c.Username + ":" + Realm + ":" + password) - ha2 := md5str(method + ":" + uri) - resp := md5str(ha1 + ":" + c.Nonce + ":" + c.Nc + ":" + c.CNonce + ":auth:" + ha2) - if resp == c.Response { + resp := "" + if c.Algorithm == "SHA-256" { + ha1 := sha256str(c.Username + ":" + Realm + ":" + password) + ha2 := sha256str(method + ":" + uri) + resp = sha256str(ha1 + ":" + c.Nonce + ":" + c.Nc + ":" + c.CNonce + ":auth:" + ha2) + } else { + ha1 := md5str(c.Username + ":" + Realm + ":" + password) + ha2 := md5str(method + ":" + uri) + resp = md5str(ha1 + ":" + c.Nonce + ":" + c.Nc + ":" + c.CNonce + ":auth:" + ha2) + } + if resp != "" && resp == c.Response { return c.Username, true } } @@ -81,6 +94,10 @@ func ParseChallenge(s string) (*Challenge, error) { c.Username = p.Value case "nonce": c.Nonce = p.Value + case "algorithm": + c.Algorithm = p.Value + case "uri": + c.Uri = p.Value case "cnonce": c.CNonce = p.Value case "nc": @@ -97,3 +114,9 @@ func md5str(str string) string { h.Write([]byte(str)) return hex.EncodeToString(h.Sum(nil)) } + +func sha256str(str string) string { + h := sha256.New() + h.Write([]byte(str)) + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index bc64a39f..cb13c0b5 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -85,7 +85,8 @@ func HandleConnectionEx( "Proxy authentication required", "Content-Type", "text/plain; charset=utf-8", "Proxy-Authenticate", "Basic realm=\"" + auth.Realm + "\"", - "Proxy-Authenticate", "Digest realm=\"" + auth.Realm + "\", nonce=\"" + nonce + "\", qop=\"auth\", stale=false", + "Proxy-Authenticate", "Digest realm=\"" + auth.Realm + "\", nonce=\"" + nonce + "\", qop=\"auth\", algorithm=SHA-256, stale=false", + "Proxy-Authenticate", "Digest realm=\"" + auth.Realm + "\", nonce=\"" + nonce + "\", qop=\"auth\", algorithm=MD5, stale=false", "Connection", "close", ).Write(conn) }