Skip to content

Commit

Permalink
UDS: Make all remote addr 0.0.0.0 (#4390)
Browse files Browse the repository at this point in the history
#4389 (comment)

---------

Co-authored-by: RPRX <[email protected]>
  • Loading branch information
Fangliding and RPRX authored Feb 13, 2025
1 parent 94c7970 commit 22c50a7
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 37 deletions.
14 changes: 1 addition & 13 deletions app/proxyman/inbound/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package inbound

import (
"context"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -464,19 +463,8 @@ func (w *dsWorker) callback(conn stat.Connection) {
WriteCounter: w.downlinkCounter,
}
}
// For most of time, unix obviously have no source addr. But if we leave it empty, it will cause panic.
// So we use gateway as source for log.
// However, there are some special situations where a valid source address might be available.
// Such as the source address parsed from X-Forwarded-For in websocket.
// In that case, we keep it.
var source net.Destination
if !strings.Contains(conn.RemoteAddr().String(), "unix") {
source = net.DestinationFromAddr(conn.RemoteAddr())
} else {
source = net.UnixDestination(w.address)
}
ctx = session.ContextWithInbound(ctx, &session.Inbound{
Source: source,
Source: net.DestinationFromAddr(conn.RemoteAddr()),
Gateway: net.UnixDestination(w.address),
Tag: w.tag,
Conn: conn,
Expand Down
10 changes: 4 additions & 6 deletions common/net/destination.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,10 @@ func UnixDestination(address Address) Destination {
// NetAddr returns the network address in this Destination in string form.
func (d Destination) NetAddr() string {
addr := ""
if d.Address != nil {
if d.Network == Network_TCP || d.Network == Network_UDP {
addr = d.Address.String() + ":" + d.Port.String()
} else if d.Network == Network_UNIX {
addr = d.Address.String()
}
if d.Network == Network_TCP || d.Network == Network_UDP {
addr = d.Address.String() + ":" + d.Port.String()
} else if d.Network == Network_UNIX {
addr = d.Address.String()
}
return addr
}
Expand Down
56 changes: 38 additions & 18 deletions transport/internet/system_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,6 @@ type DefaultListener struct {
controllers []control.Func
}

type combinedListener struct {
net.Listener
locker *FileLocker // for unix domain socket
}

func (cl *combinedListener) Close() error {
if cl.locker != nil {
cl.locker.Release()
cl.locker = nil
}
return cl.Listener.Close()
}

func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []control.Func) func(network, address string, c syscall.RawConn) error {
return func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
Expand All @@ -54,6 +41,40 @@ func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []co
}
}

// For some reason, other component of ray will assume the listener is a TCP listener and have valid remote address.
// But in fact it doesn't. So we need to wrap the listener to make it return 0.0.0.0(unspecified) as remote address.
// If other issues encountered, we should able to fix it here.
type listenUDSWrapper struct {
net.Listener
locker *FileLocker
}

func (l *listenUDSWrapper) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &listenUDSWrapperConn{Conn: conn}, nil
}

func (l *listenUDSWrapper) Close() error {
if l.locker != nil {
l.locker.Release()
l.locker = nil
}
return l.Listener.Close()
}

type listenUDSWrapperConn struct {
net.Conn
}

func (conn *listenUDSWrapperConn) RemoteAddr() net.Addr {
return &net.TCPAddr{
IP: []byte{0, 0, 0, 0},
}
}

func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (l net.Listener, err error) {
var lc net.ListenConfig
var network, address string
Expand Down Expand Up @@ -113,9 +134,9 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
callback = func(l net.Listener, err error) (net.Listener, error) {
if err != nil {
locker.Release()
return l, err
return nil, err
}
l = &combinedListener{Listener: l, locker: locker}
l = &listenUDSWrapper{Listener: l, locker: locker}
if filePerm == nil {
return l, nil
}
Expand All @@ -129,9 +150,8 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
}
}

l, err = lc.Listen(ctx, network, address)
l, err = callback(l, err)
if sockopt != nil && sockopt.AcceptProxyProtocol {
l, err = callback(lc.Listen(ctx, network, address))
if err == nil && sockopt != nil && sockopt.AcceptProxyProtocol {
policyFunc := func(upstream net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil }
l = &proxyproto.Listener{Listener: l, Policy: policyFunc}
}
Expand Down

0 comments on commit 22c50a7

Please sign in to comment.