Skip to content

add Digest authentication for http proxy server #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion common/auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
package auth

import "github.com/sagernet/sing/common"
import (
"crypto/md5"
"crypto/sha256"
"encoding/hex"
"fmt"

"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/param"
)

const Realm = "sing-box"

type Challenge struct {
Username string
Nonce string
Algorithm string
Uri string
CNonce string
Nc string
Response string
}

type User struct {
Username string
@@ -28,3 +48,75 @@ func (au *Authenticator) Verify(username string, password string) bool {
passwordList, ok := au.userMap[username]
return ok && common.Contains(passwordList, password)
}

func (au *Authenticator) VerifyDigest(method string, uri string, s string) (string, bool) {
c, err := ParseChallenge(s)
if err != nil {
return "", false
}
if c.Username == "" || c.Nonce == "" || c.Nc == "" || c.CNonce == "" || c.Response == "" {
return "", false
}
if c.Uri != "" {
uri = c.Uri
}
passwordList, ok := au.userMap[c.Username]
if ok {
for _, password := range passwordList {
resp := ""
if c.Algorithm == "SHA-256" {
ha1 := sha256str(c.Username + ":" + Realm + ":" + password)
ha2 := sha256str(method + ":" + uri)
resp = sha256str(ha1 + ":" + c.Nonce + ":" + c.Nc + ":" + c.CNonce + ":auth:" + ha2)
} else {
ha1 := md5str(c.Username + ":" + Realm + ":" + password)
ha2 := md5str(method + ":" + uri)
resp = md5str(ha1 + ":" + c.Nonce + ":" + c.Nc + ":" + c.CNonce + ":auth:" + ha2)
}
if resp != "" && resp == c.Response {
return c.Username, true
}
}
}
return "", false
}

func ParseChallenge(s string) (*Challenge, error) {
pp, err := param.Parse(s)
if err != nil {
return nil, fmt.Errorf("digest: invalid challenge: %w", err)
}
var c Challenge

for _, p := range pp {
switch p.Key {
case "username":
c.Username = p.Value
case "nonce":
c.Nonce = p.Value
case "algorithm":
c.Algorithm = p.Value
case "uri":
c.Uri = p.Value
case "cnonce":
c.CNonce = p.Value
case "nc":
c.Nc = p.Value
case "response":
c.Response = p.Value
}
}
return &c, nil
}

func md5str(str string) string {
h := md5.New()
h.Write([]byte(str))
return hex.EncodeToString(h.Sum(nil))
}

func sha256str(str string) string {
h := sha256.New()
h.Write([]byte(str))
return hex.EncodeToString(h.Sum(nil))
}
10 changes: 4 additions & 6 deletions common/json/badjson/merge_objects.go
Original file line number Diff line number Diff line change
@@ -2,9 +2,11 @@ package badjson

import (
"context"
"reflect"

E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
cJSON "github.com/sagernet/sing/common/json/internal/contextjson"
)

func MarshallObjects(objects ...any) ([]byte, error) {
@@ -31,16 +33,12 @@ func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error
}

func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error {
parentContent, err := newJSONObject(ctx, parentObject)
if err != nil {
return err
}
var content JSONObject
err = content.UnmarshalJSONContext(ctx, inputContent)
err := content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return err
}
for _, key := range parentContent.Keys() {
for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) {
content.Remove(key)
}
if object == nil {
20 changes: 20 additions & 0 deletions common/json/internal/contextjson/keys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package json

import (
"reflect"

"github.com/sagernet/sing/common"
)

func ObjectKeys(object reflect.Type) []string {
switch object.Kind() {
case reflect.Pointer:
return ObjectKeys(object.Elem())
case reflect.Struct:
default:
panic("invalid non-struct input")
}
return common.Map(cachedTypeFields(object).list, func(field field) string {
return field.name
})
}
25 changes: 25 additions & 0 deletions common/json/internal/contextjson/keys_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package json_test

import (
"reflect"
"testing"

json "github.com/sagernet/sing/common/json/internal/contextjson"

"github.com/stretchr/testify/require"
)

type MyObject struct {
Hello string `json:"hello,omitempty"`
MyWorld
MyWorld2 string `json:"-"`
}

type MyWorld struct {
World string `json:"world,omitempty"`
}

func TestObjectKeys(t *testing.T) {
keys := json.ObjectKeys(reflect.TypeOf(&MyObject{}))
require.Equal(t, []string{"hello", "world"}, keys)
}
189 changes: 189 additions & 0 deletions common/param/param.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package param

// code retrieve from https://github.com/icholy/digest/tree/master/internal/param

import (
"bufio"
"fmt"
"io"
"strconv"
"strings"
)

// Param is a key/value header parameter
type Param struct {
Key string
Value string
Quote bool
}

// String returns the formatted parameter
func (p Param) String() string {
if p.Quote {
return p.Key + "=" + strconv.Quote(p.Value)
}
return p.Key + "=" + p.Value
}

// Format formats the parameters to be included in the header
func Format(pp ...Param) string {
var b strings.Builder
for i, p := range pp {
if i > 0 {
b.WriteString(", ")
}
b.WriteString(p.String())
}
return b.String()
}

// Parse parses the header parameters
func Parse(s string) ([]Param, error) {
var pp []Param
br := bufio.NewReader(strings.NewReader(s))
for i := 0; true; i++ {
// skip whitespace
if err := skipWhite(br); err != nil {
return nil, err
}
// see if there's more to read
if _, err := br.Peek(1); err == io.EOF {
break
}
// read key/value pair
p, err := parseParam(br, i == 0)
if err != nil {
return nil, fmt.Errorf("param: %w", err)
}
pp = append(pp, p)
}
return pp, nil
}

func parseIdent(br *bufio.Reader) (string, error) {
var ident []byte
for {
b, err := br.ReadByte()
if err == io.EOF {
break
}
if err != nil {
return "", err
}
if !(('a' <= b && b <= 'z') || ('A' <= b && b <= 'Z') || '0' <= b && b <= '9' || b == '-') {
if err := br.UnreadByte(); err != nil {
return "", err
}
break
}
ident = append(ident, b)
}
return string(ident), nil
}

func parseByte(br *bufio.Reader, expect byte) error {
b, err := br.ReadByte()
if err != nil {
if err == io.EOF {
return fmt.Errorf("expected '%c', got EOF", expect)
}
return err
}
if b != expect {
return fmt.Errorf("expected '%c', got '%c'", expect, b)
}
return nil
}

func parseString(br *bufio.Reader) (string, error) {
var s []rune
// read the open quote
if err := parseByte(br, '"'); err != nil {
return "", err
}
// read the string
var escaped bool
for {
r, _, err := br.ReadRune()
if err != nil {
return "", err
}
if escaped {
s = append(s, r)
escaped = false
continue
}
if r == '\\' {
escaped = true
continue
}
// closing quote
if r == '"' {
break
}
s = append(s, r)
}
return string(s), nil
}

func skipWhite(br *bufio.Reader) error {
for {
b, err := br.ReadByte()
if err != nil {
if err == io.EOF {
return nil
}
return err
}
if b != ' ' {
return br.UnreadByte()
}
}
}

func parseParam(br *bufio.Reader, first bool) (Param, error) {
// skip whitespace
if err := skipWhite(br); err != nil {
return Param{}, err
}
if !first {
// read the comma separator
if err := parseByte(br, ','); err != nil {
return Param{}, err
}
// skip whitespace
if err := skipWhite(br); err != nil {
return Param{}, err
}
}
// read the key
key, err := parseIdent(br)
if err != nil {
return Param{}, err
}
// skip whitespace
if err := skipWhite(br); err != nil {
return Param{}, err
}
// read the equals sign
if err := parseByte(br, '='); err != nil {
return Param{}, err
}
// skip whitespace
if err := skipWhite(br); err != nil {
return Param{}, err
}
// read the value
var value string
var quote bool
if b, _ := br.Peek(1); len(b) == 1 && b[0] == '"' {
quote = true
value, err = parseString(br)
} else {
value, err = parseIdent(br)
}
if err != nil {
return Param{}, err
}
return Param{Key: key, Value: value, Quote: quote}, nil
}
6 changes: 2 additions & 4 deletions common/windnsapi/dnsapi_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
//go:build windows

package windnsapi

import (
"runtime"
"testing"

"github.com/stretchr/testify/require"
)

func TestDNSAPI(t *testing.T) {
if runtime.GOOS != "windows" {
t.SkipNow()
}
t.Parallel()
require.NoError(t, FlushResolverCache())
}
217 changes: 217 additions & 0 deletions common/winiphlpapi/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
//go:build windows

package winiphlpapi

import (
"context"
"encoding/binary"
M "github.com/sagernet/sing/common/metadata"
"net"
"net/netip"
"os"
"time"
"unsafe"

E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
)

func LoadEStats() error {
err := modiphlpapi.Load()
if err != nil {
return err
}
err = procGetTcpTable.Find()
if err != nil {
return err
}
err = procGetTcp6Table.Find()
if err != nil {
return err
}
err = procGetPerTcp6ConnectionEStats.Find()
if err != nil {
return err
}
err = procGetPerTcp6ConnectionEStats.Find()
if err != nil {
return err
}
err = procSetPerTcpConnectionEStats.Find()
if err != nil {
return err
}
err = procSetPerTcp6ConnectionEStats.Find()
if err != nil {
return err
}
return nil
}

func LoadExtendedTable() error {
err := modiphlpapi.Load()
if err != nil {
return err
}
err = procGetExtendedTcpTable.Find()
if err != nil {
return err
}
err = procGetExtendedUdpTable.Find()
if err != nil {
return err
}
return nil
}

func FindPid(network string, source netip.AddrPort) (uint32, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
if source.Addr().Is4() {
tcpTable, err := GetExtendedTcpTable()
if err != nil {
return 0, err
}
for _, row := range tcpTable {
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
} else {
tcpTable, err := GetExtendedTcp6Table()
if err != nil {
return 0, err
}
for _, row := range tcpTable {
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
}
case N.NetworkUDP:
if source.Addr().Is4() {
udpTable, err := GetExtendedUdpTable()
if err != nil {
return 0, err
}
for _, row := range udpTable {
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
} else {
udpTable, err := GetExtendedUdp6Table()
if err != nil {
return 0, err
}
for _, row := range udpTable {
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
}
}
return 0, E.New("process not found for ", source)
}

func WriteAndWaitAck(ctx context.Context, conn net.Conn, payload []byte) error {
source := M.AddrPortFromNet(conn.LocalAddr())
destination := M.AddrPortFromNet(conn.RemoteAddr())
if source.Addr().Is4() {
tcpTable, err := GetTcpTable()
if err != nil {
return err
}
var tcpRow *MibTcpRow
for _, row := range tcpTable {
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) ||
destination == netip.AddrPortFrom(DwordToAddr(row.DwRemoteAddr), DwordToPort(row.DwRemotePort)) {
tcpRow = &row
break
}
}
if tcpRow == nil {
return E.New("row not found for: ", source)
}
err = SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: true,
})
if err != nil {
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
}
defer SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: false,
})
_, err = conn.Write(payload)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
eStstsSendBuffer, err := GetPerTcpConnectionEStatsSendBuffer(tcpRow)
if err != nil {
return err
}
if eStstsSendBuffer.CurRetxQueue == 0 {
return nil
}
time.Sleep(10 * time.Millisecond)
}
} else {
tcpTable, err := GetTcp6Table()
if err != nil {
return err
}
var tcpRow *MibTcp6Row
for _, row := range tcpTable {
if source == netip.AddrPortFrom(netip.AddrFrom16(row.LocalAddr), DwordToPort(row.LocalPort)) ||
destination == netip.AddrPortFrom(netip.AddrFrom16(row.RemoteAddr), DwordToPort(row.RemotePort)) {
tcpRow = &row
break
}
}
if tcpRow == nil {
return E.New("row not found for: ", source)
}
err = SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: true,
})
if err != nil {
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
}
defer SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: false,
})
_, err = conn.Write(payload)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
eStstsSendBuffer, err := GetPerTcp6ConnectionEStatsSendBuffer(tcpRow)
if err != nil {
return err
}
if eStstsSendBuffer.CurRetxQueue == 0 {
return nil
}
time.Sleep(10 * time.Millisecond)
}
}
}

