Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 55 additions & 51 deletions dialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,81 +3,85 @@ package dialer
import (
"context"
"net"
"sync"
"syscall"

"go.uber.org/atomic"
)

// DefaultDialer is the default Dialer and is used by DialContext and ListenPacket.
var DefaultDialer = &Dialer{
InterfaceName: atomic.NewString(""),
InterfaceIndex: atomic.NewInt32(0),
RoutingMark: atomic.NewInt32(0),
}

type Dialer struct {
InterfaceName *atomic.String
InterfaceIndex *atomic.Int32
RoutingMark *atomic.Int32
}

type Options struct {
// InterfaceName is the name of interface/device to bind.
// If a socket is bound to an interface, only packets received
// from that particular interface are processed by the socket.
InterfaceName string
// DefaultDialer is the package-level default Dialer.
// It is used by DialContext and ListenPacket.
var DefaultDialer = &Dialer{}

// InterfaceIndex is the index of interface/device to bind.
// It is almost the same as InterfaceName except it uses the
// index of the interface instead of the name.
InterfaceIndex int

// RoutingMark is the mark for each packet sent through this
// socket. Changing the mark can be used for mark-based routing
// without netfilter or for packet filtering.
RoutingMark int
// RegisterSockOpt registers a socket option on the DefaultDialer.
func RegisterSockOpt(opt SocketOption) {
DefaultDialer.RegisterSockOpt(opt)
}

// DialContext is a wrapper around DefaultDialer.DialContext.
// DialContext dials using the DefaultDialer.
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return DefaultDialer.DialContext(ctx, network, address)
}

// ListenPacket is a wrapper around DefaultDialer.ListenPacket.
// ListenPacket listens using the DefaultDialer.
func ListenPacket(network, address string) (net.PacketConn, error) {
return DefaultDialer.ListenPacket(network, address)
}

func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.DialContextWithOptions(ctx, network, address, &Options{
InterfaceName: d.InterfaceName.Load(),
InterfaceIndex: int(d.InterfaceIndex.Load()),
RoutingMark: int(d.RoutingMark.Load()),
})
// Dialer applies registered SocketOptions to all dials/listens.
type Dialer struct {
optsMu sync.Mutex
atomicOpts atomic.Value
}

func (*Dialer) DialContextWithOptions(ctx context.Context, network, address string, opts *Options) (net.Conn, error) {
d := &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
return setSocketOptions(network, address, c, opts)
},
// New creates a new Dialer with the given initial socket options.
func New(opts ...SocketOption) *Dialer {
d := &Dialer{}
for _, opt := range opts {
d.RegisterSockOpt(opt)
}
return d.DialContext(ctx, network, address)
return d
}

func (d *Dialer) ListenPacket(network, address string) (net.PacketConn, error) {
return d.ListenPacketWithOptions(network, address, &Options{
InterfaceName: d.InterfaceName.Load(),
InterfaceIndex: int(d.InterfaceIndex.Load()),
RoutingMark: int(d.RoutingMark.Load()),
})
// RegisterSockOpt registers a socket option on the Dialer.
func (d *Dialer) RegisterSockOpt(opt SocketOption) {
d.optsMu.Lock()
opts, _ := d.atomicOpts.Load().([]SocketOption)
d.atomicOpts.Store(append(opts, opt))
d.optsMu.Unlock()
}

func (*Dialer) ListenPacketWithOptions(network, address string, opts *Options) (net.PacketConn, error) {
func (d *Dialer) applySockOpts(network string, address string, c syscall.RawConn) error {
opts, _ := d.atomicOpts.Load().([]SocketOption)
if len(opts) == 0 {
return nil
}
// Skip non-global-unicast IPs (e.g. loopback, link-local).
if host, _, err := net.SplitHostPort(address); err == nil {
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return nil
}
}
for _, opt := range opts {
if err := opt.Apply(network, address, c); err != nil {
return err
}
}
return nil
}

// DialContext behaves like net.Dialer.DialContext, applying registered SocketOptions.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
nd := &net.Dialer{
Control: d.applySockOpts,
}
return nd.DialContext(ctx, network, address)
}

// ListenPacket behaves like net.ListenConfig.ListenPacket, applying registered SocketOptions.
func (d *Dialer) ListenPacket(network, address string) (net.PacketConn, error) {
lc := &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
return setSocketOptions(network, address, c, opts)
},
Control: d.applySockOpts,
}
return lc.ListenPacket(context.Background(), network, address)
}
55 changes: 33 additions & 22 deletions dialer/sockopt.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,41 @@
package dialer

