diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4e3fdc1a0..1e93758d6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,6 +67,10 @@ jobs: -v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \ test-container + - name: Run integration tests in test container + run: | + docker run --rm --entrypoint go test-container test -tags=integration ./internal/restrictednet + - name: Verify dev cross platform compatibility run: docker build --target xcompile . diff --git a/AGENTS.md b/AGENTS.md index 0e9902334..fb3f7d8ea 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -50,6 +50,7 @@ Guidance for coding agents working in this repository. - Prefer splitting a code line only when it triggers the `lll` linter, do not split a command or arguments list for each element - Use `netip` types instead of `net` types whenever possible - Use constants instead of variables whenever possible, especially function-local inline constants. +- Prefer using pure functions over methods when possible. Especially if the method does not need any fields from the receiving struct, it should be a pure function. - Do not use `time.Sleep`, prefer using a `time.Timer` with a `select` statement also listening on a context cancelation - `panic`: - should only be used when a programming error is encountered and you should NOT return errors for programming errors (such as passing nil objects) @@ -115,6 +116,7 @@ Mocking works with the `go.uber.org/mock` library, and the `mockgen` tool. - **Never** use `.AnyTimes()` on mocks. Always define the number of times a certain mock call should be called, with `.Times(3)` for example. - **Always** set the `.Return(...)` on the mock if the function returns something. - Avoid using **mock helpers** functions, prefer a bit of repetition than tight coupling and dependency + - Always define the gomock controller `ctrl` in the subtest and not in the parent test, or a subtest mock failing will crash all the other subtests. ### main.go @@ -127,6 +129,7 @@ The Go formatter used is gofumpt. ### Errors - Always prefer wrapping errors with some context with `fmt.Errorf("doing this: %w", err)` +- Use `errors.New("error message")` when creating a 'bottom' constant string error without additional context, instead of `fmt.Errorf` - In rare cases, you can just use `return err` notably: - If the function is called **recursively**, since we don't wrap the wrapping multiple times for each recursion - If the current function only statement is the call to another function, for example: @@ -179,6 +182,8 @@ The Go formatter used is gofumpt. - Do not use `http.DefaultClient`, use a custom `*http.Client` with a fixed timeout and share with dependency injections. - Do not check for injected dependencies being `nil`, prefer to just panic on a nil pointer. By default it's fine to panic if a developer injects a dependency `nil`. `nil` does not mean use a default. +- Prefer using a `switch { case ...}` statement over multiple consecutive `if` statements to have shorter code. +- Prefer using `[...]T` instead of `[]T` when the length is fixed and known at compile time, to avoid unnecessary allocations. ## Validation checklist diff --git a/internal/firewall/interfaces.go b/internal/firewall/interfaces.go index d9c830da2..064046c7b 100644 --- a/internal/firewall/interfaces.go +++ b/internal/firewall/interfaces.go @@ -28,6 +28,8 @@ type firewallImpl interface { //nolint:interfacebloat AcceptIpv6MulticastOutput(ctx context.Context, intf string) error AcceptOutput(ctx context.Context, protocol, intf string, ip netip.Addr, port uint16, remove bool) error + AcceptOutputFromIPPortToIPPort(ctx context.Context, protocol, intf string, + source, destination netip.AddrPort, remove bool) error AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr, subnet netip.Prefix, remove bool) error AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error diff --git a/internal/firewall/iptables/iptables.go b/internal/firewall/iptables/iptables.go index c48879295..a6b40cf72 100644 --- a/internal/firewall/iptables/iptables.go +++ b/internal/firewall/iptables/iptables.go @@ -2,6 +2,7 @@ package iptables import ( "context" + "errors" "fmt" "io" "net/netip" @@ -177,6 +178,29 @@ func (c *Config) AcceptOutput(ctx context.Context, return c.runIP6tablesInstruction(ctx, instruction) } +func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context, + protocol, intf string, source, destination netip.AddrPort, remove bool, +) error { + if source.Addr().BitLen() != destination.Addr().BitLen() { + return errors.New("source and destination address families do not match") + } + + interfaceFlag := "-o " + intf + if intf == "*" { // all interfaces + interfaceFlag = "" + } + + instruction := fmt.Sprintf("%s OUTPUT %s -s %s -d %s -p %s -m %s --sport %d --dport %d -j ACCEPT", + appendOrDelete(remove), interfaceFlag, source.Addr(), destination.Addr(), + protocol, protocol, source.Port(), destination.Port()) + if destination.Addr().Is4() { + return c.runIptablesInstruction(ctx, instruction) + } else if c.ip6Tables == "" { + return fmt.Errorf("accept output from %s to %s: %s", source, destination, needIP6Tables) + } + return c.runIP6tablesInstruction(ctx, instruction) +} + // AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet // on the interface intf. If intf is empty, it is set to "*" which means all interfaces. // If remove is true, the rule is removed instead of added. diff --git a/internal/firewall/wrappers.go b/internal/firewall/wrappers.go index 0167eba0a..435df9b18 100644 --- a/internal/firewall/wrappers.go +++ b/internal/firewall/wrappers.go @@ -25,3 +25,10 @@ func (c *Config) AcceptOutput(ctx context.Context, protocol, intf string, ) error { return c.impl.AcceptOutput(ctx, protocol, intf, ip, port, remove) } + +func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context, + protocol, intf string, source, destination netip.AddrPort, remove bool, +) error { + return c.impl.AcceptOutputFromIPPortToIPPort(ctx, protocol, intf, + source, destination, remove) +} diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go new file mode 100644 index 000000000..040557583 --- /dev/null +++ b/internal/restrictednet/client.go @@ -0,0 +1,82 @@ +package restrictednet + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "net/netip" + "strconv" + + "github.com/qdm12/dns/v2/pkg/provider" +) + +// Client is a client for making restricted network requests, +// such as opening temporary firewall rules for HTTPS connections. +// It is not meant to be high performance, although it can be used for +// multiple requests and concurrently. +type Client struct { + outboundInterface string + ipv6Supported bool + firewall Firewall + dohServers []provider.DoHServer +} + +func New(settings Settings) *Client { + if err := settings.validate(); err != nil { + panic(fmt.Sprintf("invalid settings: %v", err)) // programming error + } + dohServers := make([]provider.DoHServer, len(settings.UpstreamResolvers)) + for i, upstreamResolver := range settings.UpstreamResolvers { + dohServers[i] = upstreamResolver.DoH + } + + return &Client{ + outboundInterface: settings.DefaultInterface, + ipv6Supported: *settings.IPv6Supported, + firewall: settings.Firewall, + dohServers: dohServers, + } +} + +// OpenHTTPSByHostname opens an https connection through the firewall, +// to the hostname which in the format `host:port`. The returned cleanup +// function must be called to remove the temporary firewall rule and close connections. +// It first resolves the domain in hostname using DNS over HTTPS and then opens +// the restricted HTTPS connection to the resolved IP. +func (c *Client) OpenHTTPSByHostname(ctx context.Context, hostname string) ( + httpClient *http.Client, cleanup func() error, err error, +) { + host, portStr, err := net.SplitHostPort(hostname) + if err != nil { + return nil, nil, fmt.Errorf("splitting host and port: %w", err) + } + resolvedIPs, err := c.ResolveName(ctx, host) + if err != nil { + return nil, nil, fmt.Errorf("resolving name: %w", err) + } else if len(resolvedIPs) == 0 { + return nil, nil, fmt.Errorf("no IP address found for name %q", host) + } + + portUint, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, nil, fmt.Errorf("parsing port: %w", err) + } else if portUint == 0 { + return nil, nil, errors.New("destination port cannot be 0") + } + port := uint16(portUint) + + errs := make([]error, 0, len(resolvedIPs)) + for _, ip := range resolvedIPs { + addrPort := netip.AddrPortFrom(ip, port) + httpClient, cleanup, err := c.OpenHTTPS(ctx, host, addrPort) + if err != nil { + errs = append(errs, fmt.Errorf("for %s: %w", ip, err)) + continue + } + return httpClient, cleanup, nil + } + + return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", hostname, errors.Join(errs...)) +} diff --git a/internal/restrictednet/helpers_test.go b/internal/restrictednet/helpers_test.go new file mode 100644 index 000000000..091aefd1c --- /dev/null +++ b/internal/restrictednet/helpers_test.go @@ -0,0 +1,7 @@ +//go:build integration + +package restrictednet + +func ptrTo[T any](value T) *T { + return &value +} diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go new file mode 100644 index 000000000..1ad6d891b --- /dev/null +++ b/internal/restrictednet/https.go @@ -0,0 +1,202 @@ +package restrictednet + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "net/netip" + "os" + "time" + + "github.com/jsimonetti/rtnetlink" + "github.com/qdm12/gluetun/internal/pmtud/constants" +) + +// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination. +// The returned [*http.Client] must be used sequentially only, and each request must +// have its response body fully read/discarded and then closed. +// The returned cleanup function must be called to remove the temporary firewall rule and close connections. +func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationAddrPort netip.AddrPort, +) (httpClient *http.Client, cleanup func() error, err error) { + fd, sourceAddrPort, err := bindSourceConnection(destinationAddrPort.Addr()) + if err != nil { + return nil, nil, fmt.Errorf("binding source port: %w", err) + } + + const remove = false + err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, + sourceAddrPort, destinationAddrPort, remove) + if err != nil { + closeFD(fd) + return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err) + } + + connection, err := connectSourceConnection(ctx, fd, destinationAddrPort) + if err != nil { + const remove = true + _ = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface, + sourceAddrPort, destinationAddrPort, remove) + return nil, nil, fmt.Errorf("connecting source socket: %w", err) + } + + dial := makeDial(connection, destinationTLSName) + httpClient = newHTTPSClient(destinationTLSName, dial) + cleanup = func() error { + var errs []error + httpClient.CloseIdleConnections() + err := connection.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + errs = append(errs, fmt.Errorf("closing connection: %w", err)) + } + const remove = true + err = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface, + sourceAddrPort, destinationAddrPort, remove) + if err != nil { + errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) + } + if len(errs) > 0 { + return errors.Join(errs...) + } + return nil + } + return httpClient, cleanup, nil +} + +type dialFunc func(ctx context.Context, network, address string) (net.Conn, error) + +func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client { + const timeout = 5 * time.Second + transport := &http.Transport{ + MaxIdleConns: 1, + MaxIdleConnsPerHost: 1, + MaxConnsPerHost: 1, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: destinationTLSName, + }, + DialContext: dial, + } + return &http.Client{ + Timeout: timeout, + Transport: transport, + } +} + +func makeDial(connection net.Conn, tlsName string) dialFunc { + _, destinationPort, err := net.SplitHostPort(connection.RemoteAddr().String()) + if err != nil { + panic(err) // connection remote address should always be in the form "host:port" + } + expectedAddress := net.JoinHostPort(tlsName, destinationPort) + used := false + return func(_ context.Context, network, address string) (net.Conn, error) { + if used { + return nil, errors.New("dial function called more than once") + } + used = true + switch network { + case "tcp", "tcp4", "tcp6": + default: + return nil, fmt.Errorf("unexpected dial network %q", network) + } + if address != expectedAddress { + return nil, fmt.Errorf("unexpected dial address %q (expected %q)", address, expectedAddress) + } + return connection, nil + } +} + +func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) { + sourceIP, err := sourceIPForDestination(destinationIP) + if err != nil { + return 0, netip.AddrPort{}, fmt.Errorf("finding source IP: %w", err) + } + + family := constants.AF_INET + if sourceIP.Is6() { + family = constants.AF_INET6 + } + + fd, err = newTCPSockStream(family) + if err != nil { + return 0, netip.AddrPort{}, fmt.Errorf("creating socket: %w", err) + } + + bindAddrPort := netip.AddrPortFrom(sourceIP, 0) + err = bindFD(fd, bindAddrPort) + if err != nil { + closeFD(fd) + return 0, netip.AddrPort{}, fmt.Errorf("binding socket: %w", err) + } + + sourceAddr, err = fdToSourceAddr(fd) + if err != nil { + closeFD(fd) + return 0, netip.AddrPort{}, fmt.Errorf("getting source address: %w", err) + } + + return fd, sourceAddr, nil +} + +func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) ( + connection net.Conn, err error, +) { + err = connectFD(ctx, fd, destinationAddrPort) + if err != nil { + closeFD(fd) + return nil, fmt.Errorf("connecting socket: %w", err) + } + + file := os.NewFile(uintptr(fd), "") + if file == nil { + closeFD(fd) + return nil, fmt.Errorf("creating socket file") + } + defer file.Close() + + connection, err = net.FileConn(file) + if err != nil { + return nil, fmt.Errorf("wrapping socket connection: %w", err) + } + + return connection, nil +} + +func sourceIPForDestination(destinationIP netip.Addr) (srcIP netip.Addr, err error) { + conn, err := rtnetlink.Dial(nil) + if err != nil { + return netip.Addr{}, err + } + defer conn.Close() + + family := uint8(constants.AF_INET) + if destinationIP.Is6() { + family = constants.AF_INET6 + } + + requestMessage := &rtnetlink.RouteMessage{ + Family: family, + Attributes: rtnetlink.RouteAttributes{ + Dst: destinationIP.AsSlice(), + }, + } + messages, err := conn.Route.Get(requestMessage) + if err != nil { + return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", destinationIP, err) + } + + for _, message := range messages { + if message.Attributes.Src == nil { + continue + } + if message.Attributes.Src.To4() == nil { + return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil + } + return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil + } + + return netip.Addr{}, fmt.Errorf("no route to %s", destinationIP) +} diff --git a/internal/restrictednet/https_integration_test.go b/internal/restrictednet/https_integration_test.go new file mode 100644 index 000000000..2a300cfcc --- /dev/null +++ b/internal/restrictednet/https_integration_test.go @@ -0,0 +1,117 @@ +//go:build integration + +package restrictednet + +import ( + "context" + "fmt" + "io" + "net/http" + "net/netip" + "testing" + + "github.com/golang/mock/gomock" + "github.com/qdm12/dns/v2/pkg/provider" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type listenAddrPortMatcher struct { + expected netip.AddrPort +} + +func (m listenAddrPortMatcher) Matches(x any) bool { + ip, ok := x.(netip.AddrPort) + if !ok { + return false + } + if m.expected.IsValid() { + return ip == m.expected + } + return ip.IsValid() && ip.Addr().IsValid() && ip.Port() > 0 +} + +func (m listenAddrPortMatcher) String() string { + if m.expected.IsValid() { + return "is the same as " + m.expected.String() + } + return "is a valid netip.AddrPort with a valid IP and non-zero port" +} + +type destinationAddrPortMatcher struct { + expected netip.AddrPort +} + +func (m destinationAddrPortMatcher) Matches(x any) bool { + ip, ok := x.(netip.AddrPort) + if !ok { + return false + } + if m.expected.IsValid() { + return ip == m.expected + } + return ip.IsValid() && ip.Port() == m.expected.Port() +} + +func (m destinationAddrPortMatcher) String() string { + if m.expected.IsValid() { + return "is the same as " + m.expected.String() + } + return "matches the port " + fmt.Sprint(m.expected.Port()) +} + +func Test_Client_OpenHTTPS(t *testing.T) { + t.Parallel() + ctx := t.Context() + ctrl := gomock.NewController(t) + + const destinationTLSName = "one.one.one.one" + destinationAddrPort := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443) + + firewall := NewMockFirewall(ctrl) + sourceMatcher := listenAddrPortMatcher{} + firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + ctx, "tcp", "eth0", sourceMatcher, destinationAddrPort, false, + ).DoAndReturn(func(_ context.Context, + _, _ string, source, _ netip.AddrPort, _ bool, + ) error { + sourceMatcher.expected = source + return nil + }) + firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + context.Background(), "tcp", "eth0", sourceMatcher, destinationAddrPort, true, + ).Return(nil) + + const ipv6Supported = false + upstreamResolvers := []provider.Provider{provider.Google()} + settings := Settings{ + Firewall: firewall, + DefaultInterface: "eth0", + IPv6Supported: ptrTo(ipv6Supported), + UpstreamResolvers: upstreamResolvers, + } + client := New(settings) + + httpClient, cleanup, err := client.OpenHTTPS(ctx, destinationTLSName, destinationAddrPort) + require.NoError(t, err) + require.NotNil(t, httpClient) + require.NotNil(t, cleanup) + + const requests = 2 + + for range requests { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://"+destinationTLSName, nil) + require.NoError(t, err) + + response, err := httpClient.Do(request) + require.NoError(t, err) + _, err = io.Copy(io.Discard, response.Body) + require.NoError(t, err) + err = response.Body.Close() + require.NoError(t, err) + assert.Equal(t, http.StatusOK, response.StatusCode) + } + + err = cleanup() + require.NoError(t, err) +} diff --git a/internal/restrictednet/interfaces.go b/internal/restrictednet/interfaces.go new file mode 100644 index 000000000..205f78a28 --- /dev/null +++ b/internal/restrictednet/interfaces.go @@ -0,0 +1,12 @@ +package restrictednet + +import ( + "context" + "net/netip" +) + +type Firewall interface { + AcceptOutputFromIPPortToIPPort(ctx context.Context, + protocol, intf string, source, destination netip.AddrPort, remove bool, + ) error +} diff --git a/internal/restrictednet/mocks_generate_test.go b/internal/restrictednet/mocks_generate_test.go new file mode 100644 index 000000000..687eef332 --- /dev/null +++ b/internal/restrictednet/mocks_generate_test.go @@ -0,0 +1,3 @@ +package restrictednet + +//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall diff --git a/internal/restrictednet/mocks_test.go b/internal/restrictednet/mocks_test.go new file mode 100644 index 000000000..f7c322269 --- /dev/null +++ b/internal/restrictednet/mocks_test.go @@ -0,0 +1,50 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/restrictednet (interfaces: Firewall) + +// Package restrictednet is a generated GoMock package. +package restrictednet + +import ( + context "context" + netip "net/netip" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockFirewall is a mock of Firewall interface. +type MockFirewall struct { + ctrl *gomock.Controller + recorder *MockFirewallMockRecorder +} + +// MockFirewallMockRecorder is the mock recorder for MockFirewall. +type MockFirewallMockRecorder struct { + mock *MockFirewall +} + +// NewMockFirewall creates a new mock instance. +func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall { + mock := &MockFirewall{ctrl: ctrl} + mock.recorder = &MockFirewallMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder { + return m.recorder +} + +// AcceptOutputFromIPPortToIPPort mocks base method. +func (m *MockFirewall) AcceptOutputFromIPPortToIPPort(arg0 context.Context, arg1, arg2 string, arg3, arg4 netip.AddrPort, arg5 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptOutputFromIPPortToIPPort", arg0, arg1, arg2, arg3, arg4, arg5) + ret0, _ := ret[0].(error) + return ret0 +} + +// AcceptOutputFromIPPortToIPPort indicates an expected call of AcceptOutputFromIPPortToIPPort. +func (mr *MockFirewallMockRecorder) AcceptOutputFromIPPortToIPPort(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutputFromIPPortToIPPort", reflect.TypeOf((*MockFirewall)(nil).AcceptOutputFromIPPortToIPPort), arg0, arg1, arg2, arg3, arg4, arg5) +} diff --git a/internal/restrictednet/resolve.go b/internal/restrictednet/resolve.go new file mode 100644 index 000000000..ab71a8e7c --- /dev/null +++ b/internal/restrictednet/resolve.go @@ -0,0 +1,205 @@ +package restrictednet + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "net/netip" + "net/url" + "strconv" + + "github.com/miekg/dns" +) + +// ResolveName resolves the given host name to IP addresses using DoH servers, +// while opening temporary restrictive firewall rules for HTTPS traffic to DoH servers. +// The host must be a single well-formed domain name, without port or path. +func (c *Client) ResolveName(ctx context.Context, host string) ( + resolvedAddresses []netip.Addr, err error, +) { + const maxTypes = 2 + questionTypes := make([]uint16, 0, maxTypes) + if c.ipv6Supported { + questionTypes = append(questionTypes, dns.TypeAAAA) + } + questionTypes = append(questionTypes, dns.TypeA) + + var addresses []netip.Addr + errs := make([]error, 0, len(questionTypes)) + for _, questionType := range questionTypes { + answerAddresses, err := c.resolveOneQuestionType(ctx, host, questionType) + if err != nil { + errs = append(errs, err) + continue + } + addresses = append(addresses, answerAddresses...) + } + + switch { + case len(addresses) > 0: + return addresses, nil + case len(errs) == 0: + return nil, nil // no address found + default: // errors + return nil, fmt.Errorf("resolving host %q: %w", host, errors.Join(errs...)) + } +} + +func (c *Client) resolveOneQuestionType(ctx context.Context, + host string, questionType uint16, +) (addresses []netip.Addr, err error) { + queryMessage := &dns.Msg{} + queryMessage.SetQuestion(dns.Fqdn(host), questionType) + queryWire, err := queryMessage.Pack() + if err != nil { + return nil, fmt.Errorf("packing DNS query: %w", err) + } + + // Try every DoH server and every of each of their IP until we get a non-empty + // successful response. + errs := make([]error, 0) + for _, dohServer := range c.dohServers { + dohURL, err := url.Parse(dohServer.URL) + if err != nil { + errs = append(errs, + fmt.Errorf("parsing DoH server URL %s: %w", dohServer.URL, err)) + continue + } + + dohServerIPs := make([]netip.Addr, 0, len(dohServer.IPv4)+len(dohServer.IPv6)) + if c.ipv6Supported { + // Prefer IPv6 addresses if IPv6 is supported + dohServerIPs = append(dohServerIPs, dohServer.IPv6...) + } + dohServerIPs = append(dohServerIPs, dohServer.IPv4...) + + for _, dohServerIP := range dohServerIPs { + const defaultDoHPort uint16 = 443 + port := defaultDoHPort + if portStr := dohURL.Port(); portStr != "" { + port, err = parseDestinationPort(portStr) + if err != nil { + errs = append(errs, fmt.Errorf("parsing DoH server port: %w", err)) + continue + } + } + dohServerAddrPort := netip.AddrPortFrom(dohServerIP, port) + responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerAddrPort) + switch { + case err != nil: + errs = append(errs, fmt.Errorf("querying DoH server %q (%s): %w", + dohServer.URL, dohServerAddrPort, err)) + continue + case responseMessage.Rcode != dns.RcodeSuccess: + errs = append(errs, fmt.Errorf("querying DoH server %q (%s): DNS rcode %s", + dohServer.URL, dohServerAddrPort, dns.RcodeToString[responseMessage.Rcode])) + continue + } + addresses := answersToNetipAddrs(responseMessage) + if len(addresses) == 0 { + continue + } + return addresses, nil + } + } + + if len(errs) == 0 { + return nil, nil + } + + return nil, fmt.Errorf("resolving %s %s: %w", + dns.TypeToString[questionType], host, errors.Join(errs...)) +} + +func (c *Client) doHQuery(ctx context.Context, queryWire []byte, + dohURL *url.URL, dohServerAddrPort netip.AddrPort, +) (responseMessage *dns.Msg, err error) { + httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerAddrPort) + if err != nil { + return nil, fmt.Errorf("opening https connection: %w", err) + } + defer func() { + closeErr := cleanup() + if err == nil && closeErr != nil { + err = fmt.Errorf("cleaning up https connection: %w", closeErr) + } + }() + + requestBody := bytes.NewReader(queryWire) + request, err := http.NewRequestWithContext(ctx, http.MethodPost, dohURL.String(), requestBody) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + request.Header.Set("Content-Type", "application/dns-message") + request.Header.Set("Accept", "application/dns-message") + + response, err := httpClient.Do(request) + if err != nil { + return nil, err + } + + responseData, err := io.ReadAll(response.Body) + if err != nil { + _ = response.Body.Close() + return nil, fmt.Errorf("reading response body: %w", err) + } + + err = response.Body.Close() + if err != nil { + return nil, fmt.Errorf("closing response body: %w", err) + } + + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("response status code is %s (data length %d)", + response.Status, len(responseData)) + } + + responseMessage = new(dns.Msg) + err = responseMessage.Unpack(responseData) + if err != nil { + return nil, fmt.Errorf("parsing DoH response: %w", err) + } + + return responseMessage, nil +} + +func answersToNetipAddrs(message *dns.Msg) (addresses []netip.Addr) { + if message == nil { + return nil + } + addresses = make([]netip.Addr, 0, len(message.Answer)) + for _, answer := range message.Answer { + switch record := answer.(type) { + case *dns.A: + address, ok := netip.AddrFromSlice(record.A) + if ok { + addresses = append(addresses, address.Unmap()) + } + case *dns.AAAA: + address, ok := netip.AddrFromSlice(record.AAAA) + if ok { + addresses = append(addresses, address) + } + } + } + return addresses +} + +func parseDestinationPort(portStr string) (port uint16, err error) { + portUint, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return 0, err + } + + const maxPortUint = 65535 + switch { + case portUint == 0: + return 0, errors.New("port cannot be 0") + case portUint > maxPortUint: + return 0, fmt.Errorf("port cannot be greater than %d", maxPortUint) + } + return uint16(portUint), nil +} diff --git a/internal/restrictednet/resolve_integration_test.go b/internal/restrictednet/resolve_integration_test.go new file mode 100644 index 000000000..0fe602c23 --- /dev/null +++ b/internal/restrictednet/resolve_integration_test.go @@ -0,0 +1,110 @@ +//go:build integration + +package restrictednet + +import ( + "context" + "net" + "net/netip" + "testing" + + "github.com/golang/mock/gomock" + "github.com/miekg/dns" + "github.com/qdm12/dns/v2/pkg/provider" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Client_ResolveName(t *testing.T) { + t.Parallel() + ctx := t.Context() + ctrl := gomock.NewController(t) + + firewall := NewMockFirewall(ctrl) + sourceMatcher := listenAddrPortMatcher{} + destinationMatcher := destinationAddrPortMatcher{ + expected: netip.AddrPortFrom(netip.Addr{}, 443), + } + + // Add rule + firstCall := firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + ctx, "tcp", "eth0", sourceMatcher, destinationMatcher, false, + ).DoAndReturn(func( + _ context.Context, _, _ string, source, destination netip.AddrPort, _ bool, + ) error { + sourceMatcher.expected = source + destinationMatcher.expected = destination + return nil + }) + + // Removal rule + firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + context.Background(), "tcp", "eth0", sourceMatcher, destinationMatcher, true, + ).Return(nil).After(firstCall) + + settings := Settings{ + DefaultInterface: "eth0", + IPv6Supported: ptrTo(false), + Firewall: firewall, + UpstreamResolvers: []provider.Provider{provider.Cloudflare()}, + } + client := New(settings) + + addresses, err := client.ResolveName(ctx, "github.com") + require.NoError(t, err) + assert.NotEmpty(t, addresses) +} + +func Test_answersToNetipAddrs(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + message *dns.Msg + expected []netip.Addr + }{ + "nil_message": {}, + "no_answers": { + message: &dns.Msg{}, + expected: []netip.Addr{}, + }, + "a_record": { + message: &dns.Msg{Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 1}, + }, + }}, + expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + }, + "aaaa_record": { + message: &dns.Msg{Answer: []dns.RR{ + &dns.AAAA{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, + AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, + }, + }}, + expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")}, + }, + "mixed_records": { + message: &dns.Msg{Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 1}, + }, + &dns.AAAA{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, + AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, + }, + }}, + expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")}, + }, + } + + for testName, testCase := range testCases { + t.Run(testName, func(t *testing.T) { + t.Parallel() + addresses := answersToNetipAddrs(testCase.message) + assert.Equal(t, testCase.expected, addresses) + }) + } +} diff --git a/internal/restrictednet/settings.go b/internal/restrictednet/settings.go new file mode 100644 index 000000000..52c678c37 --- /dev/null +++ b/internal/restrictednet/settings.go @@ -0,0 +1,28 @@ +package restrictednet + +import ( + "errors" + + "github.com/qdm12/dns/v2/pkg/provider" +) + +type Settings struct { + DefaultInterface string + IPv6Supported *bool + Firewall Firewall + UpstreamResolvers []provider.Provider +} + +func (s *Settings) validate() error { + switch { + case s.DefaultInterface == "": + return errors.New("default interface is not set") + case s.IPv6Supported == nil: + return errors.New("IPv6 support field is not set") + case s.Firewall == nil: + return errors.New("firewall is not set") + case len(s.UpstreamResolvers) == 0: + return errors.New("no upstream resolvers provided") + } + return nil +} diff --git a/internal/restrictednet/unix.go b/internal/restrictednet/unix.go new file mode 100644 index 000000000..bb52bd310 --- /dev/null +++ b/internal/restrictednet/unix.go @@ -0,0 +1,121 @@ +//go:build !windows + +package restrictednet + +import ( + "context" + "errors" + "fmt" + "net/netip" + "time" + + "golang.org/x/sys/unix" +) + +func closeFD(fd int) { + unix.Close(fd) +} + +func newTCPSockStream(family int) (fd int, err error) { + fd, err = unix.Socket(family, unix.SOCK_STREAM, unix.IPPROTO_TCP) + if err != nil { + return 0, err + } + err = unix.SetNonblock(fd, true) + if err != nil { + _ = unix.Close(fd) + return 0, err + } + return fd, nil +} + +func bindFD(fd int, address netip.AddrPort) error { + bindAddr := makeSockAddr(address) + return unix.Bind(fd, bindAddr) +} + +func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error { + err := unix.Connect(fd, makeSockAddr(destination)) + switch { + case err == nil: + return nil + case !errors.Is(err, unix.EINPROGRESS): + return err + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + bitsIndex := fd / 64 //nolint:mnd + if bitsIndex >= len(unix.FdSet{}.Bits) { + return fmt.Errorf("fd %d exceeds unix.Select FdSet capacity", fd) + } + wset := &unix.FdSet{} + wset.Bits[bitsIndex] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd + eset := &unix.FdSet{} + eset.Bits[bitsIndex] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd + const selectTimeout = 50 * time.Millisecond + timeval := unix.NsecToTimeval(int64(selectTimeout)) + + // Wait for the FD to become writable or hit an error state + n, err := unix.Select(fd+1, nil, wset, eset, &timeval) + if err != nil { + if errors.Is(err, unix.EINTR) { + continue // Syscall interrupted, try again + } + return fmt.Errorf("select error: %w", err) + } else if n == 0 { + continue // no status change yet + } + + // Check if the socket encountered an error + n, err = unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR) + if err != nil { + return fmt.Errorf("getsockopt error: %w", err) + } else if n != 0 { + return fmt.Errorf("connect failed asynchronously: %w", unix.Errno(n)) + } + + return nil + } + } +} + +func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) { + sockAddr, err := unix.Getsockname(fd) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("getting sockname: %w", err) + } + + sourceAddrPort, err = sockAddrToAddrPort(sockAddr) + if err != nil { + return netip.AddrPort{}, err + } + return sourceAddrPort, nil +} + +func makeSockAddr(addressPort netip.AddrPort) unix.Sockaddr { + if addressPort.Addr().Is4() { + return &unix.SockaddrInet4{ + Port: int(addressPort.Port()), + Addr: addressPort.Addr().As4(), + } + } + return &unix.SockaddrInet6{ + Port: int(addressPort.Port()), + Addr: addressPort.Addr().As16(), + } +} + +func sockAddrToAddrPort(sockAddr unix.Sockaddr) (addrPort netip.AddrPort, err error) { + switch typedSockAddr := sockAddr.(type) { + case *unix.SockaddrInet4: + return netip.AddrPortFrom(netip.AddrFrom4(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec + case *unix.SockaddrInet6: + return netip.AddrPortFrom(netip.AddrFrom16(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec + default: + return netip.AddrPort{}, fmt.Errorf("unexpected socket address type %T", typedSockAddr) + } +} diff --git a/internal/restrictednet/windows.go b/internal/restrictednet/windows.go new file mode 100644 index 000000000..454fc2c62 --- /dev/null +++ b/internal/restrictednet/windows.go @@ -0,0 +1,28 @@ +//go:build windows + +package restrictednet + +import ( + "context" + "net/netip" +) + +func closeFD(fd int) { + panic("not implemented") +} + +func newTCPSockStream(family int) (fd int, err error) { + panic("not implemented") +} + +func bindFD(fd int, address netip.AddrPort) error { + panic("not implemented") +} + +func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error { + panic("not implemented") +} + +func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) { + panic("not implemented") +}