func DwordToAddr(addr uint32) netip.Addr {
return netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(&addr)))
}

func DwordToPort(dword uint32) uint16 {
return binary.BigEndian.Uint16((*[4]byte)(unsafe.Pointer(&dword))[:])
}
313 changes: 313 additions & 0 deletions common/winiphlpapi/iphlpapi.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
//go:build windows

package winiphlpapi

import (
"errors"
"os"
"unsafe"

"golang.org/x/sys/windows"
)

const (
TcpTableBasicListener uint32 = iota
TcpTableBasicConnections
TcpTableBasicAll
TcpTableOwnerPidListener
TcpTableOwnerPidConnections
TcpTableOwnerPidAll
TcpTableOwnerModuleListener
TcpTableOwnerModuleConnections
TcpTableOwnerModuleAll
)

const (
UdpTableBasic uint32 = iota
UdpTableOwnerPid
UdpTableOwnerModule
)

const (
TcpConnectionEstatsSynOpts uint32 = iota
TcpConnectionEstatsData
TcpConnectionEstatsSndCong
TcpConnectionEstatsPath
TcpConnectionEstatsSendBuff
TcpConnectionEstatsRec
TcpConnectionEstatsObsRec
TcpConnectionEstatsBandwidth
TcpConnectionEstatsFineRtt
TcpConnectionEstatsMaximum
)