func isTCPSocket(network string) bool {
switch network {
case "tcp", "tcp4", "tcp6":
return true
default:
return false
}
import (
"errors"
"syscall"
)

var _ SocketOption = SocketOptionFunc(nil)

// SocketOption applies a socket-level configuration to a network connection
// during dialing or listening, via syscall.RawConn.
type SocketOption interface {
Apply(network, address string, c syscall.RawConn) error
}

func isUDPSocket(network string) bool {
switch network {
case "udp", "udp4", "udp6":
return true
default:
return false
}
// SocketOptionFunc adapts a function to a SocketOption.
type SocketOptionFunc func(network, address string, c syscall.RawConn) error

func (f SocketOptionFunc) Apply(network, address string, c syscall.RawConn) error {
return f(network, address, c)
}

// UnsupportedSocketOption is a sentinel SocketOption that always reports
// ErrUnsupported when applied.
var UnsupportedSocketOption = SocketOptionFunc(unsupportedSocketOpt)

func unsupportedSocketOpt(_, _ string, _ syscall.RawConn) error {
return errors.ErrUnsupported
}

func isICMPSocket(network string) bool {
switch network {
case "ip:icmp", "ip4:icmp", "ip6:ipv6-icmp":
return true
case "ip4", "ip6":
return true
default:
return false
// rawConnControl runs f with the file descriptor obtained via RawConn.Control
// and correctly propagates errors returned from f.
func rawConnControl(c syscall.RawConn, f func(uintptr) error) error {
var innerErr error
if err := c.Control(func(fd uintptr) {
innerErr = f(fd)
}); err != nil {
return err
}
return innerErr
}
44 changes: 12 additions & 32 deletions dialer/sockopt_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,19 @@ import (
"golang.org/x/sys/unix"
)

func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
return err
}

var innerErr error
err = c.Control(func(fd uintptr) {
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return
}

if opts.InterfaceIndex == 0 && opts.InterfaceName != "" {
if iface, err := net.InterfaceByName(opts.InterfaceName); err == nil {
opts.InterfaceIndex = iface.Index
}
}

if opts.InterfaceIndex != 0 {
func WithBindToInterface(iface *net.Interface) SocketOption {
index := iface.Index
return SocketOptionFunc(func(network, _ string, c syscall.RawConn) error {
return rawConnControl(c, func(fd uintptr) error {
switch network {
case "tcp4", "udp4":
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, opts.InterfaceIndex)
case "tcp6", "udp6":
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, opts.InterfaceIndex)
case "ip4", "tcp4", "udp4":
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
case "ip6", "tcp6", "udp6":
return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, index)
}
if innerErr != nil {
return
}
}
return nil
})
})

if innerErr != nil {
err = innerErr
}
return err
}

func WithRoutingMark(_ int) SocketOption { return UnsupportedSocketOption }
27 changes: 6 additions & 21 deletions dialer/sockopt_freebsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,12 @@ import (
"golang.org/x/sys/unix"
)

func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
return err
}
func WithBindToInterface(_ *net.Interface) SocketOption { return UnsupportedSocketOption }

var innerErr error
err = c.Control(func(fd uintptr) {
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return
}

if opts.RoutingMark != 0 {
if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_USER_COOKIE, opts.RoutingMark); innerErr != nil {
return
}
}
func WithRoutingMark(mark int) SocketOption {
return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error {
return rawConnControl(c, func(fd uintptr) error {
return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_USER_COOKIE, mark)
})
})

if innerErr != nil {
err = innerErr
}
return err
}
45 changes: 13 additions & 32 deletions dialer/sockopt_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,19 @@ import (
"golang.org/x/sys/unix"
)

func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
return err
}

var innerErr error
err = c.Control(func(fd uintptr) {
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return
}

if opts.InterfaceName == "" && opts.InterfaceIndex != 0 {
if iface, err := net.InterfaceByIndex(opts.InterfaceIndex); err == nil {
opts.InterfaceName = iface.Name
}
}

if opts.InterfaceName != "" {
if innerErr = unix.BindToDevice(int(fd), opts.InterfaceName); innerErr != nil {
return
}
}
if opts.RoutingMark != 0 {
if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, opts.RoutingMark); innerErr != nil {
return
}
}
func WithBindToInterface(iface *net.Interface) SocketOption {
device := iface.Name
return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error {
return rawConnControl(c, func(fd uintptr) error {
return unix.BindToDevice(int(fd), device)
})
})
}

if innerErr != nil {
err = innerErr
}
return err
func WithRoutingMark(mark int) SocketOption {
return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error {
return rawConnControl(c, func(fd uintptr) error {
return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, mark)
})
})
}
27 changes: 6 additions & 21 deletions dialer/sockopt_openbsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,12 @@ import (
"golang.org/x/sys/unix"
)

func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) && !isICMPSocket(network) {
return err
}
func WithBindToInterface(_ *net.Interface) SocketOption { return UnsupportedSocketOption }

var innerErr error
err = c.Control(func(fd uintptr) {
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return
}

if opts.RoutingMark != 0 {
if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RTABLE, opts.RoutingMark); innerErr != nil {
return
}
}
func WithRoutingMark(mark int) SocketOption {
return SocketOptionFunc(func(_, _ string, c syscall.RawConn) error {
return rawConnControl(c, func(fd uintptr) error {
return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RTABLE, mark)
})
})

if innerErr != nil {
err = innerErr
}
return err
}
9 changes: 0 additions & 9 deletions dialer/sockopt_others.go

This file was deleted.

Loading
Loading