type MibTcpTable struct {
DwNumEntries uint32
Table [1]MibTcpRow
}

type MibTcpRow struct {
DwState uint32
DwLocalAddr uint32
DwLocalPort uint32
DwRemoteAddr uint32
DwRemotePort uint32
}

type MibTcp6Table struct {
DwNumEntries uint32
Table [1]MibTcp6Row
}

type MibTcp6Row struct {
State uint32
LocalAddr [16]byte
LocalScopeId uint32
LocalPort uint32
RemoteAddr [16]byte
RemoteScopeId uint32
RemotePort uint32
}

type MibTcpTableOwnerPid struct {
DwNumEntries uint32
Table [1]MibTcpRowOwnerPid
}

type MibTcpRowOwnerPid struct {
DwState uint32
DwLocalAddr uint32
DwLocalPort uint32
DwRemoteAddr uint32
DwRemotePort uint32
DwOwningPid uint32
}

type MibTcp6TableOwnerPid struct {
DwNumEntries uint32
Table [1]MibTcp6RowOwnerPid
}

type MibTcp6RowOwnerPid struct {
UcLocalAddr [16]byte
DwLocalScopeId uint32
DwLocalPort uint32
UcRemoteAddr [16]byte
DwRemoteScopeId uint32
DwRemotePort uint32
DwState uint32
DwOwningPid uint32
}

type MibUdpTableOwnerPid struct {
DwNumEntries uint32
Table [1]MibUdpRowOwnerPid
}

type MibUdpRowOwnerPid struct {
DwLocalAddr uint32
DwLocalPort uint32
DwOwningPid uint32
}

type MibUdp6TableOwnerPid struct {
DwNumEntries uint32
Table [1]MibUdp6RowOwnerPid
}

type MibUdp6RowOwnerPid struct {
UcLocalAddr [16]byte
DwLocalScopeId uint32
DwLocalPort uint32
DwOwningPid uint32
}

type TcpEstatsSendBufferRodV0 struct {
CurRetxQueue uint64
MaxRetxQueue uint64
CurAppWQueue uint64
MaxAppWQueue uint64
}

type TcpEstatsSendBuffRwV0 struct {
EnableCollection bool
}

const (
offsetOfMibTcpTable = unsafe.Offsetof(MibTcpTable{}.Table)
offsetOfMibTcp6Table = unsafe.Offsetof(MibTcp6Table{}.Table)
offsetOfMibTcpTableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table)
offsetOfMibTcp6TableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table)
offsetOfMibUdpTableOwnerPid = unsafe.Offsetof(MibUdpTableOwnerPid{}.Table)
offsetOfMibUdp6TableOwnerPid = unsafe.Offsetof(MibUdp6TableOwnerPid{}.Table)
sizeOfTcpEstatsSendBuffRwV0 = unsafe.Sizeof(TcpEstatsSendBuffRwV0{})
sizeOfTcpEstatsSendBufferRodV0 = unsafe.Sizeof(TcpEstatsSendBufferRodV0{})
)

func GetTcpTable() ([]MibTcpRow, error) {
var size uint32
err := getTcpTable(nil, &size, false)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, err
}
for {
table := make([]byte, size)
err = getTcpTable(&table[0], &size, false)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, err
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcpRow)(unsafe.Pointer(&table[offsetOfMibTcpTable])), dwNumEntries), nil
}
}

func GetTcp6Table() ([]MibTcp6Row, error) {
var size uint32
err := getTcp6Table(nil, &size, false)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, err
}
for {
table := make([]byte, size)
err = getTcp6Table(&table[0], &size, false)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, err
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcp6Row)(unsafe.Pointer(&table[offsetOfMibTcp6Table])), dwNumEntries), nil
}
}

func GetExtendedTcpTable() ([]MibTcpRowOwnerPid, error) {
var size uint32
err := getExtendedTcpTable(nil, &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcpTableOwnerPid])), dwNumEntries), nil
}
}

func GetExtendedTcp6Table() ([]MibTcp6RowOwnerPid, error) {
var size uint32
err := getExtendedTcpTable(nil, &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcp6TableOwnerPid])), dwNumEntries), nil
}
}

func GetExtendedUdpTable() ([]MibUdpRowOwnerPid, error) {
var size uint32
err := getExtendedUdpTable(nil, &size, false, windows.AF_INET, UdpTableOwnerPid, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET, UdpTableOwnerPid, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibUdpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdpTableOwnerPid])), dwNumEntries), nil
}
}

func GetExtendedUdp6Table() ([]MibUdp6RowOwnerPid, error) {
var size uint32
err := getExtendedUdpTable(nil, &size, false, windows.AF_INET6, UdpTableOwnerPid, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET6, UdpTableOwnerPid, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibUdp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdp6TableOwnerPid])), dwNumEntries), nil
}
}

func GetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow) (*TcpEstatsSendBufferRodV0, error) {
var rod TcpEstatsSendBufferRodV0
err := getPerTcpConnectionEStats(row,
TcpConnectionEstatsSendBuff,
0,
0,
0,
0,
0,
0,
uintptr(unsafe.Pointer(&rod)),
0,
uint64(sizeOfTcpEstatsSendBufferRodV0),
)
if err != nil {
return nil, err
}
return &rod, nil
}

func GetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row) (*TcpEstatsSendBufferRodV0, error) {
var rod TcpEstatsSendBufferRodV0
err := getPerTcp6ConnectionEStats(row,
TcpConnectionEstatsSendBuff,
0,
0,
0,
0,
0,
0,
uintptr(unsafe.Pointer(&rod)),
0,
uint64(sizeOfTcpEstatsSendBufferRodV0),
)
if err != nil {
return nil, err
}
return &rod, nil
}

func SetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow, rw *TcpEstatsSendBuffRwV0) error {
return setPerTcpConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0)
}

func SetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row, rw *TcpEstatsSendBuffRwV0) error {
return setPerTcp6ConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0)
}
90 changes: 90 additions & 0 deletions common/winiphlpapi/iphlpapi_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//go:build windows

package winiphlpapi_test

import (
"context"
"net"
"syscall"
"testing"

M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/winiphlpapi"

"github.com/stretchr/testify/require"
)

func TestFindPidTcp4(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}

func TestFindPidTcp6(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "[::1]:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}

func TestFindPidUdp4(t *testing.T) {
t.Parallel()
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}

func TestFindPidUdp6(t *testing.T) {
t.Parallel()
conn, err := net.ListenPacket("udp", "[::1]:0")
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}

func TestWaitAck4(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello"))
require.NoError(t, err)
}

func TestWaitAck6(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "[::1]:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello"))
require.NoError(t, err)
}
27 changes: 27 additions & 0 deletions common/winiphlpapi/syscall_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package winiphlpapi

//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go

// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcptable
//sys getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcpTable

// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcp6table
//sys getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcp6Table

// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcpconnectionestats
//sys getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcpConnectionEStats

// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcp6connectionestats
//sys getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcp6ConnectionEStats

// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcpconnectionestats
//sys setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcpConnectionEStats

// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcp6connectionestats
//sys setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcp6ConnectionEStats

// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
//sys getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedTcpTable

// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedudptable
//sys getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedUdpTable
131 changes: 131 additions & 0 deletions common/winiphlpapi/zsyscall_windows.go
71 changes: 65 additions & 6 deletions protocol/http/handshake.go
Original file line number Diff line number Diff line change
@@ -3,7 +3,9 @@ package http
import (
std_bufio "bufio"
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"io"
"net"
"net/http"
@@ -42,7 +44,13 @@ func HandleConnectionEx(
authOk bool
)
authorization := request.Header.Get("Proxy-Authorization")
if strings.HasPrefix(authorization, "Basic ") {
if strings.HasPrefix(authorization, "Digest ") {
username, authOk = authenticator.VerifyDigest(request.Method, request.RequestURI, authorization[7:])
if authOk {
ctx = auth.ContextWithUser(ctx, username)
}
}
if !authOk && strings.HasPrefix(authorization, "Basic ") {
userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
userPswdArr := strings.SplitN(string(userPassword), ":", 2)
if len(userPswdArr) == 2 {
@@ -56,10 +64,32 @@ func HandleConnectionEx(
}
if !authOk {
// Since no one else is using the library, use a fixed realm until rewritten
err = responseWith(
request, http.StatusProxyAuthRequired,
"Proxy-Authenticate", `Basic realm="sing-box" charset="UTF-8"`,
).Write(conn)
// define realm in common/auth package, still "sing-box" now
nonce := "";
randomBytes := make([]byte, 16)
_, err = rand.Read(randomBytes)
if err == nil {
nonce = hex.EncodeToString(randomBytes)
}
if nonce == "" {
err = responseWithBody(
request, http.StatusProxyAuthRequired,
"Proxy authentication required",
"Content-Type", "text/plain; charset=utf-8",
"Proxy-Authenticate", "Basic realm=\"" + auth.Realm + "\"",
"Connection", "close",
).Write(conn)
} else {
err = responseWithBody(
request, http.StatusProxyAuthRequired,
"Proxy authentication required",
"Content-Type", "text/plain; charset=utf-8",
"Proxy-Authenticate", "Basic realm=\"" + auth.Realm + "\"",
"Proxy-Authenticate", "Digest realm=\"" + auth.Realm + "\", nonce=\"" + nonce + "\", qop=\"auth\", algorithm=SHA-256, stale=false",
"Proxy-Authenticate", "Digest realm=\"" + auth.Realm + "\", nonce=\"" + nonce + "\", qop=\"auth\", algorithm=MD5, stale=false",
"Connection", "close",
).Write(conn)
}
if err != nil {
return err
}
@@ -68,7 +98,8 @@ func HandleConnectionEx(
} else if authorization != "" {
return E.New("http: authentication failed, Proxy-Authorization=", authorization)
} else {
return E.New("http: authentication failed, no Proxy-Authorization header")
//return E.New("http: authentication failed, no Proxy-Authorization header")
continue
}
}
}
@@ -270,3 +301,31 @@ func responseWith(request *http.Request, statusCode int, headers ...string) *htt
Header: header,
}
}

func responseWithBody(request *http.Request, statusCode int, body string, headers ...string) *http.Response {
var header http.Header
if len(headers) > 0 {
header = make(http.Header)
for i := 0; i < len(headers); i += 2 {
header.Add(headers[i], headers[i+1])
}
}
var bodyReadCloser io.ReadCloser
var bodyContentLength = int64(0)
if body != "" {
bodyReadCloser = io.NopCloser(strings.NewReader(body))
bodyContentLength = int64(len(body))
}
return &http.Response{
StatusCode: statusCode,
Status: http.StatusText(statusCode),
Proto: request.Proto,
ProtoMajor: request.ProtoMajor,
ProtoMinor: request.ProtoMinor,
Header: header,
Body: bodyReadCloser,
ContentLength: bodyContentLength,
Close: true,
}
}