From aa781c6cc5dcd3609ac254a2b89a33da6aa4ff72 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 5 Jun 2026 03:56:25 +0000 Subject: [PATCH 01/18] initial --- AGENTS.md | 4 + internal/firewall/interfaces.go | 2 + internal/firewall/iptables/iptables.go | 23 +++ internal/firewall/wrappers.go | 7 + internal/restrictednet/client.go | 56 ++++++ internal/restrictednet/client_test.go | 68 +++++++ internal/restrictednet/https.go | 115 ++++++++++++ internal/restrictednet/interfaces.go | 12 ++ internal/restrictednet/mocks_generate_test.go | 3 + internal/restrictednet/mocks_test.go | 50 +++++ internal/restrictednet/resolve.go | 177 ++++++++++++++++++ internal/restrictednet/resolve_test.go | 82 ++++++++ 12 files changed, 599 insertions(+) create mode 100644 internal/restrictednet/client.go create mode 100644 internal/restrictednet/client_test.go create mode 100644 internal/restrictednet/https.go create mode 100644 internal/restrictednet/interfaces.go create mode 100644 internal/restrictednet/mocks_generate_test.go create mode 100644 internal/restrictednet/mocks_test.go create mode 100644 internal/restrictednet/resolve.go create mode 100644 internal/restrictednet/resolve_test.go diff --git a/AGENTS.md b/AGENTS.md index 0e9902334..b7d0b3bb0 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -50,6 +50,7 @@ Guidance for coding agents working in this repository. - Prefer splitting a code line only when it triggers the `lll` linter, do not split a command or arguments list for each element - Use `netip` types instead of `net` types whenever possible - Use constants instead of variables whenever possible, especially function-local inline constants. +- Prefer using pure functions over methods when possible. Especially if the method does not need any fields from the receiving struct, it should be a pure function. - Do not use `time.Sleep`, prefer using a `time.Timer` with a `select` statement also listening on a context cancelation - `panic`: - should only be used when a programming error is encountered and you should NOT return errors for programming errors (such as passing nil objects) @@ -127,6 +128,7 @@ The Go formatter used is gofumpt. ### Errors - Always prefer wrapping errors with some context with `fmt.Errorf("doing this: %w", err)` +- Use `errors.New("error message")` when creating a 'bottom' constant string error without additional context, instead of `fmt.Errorf` - In rare cases, you can just use `return err` notably: - If the function is called **recursively**, since we don't wrap the wrapping multiple times for each recursion - If the current function only statement is the call to another function, for example: @@ -179,6 +181,8 @@ The Go formatter used is gofumpt. - Do not use `http.DefaultClient`, use a custom `*http.Client` with a fixed timeout and share with dependency injections. - Do not check for injected dependencies being `nil`, prefer to just panic on a nil pointer. By default it's fine to panic if a developer injects a dependency `nil`. `nil` does not mean use a default. +- Prefer using a `switch { case ...}` statement over multiple consecutive `if` statements to have shorter code. +- Prefer using `[...]T` instead of `[]T` when the length is fixed and known at compile time, to avoid unnecessary allocations. ## Validation checklist diff --git a/internal/firewall/interfaces.go b/internal/firewall/interfaces.go index d9c830da2..064046c7b 100644 --- a/internal/firewall/interfaces.go +++ b/internal/firewall/interfaces.go @@ -28,6 +28,8 @@ type firewallImpl interface { //nolint:interfacebloat AcceptIpv6MulticastOutput(ctx context.Context, intf string) error AcceptOutput(ctx context.Context, protocol, intf string, ip netip.Addr, port uint16, remove bool) error + AcceptOutputFromIPPortToIPPort(ctx context.Context, protocol, intf string, + source, destination netip.AddrPort, remove bool) error AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr, subnet netip.Prefix, remove bool) error AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error diff --git a/internal/firewall/iptables/iptables.go b/internal/firewall/iptables/iptables.go index c48879295..b96b4f1d8 100644 --- a/internal/firewall/iptables/iptables.go +++ b/internal/firewall/iptables/iptables.go @@ -177,6 +177,29 @@ func (c *Config) AcceptOutput(ctx context.Context, return c.runIP6tablesInstruction(ctx, instruction) } +func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context, + protocol, intf string, source, destination netip.AddrPort, remove bool, +) error { + if source.Addr().BitLen() != destination.Addr().BitLen() { + return fmt.Errorf("source and destination address families do not match") + } + + interfaceFlag := "-o " + intf + if intf == "*" { // all interfaces + interfaceFlag = "" + } + + instruction := fmt.Sprintf("%s OUTPUT -s %s --sport %d -d %s %s -p %s -m %s --dport %d -j ACCEPT", + appendOrDelete(remove), source.Addr(), source.Port(), destination.Addr(), + interfaceFlag, protocol, protocol, destination.Port()) + if destination.Addr().Is4() { + return c.runIptablesInstruction(ctx, instruction) + } else if c.ip6Tables == "" { + return fmt.Errorf("accept output from %s to %s: %s", source, destination, needIP6Tables) + } + return c.runIP6tablesInstruction(ctx, instruction) +} + // AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet // on the interface intf. If intf is empty, it is set to "*" which means all interfaces. // If remove is true, the rule is removed instead of added. diff --git a/internal/firewall/wrappers.go b/internal/firewall/wrappers.go index 0167eba0a..435df9b18 100644 --- a/internal/firewall/wrappers.go +++ b/internal/firewall/wrappers.go @@ -25,3 +25,10 @@ func (c *Config) AcceptOutput(ctx context.Context, protocol, intf string, ) error { return c.impl.AcceptOutput(ctx, protocol, intf, ip, port, remove) } + +func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context, + protocol, intf string, source, destination netip.AddrPort, remove bool, +) error { + return c.impl.AcceptOutputFromIPPortToIPPort(ctx, protocol, intf, + source, destination, remove) +} diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go new file mode 100644 index 000000000..d8812a3fa --- /dev/null +++ b/internal/restrictednet/client.go @@ -0,0 +1,56 @@ +package restrictednet + +import ( + "context" + "fmt" + "net/http" + + "github.com/qdm12/dns/v2/pkg/provider" +) + +// Client is a client for making restricted network requests, +// such as opening temporary firewall rules for HTTPS connections. +// It is not meant to be high performance, although it can be used for +// multiple requests and concurrently. +type Client struct { + ipv6Supported bool + firewall Firewall + outboundInterface string + dohServers []provider.DoHServer +} + +func New(firewall Firewall, defaultInterface string, ipv6Supported bool, + upstreamResolvers []provider.Provider, +) (*Client, error) { + dohServers := make([]provider.DoHServer, len(upstreamResolvers)) + for i, upstreamResolver := range upstreamResolvers { + dohServers[i] = upstreamResolver.DoH + } + + return &Client{ + firewall: firewall, + outboundInterface: defaultInterface, + ipv6Supported: ipv6Supported, + dohServers: dohServers, + }, nil +} + +func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) ( + httpClient *http.Client, cleanup func() error, err error, +) { + resolvedIPs, err := c.ResolveName(ctx, domain) + if err != nil { + return nil, nil, fmt.Errorf("resolving name: %w", err) + } else if len(resolvedIPs) == 0 { + return nil, nil, fmt.Errorf("no IP address found for name %q", domain) + } + + selectedIP := resolvedIPs[0] + + httpClient, cleanup, err = c.OpenHTTPS(domain, selectedIP) + if err != nil { + return nil, nil, fmt.Errorf("opening HTTPS: %w", err) + } + + return httpClient, cleanup, nil +} diff --git a/internal/restrictednet/client_test.go b/internal/restrictednet/client_test.go new file mode 100644 index 000000000..b3f5ba8d5 --- /dev/null +++ b/internal/restrictednet/client_test.go @@ -0,0 +1,68 @@ +package restrictednet + +import ( + "context" + "net/netip" + "testing" + + "github.com/golang/mock/gomock" + "github.com/qdm12/dns/v2/pkg/provider" + "github.com/stretchr/testify/require" +) + +type listenAddrPortMatcher struct { + expected netip.AddrPort +} + +func (m listenAddrPortMatcher) Matches(x any) bool { + ip, ok := x.(netip.AddrPort) + if !ok { + return false + } + if m.expected.IsValid() { + return ip == m.expected + } + return ip.IsValid() && ip.Addr().IsValid() && ip.Port() > 0 +} + +func (m listenAddrPortMatcher) String() string { + if m.expected.IsValid() { + return "is the same as " + m.expected.String() + } + return "is a valid netip.AddrPort with a valid IP and non-zero port" +} + +func Test_Client_OpenHTTPS(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + firewall := NewMockFirewall(ctrl) + + destination := netip.MustParseAddrPort("1.2.3.4:443") + backgroundContext := context.Background() + sourceMatcher := listenAddrPortMatcher{} + firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + backgroundContext, "tcp", "eth0", sourceMatcher, destination, false, + ).DoAndReturn(func(_ context.Context, + _, _ string, source, _ netip.AddrPort, _ bool, + ) error { + sourceMatcher.expected = source + return nil + }) + firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + backgroundContext, "tcp", "eth0", sourceMatcher, destination, true, + ) + + const ipv6Supported = false + upstreamResolvers := []provider.Provider{provider.Google()} + client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers) + require.NoError(t, err) + + httpClient, cleanup, err := client.OpenHTTPS("api.example.com", netip.MustParseAddr("1.2.3.4")) + require.NoError(t, err) + require.NotNil(t, httpClient) + require.NotNil(t, cleanup) + + err = cleanup() + require.NoError(t, err) +} diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go new file mode 100644 index 000000000..462f69c25 --- /dev/null +++ b/internal/restrictednet/https.go @@ -0,0 +1,115 @@ +package restrictednet + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "net/netip" + "time" +) + +// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination. +// The returned cleanup function must be called to remove the temporary firewall rule and close connections. +func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr, +) (httpClient *http.Client, cleanup func() error, err error) { + listener, sourceAddrPort, err := bindSourcePort(destinationIP) + if err != nil { + return nil, nil, fmt.Errorf("binding source port: %w", err) + } + + const httpsPort = 443 + destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort) + + const remove = false + ctx := context.Background() // it's a quick firewall change, worth not passing a context + err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, + sourceAddrPort, destinationAddrPort, remove) + if err != nil { + _ = listener.Close() + return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err) + } + + httpClient = newHTTPSClient(destinationTLSName, destinationIP, sourceAddrPort) + cleanup = func() error { + var errs []error + httpClient.CloseIdleConnections() + const remove = true + err := c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, + sourceAddrPort, destinationAddrPort, remove) + if err != nil { + errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) + } + err = listener.Close() + if err != nil { + errs = append(errs, fmt.Errorf("closing listener: %w", err)) + } + if len(errs) > 0 { + return errors.Join(errs...) + } + return nil + } + return httpClient, cleanup, nil +} + +func newHTTPSClient(destinationTLSName string, + destinationIP netip.Addr, sourceAddress netip.AddrPort, +) *http.Client { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert + httpTransport.Proxy = nil + httpTransport.MaxIdleConns = 1 + httpTransport.MaxIdleConnsPerHost = 1 + httpTransport.IdleConnTimeout = time.Second + httpTransport.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: destinationTLSName, + } + httpTransport.DialContext = newBoundDialContext(destinationIP, sourceAddress) + + const timeout = 5 * time.Second + return &http.Client{ + Timeout: timeout, + Transport: httpTransport, + } +} + +func newBoundDialContext(destinationAddress netip.Addr, + sourceAddress netip.AddrPort, +) func(ctx context.Context, network, _ string) (net.Conn, error) { + const httpsPort = 443 + destinationAddrPort := netip.AddrPortFrom(destinationAddress, httpsPort).String() + return func(ctx context.Context, network, _ string) (net.Conn, error) { + const timeout = 2 * time.Second + dialer := &net.Dialer{Timeout: timeout} + dialer.LocalAddr = net.TCPAddrFromAddrPort(sourceAddress) + connection, err := dialer.DialContext(ctx, network, destinationAddrPort) + if err != nil { + return nil, fmt.Errorf("%s dialing %s: %w", network, destinationAddrPort, err) + } + return connection, nil + } +} + +func bindSourcePort(destinationIP netip.Addr) ( + listener net.Listener, sourceAddr netip.AddrPort, err error, +) { + var bindAddr netip.Addr + if destinationIP.Is4() { + bindAddr = netip.AddrFrom4([4]byte{}) + } else { + bindAddr = netip.AddrFrom16([16]byte{}) + } + + listener, err = net.ListenTCP("tcp", net.TCPAddrFromAddrPort( + netip.AddrPortFrom(bindAddr, 0))) + if err != nil { + return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err) + } + + tcpAddr := listener.Addr().(*net.TCPAddr) //nolint:forcetypeassert + sourceAddr = tcpAddr.AddrPort() + + return listener, sourceAddr, nil +} diff --git a/internal/restrictednet/interfaces.go b/internal/restrictednet/interfaces.go new file mode 100644 index 000000000..205f78a28 --- /dev/null +++ b/internal/restrictednet/interfaces.go @@ -0,0 +1,12 @@ +package restrictednet + +import ( + "context" + "net/netip" +) + +type Firewall interface { + AcceptOutputFromIPPortToIPPort(ctx context.Context, + protocol, intf string, source, destination netip.AddrPort, remove bool, + ) error +} diff --git a/internal/restrictednet/mocks_generate_test.go b/internal/restrictednet/mocks_generate_test.go new file mode 100644 index 000000000..687eef332 --- /dev/null +++ b/internal/restrictednet/mocks_generate_test.go @@ -0,0 +1,3 @@ +package restrictednet + +//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall diff --git a/internal/restrictednet/mocks_test.go b/internal/restrictednet/mocks_test.go new file mode 100644 index 000000000..f7c322269 --- /dev/null +++ b/internal/restrictednet/mocks_test.go @@ -0,0 +1,50 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/restrictednet (interfaces: Firewall) + +// Package restrictednet is a generated GoMock package. +package restrictednet + +import ( + context "context" + netip "net/netip" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockFirewall is a mock of Firewall interface. +type MockFirewall struct { + ctrl *gomock.Controller + recorder *MockFirewallMockRecorder +} + +// MockFirewallMockRecorder is the mock recorder for MockFirewall. +type MockFirewallMockRecorder struct { + mock *MockFirewall +} + +// NewMockFirewall creates a new mock instance. +func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall { + mock := &MockFirewall{ctrl: ctrl} + mock.recorder = &MockFirewallMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder { + return m.recorder +} + +// AcceptOutputFromIPPortToIPPort mocks base method. +func (m *MockFirewall) AcceptOutputFromIPPortToIPPort(arg0 context.Context, arg1, arg2 string, arg3, arg4 netip.AddrPort, arg5 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptOutputFromIPPortToIPPort", arg0, arg1, arg2, arg3, arg4, arg5) + ret0, _ := ret[0].(error) + return ret0 +} + +// AcceptOutputFromIPPortToIPPort indicates an expected call of AcceptOutputFromIPPortToIPPort. +func (mr *MockFirewallMockRecorder) AcceptOutputFromIPPortToIPPort(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutputFromIPPortToIPPort", reflect.TypeOf((*MockFirewall)(nil).AcceptOutputFromIPPortToIPPort), arg0, arg1, arg2, arg3, arg4, arg5) +} diff --git a/internal/restrictednet/resolve.go b/internal/restrictednet/resolve.go new file mode 100644 index 000000000..8c95b61fa --- /dev/null +++ b/internal/restrictednet/resolve.go @@ -0,0 +1,177 @@ +package restrictednet + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "net/netip" + "net/url" + + "github.com/miekg/dns" +) + +// ResolveName resolves the given host name to IP addresses using DoH servers, +// while opening temporary restrictive firewall rules for HTTPS traffic to DoH servers. +// The host must be a single well-formed domain name, without port or path. +func (c *Client) ResolveName(ctx context.Context, host string) ( + resolvedAddresses []netip.Addr, err error, +) { + questionTypes := make([]uint16, 0, 2) + if c.ipv6Supported { + questionTypes = append(questionTypes, dns.TypeAAAA) + } + questionTypes = append(questionTypes, dns.TypeA) + + var addresses []netip.Addr + errs := make([]error, 0, len(questionTypes)) + for _, questionType := range questionTypes { + answerAddresses, err := c.resolveOneQuestionType(ctx, host, questionType) + if err != nil { + errs = append(errs, err) + continue + } + addresses = append(addresses, answerAddresses...) + } + + switch { + case len(addresses) > 0: + return addresses, nil + case len(errs) == 0: + return nil, nil // no address found + default: // errors + return nil, fmt.Errorf("resolving host %q: %w", host, errors.Join(errs...)) + } +} + +func (c *Client) resolveOneQuestionType(ctx context.Context, + host string, questionType uint16, +) (addresses []netip.Addr, err error) { + queryMessage := &dns.Msg{} + queryMessage.SetQuestion(dns.Fqdn(host), questionType) + queryWire, err := queryMessage.Pack() + if err != nil { + return nil, fmt.Errorf("packing DNS query: %w", err) + } + + // Try every DoH server and every of each of their IP until we get a non-empty + // successful response. + errs := make([]error, 0) + for _, dohServer := range c.dohServers { + dohURL, err := url.Parse(dohServer.URL) + if err != nil { + errs = append(errs, + fmt.Errorf("parsing DoH server URL %s: %w", dohServer.URL, err)) + continue + } + + dohServerIPs := make([]netip.Addr, 0, len(dohServer.IPv4)+len(dohServer.IPv6)) + if c.ipv6Supported { + // Prefer IPv6 addresses if IPv6 is supported + dohServerIPs = append(dohServerIPs, dohServer.IPv6...) + } + dohServerIPs = append(dohServerIPs, dohServer.IPv4...) + + for _, dohServerIP := range dohServerIPs { + responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerIP) + switch { + case err != nil: + errs = append(errs, fmt.Errorf("querying DoH server %q at %s: %w", + dohServer.URL, dohServerIP, err)) + continue + case responseMessage.Rcode != dns.RcodeSuccess: + errs = append(errs, fmt.Errorf("querying DoH server %q at %s: DNS rcode %s", + dohServer.URL, dohServerIP, dns.RcodeToString[responseMessage.Rcode])) + continue + } + addresses := answersToNetipAddrs(responseMessage) + if len(addresses) == 0 { + continue + } + return addresses, nil + } + } + + if len(errs) == 0 { + return nil, nil + } + + return nil, fmt.Errorf("resolving %s %s: %w", + dns.TypeToString[questionType], host, errors.Join(errs...)) +} + +func (c *Client) doHQuery(ctx context.Context, queryWire []byte, + dohURL *url.URL, dohServerIP netip.Addr, +) (responseMessage *dns.Msg, err error) { + httpClient, close, err := c.OpenHTTPS(dohURL.Hostname(), dohServerIP) + if err != nil { + return nil, fmt.Errorf("opening https connection: %w", err) + } + defer func() { + closeErr := close() + if err == nil && closeErr != nil { + err = fmt.Errorf("cleaning up https connection: %w", closeErr) + } + }() + + requestBody := bytes.NewReader(queryWire) + request, err := http.NewRequestWithContext(ctx, http.MethodPost, dohURL.String(), requestBody) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + request.Header.Set("Content-Type", "application/dns-message") + request.Header.Set("Accept", "application/dns-message") + + response, err := httpClient.Do(request) + if err != nil { + return nil, err + } + + responseData, err := io.ReadAll(response.Body) + if err != nil { + _ = response.Body.Close() + return nil, fmt.Errorf("reading response body: %w", err) + } + + err = response.Body.Close() + if err != nil { + return nil, fmt.Errorf("closing response body: %w", err) + } + + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("response status code is %s, data: %s", + response.Status, responseData) + } + + responseMessage = new(dns.Msg) + err = responseMessage.Unpack(responseData) + if err != nil { + return nil, fmt.Errorf("parsing DoH response: %w", err) + } + + return responseMessage, nil +} + +func answersToNetipAddrs(message *dns.Msg) (addresses []netip.Addr) { + if message == nil { + return nil + } + addresses = make([]netip.Addr, 0, len(message.Answer)) + for _, answer := range message.Answer { + switch record := answer.(type) { + case *dns.A: + address, ok := netip.AddrFromSlice(record.A) + if ok { + addresses = append(addresses, address.Unmap()) + } + case *dns.AAAA: + address, ok := netip.AddrFromSlice(record.AAAA) + if ok { + addresses = append(addresses, address) + } + } + } + return addresses +} diff --git a/internal/restrictednet/resolve_test.go b/internal/restrictednet/resolve_test.go new file mode 100644 index 000000000..a0e50b42b --- /dev/null +++ b/internal/restrictednet/resolve_test.go @@ -0,0 +1,82 @@ +package restrictednet + +import ( + "net" + "net/netip" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func Test_answersToNetipAddrs(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + message *dns.Msg + expected []netip.Addr + errorIsNil bool + }{ + "nil_message": { + message: nil, + expected: nil, + errorIsNil: true, + }, + "no_answers": { + message: &dns.Msg{}, + expected: []netip.Addr{}, + errorIsNil: true, + }, + "a_record": { + message: &dns.Msg{ + Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 1}, + }, + }, + }, + expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + errorIsNil: true, + }, + "aaaa_record": { + message: &dns.Msg{ + Answer: []dns.RR{ + &dns.AAAA{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, + AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, + }, + }, + }, + expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")}, + errorIsNil: true, + }, + "mixed_records": { + message: &dns.Msg{ + Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 1}, + }, + &dns.AAAA{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, + AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, + }, + }, + }, + expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")}, + errorIsNil: true, + }, + } + + for testName, testCase := range testCases { + testCase := testCase + t.Run(testName, func(t *testing.T) { + t.Parallel() + + addresses := answersToNetipAddrs(testCase.message) + + assert.Equal(t, testCase.expected, addresses) + }) + } +} From fad8c9889a9ea2a8eeb473c3b0712a1d2cc216f8 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 5 Jun 2026 04:21:53 +0000 Subject: [PATCH 02/18] Minor fixes --- internal/firewall/iptables/iptables.go | 6 +++--- internal/restrictednet/resolve.go | 7 ++++--- internal/restrictednet/resolve_test.go | 1 - 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/firewall/iptables/iptables.go b/internal/firewall/iptables/iptables.go index b96b4f1d8..44d9ab1f5 100644 --- a/internal/firewall/iptables/iptables.go +++ b/internal/firewall/iptables/iptables.go @@ -189,9 +189,9 @@ func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context, interfaceFlag = "" } - instruction := fmt.Sprintf("%s OUTPUT -s %s --sport %d -d %s %s -p %s -m %s --dport %d -j ACCEPT", - appendOrDelete(remove), source.Addr(), source.Port(), destination.Addr(), - interfaceFlag, protocol, protocol, destination.Port()) + instruction := fmt.Sprintf("%s OUTPUT %s -s %s -d %s -p %s -m %s --sport %d --dport %d -j ACCEPT", + appendOrDelete(remove), interfaceFlag, source.Addr(), destination.Addr(), + protocol, protocol, source.Port(), destination.Port()) if destination.Addr().Is4() { return c.runIptablesInstruction(ctx, instruction) } else if c.ip6Tables == "" { diff --git a/internal/restrictednet/resolve.go b/internal/restrictednet/resolve.go index 8c95b61fa..b5b789c7f 100644 --- a/internal/restrictednet/resolve.go +++ b/internal/restrictednet/resolve.go @@ -19,7 +19,8 @@ import ( func (c *Client) ResolveName(ctx context.Context, host string) ( resolvedAddresses []netip.Addr, err error, ) { - questionTypes := make([]uint16, 0, 2) + const maxTypes = 2 + questionTypes := make([]uint16, 0, maxTypes) if c.ipv6Supported { questionTypes = append(questionTypes, dns.TypeAAAA) } @@ -105,12 +106,12 @@ func (c *Client) resolveOneQuestionType(ctx context.Context, func (c *Client) doHQuery(ctx context.Context, queryWire []byte, dohURL *url.URL, dohServerIP netip.Addr, ) (responseMessage *dns.Msg, err error) { - httpClient, close, err := c.OpenHTTPS(dohURL.Hostname(), dohServerIP) + httpClient, cleanup, err := c.OpenHTTPS(dohURL.Hostname(), dohServerIP) if err != nil { return nil, fmt.Errorf("opening https connection: %w", err) } defer func() { - closeErr := close() + closeErr := cleanup() if err == nil && closeErr != nil { err = fmt.Errorf("cleaning up https connection: %w", closeErr) } diff --git a/internal/restrictednet/resolve_test.go b/internal/restrictednet/resolve_test.go index a0e50b42b..51762778b 100644 --- a/internal/restrictednet/resolve_test.go +++ b/internal/restrictednet/resolve_test.go @@ -70,7 +70,6 @@ func Test_answersToNetipAddrs(t *testing.T) { } for testName, testCase := range testCases { - testCase := testCase t.Run(testName, func(t *testing.T) { t.Parallel() From a9a36644ecdf3c2ffef0e1a31a1559e9a9b1c41e Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 5 Jun 2026 04:46:16 +0000 Subject: [PATCH 03/18] imporatnt fix 1 --- internal/restrictednet/client.go | 2 +- internal/restrictednet/client_test.go | 24 +++++++++++--- internal/restrictednet/https.go | 47 ++++++++++++--------------- internal/restrictednet/resolve.go | 2 +- 4 files changed, 41 insertions(+), 34 deletions(-) diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index d8812a3fa..292f3e3d0 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -47,7 +47,7 @@ func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) ( selectedIP := resolvedIPs[0] - httpClient, cleanup, err = c.OpenHTTPS(domain, selectedIP) + httpClient, cleanup, err = c.OpenHTTPS(ctx, domain, selectedIP) if err != nil { return nil, nil, fmt.Errorf("opening HTTPS: %w", err) } diff --git a/internal/restrictednet/client_test.go b/internal/restrictednet/client_test.go index b3f5ba8d5..ff10e822f 100644 --- a/internal/restrictednet/client_test.go +++ b/internal/restrictednet/client_test.go @@ -2,6 +2,7 @@ package restrictednet import ( "context" + "net" "net/netip" "testing" @@ -34,15 +35,28 @@ func (m listenAddrPortMatcher) String() string { func Test_Client_OpenHTTPS(t *testing.T) { t.Parallel() + ctx := t.Context() + + netConfig := net.ListenConfig{} + listener, err := netConfig.Listen(ctx, "tcp", "127.0.0.1:443") + require.NoError(t, err) + t.Cleanup(func() { + _ = listener.Close() + }) + go func() { + connection, acceptErr := listener.Accept() + if acceptErr == nil { + _ = connection.Close() + } + }() ctrl := gomock.NewController(t) firewall := NewMockFirewall(ctrl) - destination := netip.MustParseAddrPort("1.2.3.4:443") - backgroundContext := context.Background() + destination := netip.MustParseAddrPort("127.0.0.1:443") sourceMatcher := listenAddrPortMatcher{} firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - backgroundContext, "tcp", "eth0", sourceMatcher, destination, false, + ctx, "tcp", "eth0", sourceMatcher, destination, false, ).DoAndReturn(func(_ context.Context, _, _ string, source, _ netip.AddrPort, _ bool, ) error { @@ -50,7 +64,7 @@ func Test_Client_OpenHTTPS(t *testing.T) { return nil }) firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - backgroundContext, "tcp", "eth0", sourceMatcher, destination, true, + ctx, "tcp", "eth0", sourceMatcher, destination, true, ) const ipv6Supported = false @@ -58,7 +72,7 @@ func Test_Client_OpenHTTPS(t *testing.T) { client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers) require.NoError(t, err) - httpClient, cleanup, err := client.OpenHTTPS("api.example.com", netip.MustParseAddr("1.2.3.4")) + httpClient, cleanup, err := client.OpenHTTPS(ctx, "api.example.com", netip.MustParseAddr("127.0.0.1")) require.NoError(t, err) require.NotNil(t, httpClient) require.NotNil(t, cleanup) diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 462f69c25..767d95e2c 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -13,9 +13,9 @@ import ( // OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination. // The returned cleanup function must be called to remove the temporary firewall rule and close connections. -func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr, +func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationIP netip.Addr, ) (httpClient *http.Client, cleanup func() error, err error) { - listener, sourceAddrPort, err := bindSourcePort(destinationIP) + connection, sourceAddrPort, err := bindSourceConnection(ctx, destinationIP) if err != nil { return nil, nil, fmt.Errorf("binding source port: %w", err) } @@ -24,15 +24,14 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr, destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort) const remove = false - ctx := context.Background() // it's a quick firewall change, worth not passing a context err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, sourceAddrPort, destinationAddrPort, remove) if err != nil { - _ = listener.Close() + _ = connection.Close() return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err) } - httpClient = newHTTPSClient(destinationTLSName, destinationIP, sourceAddrPort) + httpClient = newHTTPSClient(destinationTLSName, connection) cleanup = func() error { var errs []error httpClient.CloseIdleConnections() @@ -42,9 +41,9 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr, if err != nil { errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) } - err = listener.Close() + err = connection.Close() if err != nil { - errs = append(errs, fmt.Errorf("closing listener: %w", err)) + errs = append(errs, fmt.Errorf("closing connection: %w", err)) } if len(errs) > 0 { return errors.Join(errs...) @@ -55,7 +54,7 @@ func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr, } func newHTTPSClient(destinationTLSName string, - destinationIP netip.Addr, sourceAddress netip.AddrPort, + connection net.Conn, ) *http.Client { httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert httpTransport.Proxy = nil @@ -66,7 +65,7 @@ func newHTTPSClient(destinationTLSName string, MinVersion: tls.VersionTLS12, ServerName: destinationTLSName, } - httpTransport.DialContext = newBoundDialContext(destinationIP, sourceAddress) + httpTransport.DialContext = newConnectionDialContext(connection) const timeout = 5 * time.Second return &http.Client{ @@ -75,25 +74,14 @@ func newHTTPSClient(destinationTLSName string, } } -func newBoundDialContext(destinationAddress netip.Addr, - sourceAddress netip.AddrPort, -) func(ctx context.Context, network, _ string) (net.Conn, error) { - const httpsPort = 443 - destinationAddrPort := netip.AddrPortFrom(destinationAddress, httpsPort).String() +func newConnectionDialContext(connection net.Conn) func(ctx context.Context, network, _ string) (net.Conn, error) { return func(ctx context.Context, network, _ string) (net.Conn, error) { - const timeout = 2 * time.Second - dialer := &net.Dialer{Timeout: timeout} - dialer.LocalAddr = net.TCPAddrFromAddrPort(sourceAddress) - connection, err := dialer.DialContext(ctx, network, destinationAddrPort) - if err != nil { - return nil, fmt.Errorf("%s dialing %s: %w", network, destinationAddrPort, err) - } return connection, nil } } -func bindSourcePort(destinationIP netip.Addr) ( - listener net.Listener, sourceAddr netip.AddrPort, err error, +func bindSourceConnection(ctx context.Context, destinationIP netip.Addr) ( + connection net.Conn, sourceAddr netip.AddrPort, err error, ) { var bindAddr netip.Addr if destinationIP.Is4() { @@ -102,14 +90,19 @@ func bindSourcePort(destinationIP netip.Addr) ( bindAddr = netip.AddrFrom16([16]byte{}) } - listener, err = net.ListenTCP("tcp", net.TCPAddrFromAddrPort( - netip.AddrPortFrom(bindAddr, 0))) + const httpsPort = 443 + destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort) + dialer := &net.Dialer{ + Timeout: time.Second, + LocalAddr: net.TCPAddrFromAddrPort(netip.AddrPortFrom(bindAddr, 0)), + } + connection, err = dialer.DialContext(ctx, "tcp", destinationAddrPort.String()) if err != nil { return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err) } - tcpAddr := listener.Addr().(*net.TCPAddr) //nolint:forcetypeassert + tcpAddr := connection.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert sourceAddr = tcpAddr.AddrPort() - return listener, sourceAddr, nil + return connection, sourceAddr, nil } diff --git a/internal/restrictednet/resolve.go b/internal/restrictednet/resolve.go index b5b789c7f..e14e5c9b6 100644 --- a/internal/restrictednet/resolve.go +++ b/internal/restrictednet/resolve.go @@ -106,7 +106,7 @@ func (c *Client) resolveOneQuestionType(ctx context.Context, func (c *Client) doHQuery(ctx context.Context, queryWire []byte, dohURL *url.URL, dohServerIP netip.Addr, ) (responseMessage *dns.Msg, err error) { - httpClient, cleanup, err := c.OpenHTTPS(dohURL.Hostname(), dohServerIP) + httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerIP) if err != nil { return nil, fmt.Errorf("opening https connection: %w", err) } From 820689cc238b48062fde13827d2ec4c2cf06d7aa Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 5 Jun 2026 04:46:20 +0000 Subject: [PATCH 04/18] imporatnt fix 2 --- internal/restrictednet/https.go | 123 +++++++++++++++++++++++------- internal/restrictednet/unix.go | 64 ++++++++++++++++ internal/restrictednet/windows.go | 27 +++++++ 3 files changed, 187 insertions(+), 27 deletions(-) create mode 100644 internal/restrictednet/unix.go create mode 100644 internal/restrictednet/windows.go diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 767d95e2c..9444ab7ad 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -8,14 +8,18 @@ import ( "net" "net/http" "net/netip" + "os" "time" + + "github.com/jsimonetti/rtnetlink" + "github.com/qdm12/gluetun/internal/pmtud/constants" ) // OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination. // The returned cleanup function must be called to remove the temporary firewall rule and close connections. func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationIP netip.Addr, ) (httpClient *http.Client, cleanup func() error, err error) { - connection, sourceAddrPort, err := bindSourceConnection(ctx, destinationIP) + fd, sourceAddrPort, err := bindSourceConnection(destinationIP) if err != nil { return nil, nil, fmt.Errorf("binding source port: %w", err) } @@ -27,10 +31,18 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, sourceAddrPort, destinationAddrPort, remove) if err != nil { - _ = connection.Close() + closeFD(fd) return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err) } + connection, err := connectSourceConnection(fd, destinationAddrPort) + if err != nil { + const remove = true + _ = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, + sourceAddrPort, destinationAddrPort, remove) + return nil, nil, fmt.Errorf("connecting source socket: %w", err) + } + httpClient = newHTTPSClient(destinationTLSName, connection) cleanup = func() error { var errs []error @@ -53,9 +65,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti return httpClient, cleanup, nil } -func newHTTPSClient(destinationTLSName string, - connection net.Conn, -) *http.Client { +func newHTTPSClient(destinationTLSName string, connection net.Conn) *http.Client { httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert httpTransport.Proxy = nil httpTransport.MaxIdleConns = 1 @@ -65,7 +75,9 @@ func newHTTPSClient(destinationTLSName string, MinVersion: tls.VersionTLS12, ServerName: destinationTLSName, } - httpTransport.DialContext = newConnectionDialContext(connection) + httpTransport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return connection, nil + } const timeout = 5 * time.Second return &http.Client{ @@ -74,35 +86,92 @@ func newHTTPSClient(destinationTLSName string, } } -func newConnectionDialContext(connection net.Conn) func(ctx context.Context, network, _ string) (net.Conn, error) { - return func(ctx context.Context, network, _ string) (net.Conn, error) { - return connection, nil +func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) { + sourceIP, err := sourceIPForDestination(destinationIP) + if err != nil { + return 0, netip.AddrPort{}, fmt.Errorf("finding source IP: %w", err) + } + + family := constants.AF_INET + if sourceIP.Is6() { + family = constants.AF_INET6 + } + + fd, err = newTCPSockStream(family) + if err != nil { + return 0, netip.AddrPort{}, fmt.Errorf("creating socket: %w", err) + } + + bindAddrPort := netip.AddrPortFrom(sourceIP, 0) + err = bindFD(fd, bindAddrPort) + if err != nil { + closeFD(fd) + return 0, netip.AddrPort{}, fmt.Errorf("binding socket: %w", err) + } + + sourceAddr, err = fdToSourceAddr(fd) + if err != nil { + closeFD(fd) + return 0, netip.AddrPort{}, fmt.Errorf("getting source address: %w", err) } + + return fd, sourceAddr, nil } -func bindSourceConnection(ctx context.Context, destinationIP netip.Addr) ( - connection net.Conn, sourceAddr netip.AddrPort, err error, -) { - var bindAddr netip.Addr - if destinationIP.Is4() { - bindAddr = netip.AddrFrom4([4]byte{}) - } else { - bindAddr = netip.AddrFrom16([16]byte{}) +func connectSourceConnection(fd int, destinationAddrPort netip.AddrPort) (connection net.Conn, err error) { + err = connectFD(fd, destinationAddrPort) + if err != nil { + closeFD(fd) + return nil, fmt.Errorf("connecting socket: %w", err) } - const httpsPort = 443 - destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort) - dialer := &net.Dialer{ - Timeout: time.Second, - LocalAddr: net.TCPAddrFromAddrPort(netip.AddrPortFrom(bindAddr, 0)), + file := os.NewFile(uintptr(fd), "") + if file == nil { + closeFD(fd) + return nil, fmt.Errorf("creating socket file") + } + defer file.Close() + + connection, err = net.FileConn(file) + if err != nil { + return nil, fmt.Errorf("wrapping socket connection: %w", err) + } + + return connection, nil +} + +func sourceIPForDestination(destinationIP netip.Addr) (srcIP netip.Addr, err error) { + conn, err := rtnetlink.Dial(nil) + if err != nil { + return netip.Addr{}, err + } + defer conn.Close() + + family := uint8(constants.AF_INET) + if destinationIP.Is6() { + family = constants.AF_INET6 } - connection, err = dialer.DialContext(ctx, "tcp", destinationAddrPort.String()) + + requestMessage := &rtnetlink.RouteMessage{ + Family: family, + Attributes: rtnetlink.RouteAttributes{ + Dst: destinationIP.AsSlice(), + }, + } + messages, err := conn.Route.Get(requestMessage) if err != nil { - return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err) + return netip.Addr{}, fmt.Errorf("getting routes to %s: %w", destinationIP, err) } - tcpAddr := connection.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert - sourceAddr = tcpAddr.AddrPort() + for _, message := range messages { + if message.Attributes.Src == nil { + continue + } + if message.Attributes.Src.To4() == nil { + return netip.AddrFrom16([16]byte(message.Attributes.Src)), nil + } + return netip.AddrFrom4([4]byte(message.Attributes.Src)), nil + } - return connection, sourceAddr, nil + return netip.Addr{}, fmt.Errorf("no route to %s", destinationIP) } diff --git a/internal/restrictednet/unix.go b/internal/restrictednet/unix.go new file mode 100644 index 000000000..76895943e --- /dev/null +++ b/internal/restrictednet/unix.go @@ -0,0 +1,64 @@ +//go:build unix + +package restrictednet + +import ( + "fmt" + "net/netip" + + "golang.org/x/sys/unix" +) + +func closeFD(fd int) { + unix.Close(fd) +} + +func newTCPSockStream(family int) (fd int, err error) { + return unix.Socket(family, unix.SOCK_STREAM, unix.IPPROTO_TCP) +} + +func bindFD(fd int, address netip.AddrPort) error { + bindAddr := makeSockAddr(address) + return unix.Bind(fd, bindAddr) +} + +func connectFD(fd int, destination netip.AddrPort) error { + return unix.Connect(fd, makeSockAddr(destination)) +} + +func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) { + sockAddr, err := unix.Getsockname(fd) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("getting sockname: %w", err) + } + + sourceAddrPort, err = sockAddrToAddrPort(sockAddr) + if err != nil { + return netip.AddrPort{}, err + } + return sourceAddrPort, nil +} + +func makeSockAddr(addressPort netip.AddrPort) unix.Sockaddr { + if addressPort.Addr().Is4() { + return &unix.SockaddrInet4{ + Port: int(addressPort.Port()), + Addr: addressPort.Addr().As4(), + } + } + return &unix.SockaddrInet6{ + Port: int(addressPort.Port()), + Addr: addressPort.Addr().As16(), + } +} + +func sockAddrToAddrPort(sockAddr unix.Sockaddr) (addrPort netip.AddrPort, err error) { + switch typedSockAddr := sockAddr.(type) { + case *unix.SockaddrInet4: + return netip.AddrPortFrom(netip.AddrFrom4(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec + case *unix.SockaddrInet6: + return netip.AddrPortFrom(netip.AddrFrom16(typedSockAddr.Addr), uint16(typedSockAddr.Port)), nil //nolint:gosec + default: + return netip.AddrPort{}, fmt.Errorf("unexpected socket address type %T", typedSockAddr) + } +} diff --git a/internal/restrictednet/windows.go b/internal/restrictednet/windows.go new file mode 100644 index 000000000..e1b88453a --- /dev/null +++ b/internal/restrictednet/windows.go @@ -0,0 +1,27 @@ +//go:build windows + +package restrictednet + +import ( + "net/netip" +) + +func closeFD(fd int) { + panic("not implemented") +} + +func newTCPSockStream(family int) (fd int, err error) { + panic("not implemented") +} + +func bindFD(fd int, address netip.AddrPort) error { + panic("not implemented") +} + +func connectFD(fd int, destination netip.AddrPort) error { + panic("not implemented") +} + +func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) { + panic("not implemented") +} From c18c54c3b7a0dd9ce1b7b4d02828648b98f83fa5 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 5 Jun 2026 04:58:47 +0000 Subject: [PATCH 05/18] Fix test to use a random port and not 443 --- internal/restrictednet/client.go | 3 +++ internal/restrictednet/client_test.go | 6 ++++-- internal/restrictednet/https.go | 3 +-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index 292f3e3d0..9e20b9394 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -17,6 +17,7 @@ type Client struct { firewall Firewall outboundInterface string dohServers []provider.DoHServer + httpsPort uint16 } func New(firewall Firewall, defaultInterface string, ipv6Supported bool, @@ -27,11 +28,13 @@ func New(firewall Firewall, defaultInterface string, ipv6Supported bool, dohServers[i] = upstreamResolver.DoH } + const defaultHTTPSPort = 443 return &Client{ firewall: firewall, outboundInterface: defaultInterface, ipv6Supported: ipv6Supported, dohServers: dohServers, + httpsPort: defaultHTTPSPort, }, nil } diff --git a/internal/restrictednet/client_test.go b/internal/restrictednet/client_test.go index ff10e822f..65504f62f 100644 --- a/internal/restrictednet/client_test.go +++ b/internal/restrictednet/client_test.go @@ -38,11 +38,12 @@ func Test_Client_OpenHTTPS(t *testing.T) { ctx := t.Context() netConfig := net.ListenConfig{} - listener, err := netConfig.Listen(ctx, "tcp", "127.0.0.1:443") + listener, err := netConfig.Listen(ctx, "tcp", "127.0.0.1:0") require.NoError(t, err) t.Cleanup(func() { _ = listener.Close() }) + listeningPort := uint16(listener.Addr().(*net.TCPAddr).Port) //nolint:gosec,forcetypeassert go func() { connection, acceptErr := listener.Accept() if acceptErr == nil { @@ -53,7 +54,7 @@ func Test_Client_OpenHTTPS(t *testing.T) { ctrl := gomock.NewController(t) firewall := NewMockFirewall(ctrl) - destination := netip.MustParseAddrPort("127.0.0.1:443") + destination := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), listeningPort) sourceMatcher := listenAddrPortMatcher{} firewall.EXPECT().AcceptOutputFromIPPortToIPPort( ctx, "tcp", "eth0", sourceMatcher, destination, false, @@ -71,6 +72,7 @@ func Test_Client_OpenHTTPS(t *testing.T) { upstreamResolvers := []provider.Provider{provider.Google()} client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers) require.NoError(t, err) + client.httpsPort = listeningPort httpClient, cleanup, err := client.OpenHTTPS(ctx, "api.example.com", netip.MustParseAddr("127.0.0.1")) require.NoError(t, err) diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 9444ab7ad..02863455c 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -24,8 +24,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti return nil, nil, fmt.Errorf("binding source port: %w", err) } - const httpsPort = 443 - destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort) + destinationAddrPort := netip.AddrPortFrom(destinationIP, c.httpsPort) const remove = false err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, From b48ba8cb0abb29556d635fb2b427b3934da30c98 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 5 Jun 2026 05:01:18 +0000 Subject: [PATCH 06/18] review feedback --- internal/firewall/iptables/iptables.go | 3 ++- internal/restrictednet/https.go | 13 ++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/internal/firewall/iptables/iptables.go b/internal/firewall/iptables/iptables.go index 44d9ab1f5..a6b40cf72 100644 --- a/internal/firewall/iptables/iptables.go +++ b/internal/firewall/iptables/iptables.go @@ -2,6 +2,7 @@ package iptables import ( "context" + "errors" "fmt" "io" "net/netip" @@ -181,7 +182,7 @@ func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context, protocol, intf string, source, destination netip.AddrPort, remove bool, ) error { if source.Addr().BitLen() != destination.Addr().BitLen() { - return fmt.Errorf("source and destination address families do not match") + return errors.New("source and destination address families do not match") } interfaceFlag := "-o " + intf diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 02863455c..f3b71a43a 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -69,12 +69,23 @@ func newHTTPSClient(destinationTLSName string, connection net.Conn) *http.Client httpTransport.Proxy = nil httpTransport.MaxIdleConns = 1 httpTransport.MaxIdleConnsPerHost = 1 + httpTransport.MaxConnsPerHost = 1 httpTransport.IdleConnTimeout = time.Second httpTransport.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, ServerName: destinationTLSName, } - httpTransport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + + expectedAddress := net.JoinHostPort(destinationTLSName, "443") + httpTransport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) { + switch network { + case "tcp", "tcp4", "tcp6": + default: + return nil, fmt.Errorf("unexpected dial network %q", network) + } + if address != expectedAddress { + return nil, fmt.Errorf("unexpected dial address %q (expected %q)", address, expectedAddress) + } return connection, nil } From 2d2c3713032c560a1d675b56e0bdcc22fb93633d Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 5 Jun 2026 15:25:44 +0000 Subject: [PATCH 07/18] pr review fixes --- internal/restrictednet/client.go | 16 ++++++++++------ internal/restrictednet/https.go | 5 +++-- .../{client_test.go => https_test.go} | 2 +- internal/restrictednet/resolve.go | 4 ++-- 4 files changed, 16 insertions(+), 11 deletions(-) rename internal/restrictednet/{client_test.go => https_test.go} (96%) diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index 9e20b9394..cdcd9472c 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -2,6 +2,7 @@ package restrictednet import ( "context" + "errors" "fmt" "net/http" @@ -48,12 +49,15 @@ func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) ( return nil, nil, fmt.Errorf("no IP address found for name %q", domain) } - selectedIP := resolvedIPs[0] - - httpClient, cleanup, err = c.OpenHTTPS(ctx, domain, selectedIP) - if err != nil { - return nil, nil, fmt.Errorf("opening HTTPS: %w", err) + errs := make([]error, 0, len(resolvedIPs)) + for _, ip := range resolvedIPs { + httpClient, cleanup, err := c.OpenHTTPS(ctx, domain, ip) + if err != nil { + errs = append(errs, fmt.Errorf("for %s: %w", ip, err)) + continue + } + return httpClient, cleanup, nil } - return httpClient, cleanup, nil + return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", domain, errors.Join(errs...)) } diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index f3b71a43a..1bb5bb485 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -47,7 +47,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti var errs []error httpClient.CloseIdleConnections() const remove = true - err := c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, + err := c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface, sourceAddrPort, destinationAddrPort, remove) if err != nil { errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) @@ -76,7 +76,8 @@ func newHTTPSClient(destinationTLSName string, connection net.Conn) *http.Client ServerName: destinationTLSName, } - expectedAddress := net.JoinHostPort(destinationTLSName, "443") + _, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String()) + expectedAddress := net.JoinHostPort(destinationTLSName, destinationPort) httpTransport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) { switch network { case "tcp", "tcp4", "tcp6": diff --git a/internal/restrictednet/client_test.go b/internal/restrictednet/https_test.go similarity index 96% rename from internal/restrictednet/client_test.go rename to internal/restrictednet/https_test.go index 65504f62f..7db81e600 100644 --- a/internal/restrictednet/client_test.go +++ b/internal/restrictednet/https_test.go @@ -65,7 +65,7 @@ func Test_Client_OpenHTTPS(t *testing.T) { return nil }) firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - ctx, "tcp", "eth0", sourceMatcher, destination, true, + context.Background(), "tcp", "eth0", sourceMatcher, destination, true, ) const ipv6Supported = false diff --git a/internal/restrictednet/resolve.go b/internal/restrictednet/resolve.go index e14e5c9b6..aa1c3e647 100644 --- a/internal/restrictednet/resolve.go +++ b/internal/restrictednet/resolve.go @@ -142,8 +142,8 @@ func (c *Client) doHQuery(ctx context.Context, queryWire []byte, } if response.StatusCode != http.StatusOK { - return nil, fmt.Errorf("response status code is %s, data: %s", - response.Status, responseData) + return nil, fmt.Errorf("response status code is %s (data length %d)", + response.Status, len(responseData)) } responseMessage = new(dns.Msg) From 8da913d7c6ae9faded223bc7606d17327b830a62 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 5 Jun 2026 15:35:28 +0000 Subject: [PATCH 08/18] context aware connectSourceConnection --- internal/restrictednet/https.go | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 1bb5bb485..209e68f0f 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -34,7 +34,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err) } - connection, err := connectSourceConnection(fd, destinationAddrPort) + connection, err := connectSourceConnection(ctx, fd, destinationAddrPort) if err != nil { const remove = true _ = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, @@ -129,10 +129,27 @@ func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.Ad return fd, sourceAddr, nil } -func connectSourceConnection(fd int, destinationAddrPort netip.AddrPort) (connection net.Conn, err error) { - err = connectFD(fd, destinationAddrPort) - if err != nil { +func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) ( + connection net.Conn, err error, +) { + errCh := make(chan error) + go func() { + errCh <- connectFD(fd, destinationAddrPort) + }() + + select { + case err = <-errCh: + if err != nil { + closeFD(fd) + return nil, fmt.Errorf("connecting socket: %w", err) + } + case <-ctx.Done(): + err = ctx.Err() closeFD(fd) + connectErr := <-errCh + if connectErr != nil { + err = fmt.Errorf("%w (%w)", connectErr, err) + } return nil, fmt.Errorf("connecting socket: %w", err) } From e2256dd1b2bfe8c042f03470dc5b7fb41ffb926b Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 5 Jun 2026 15:52:51 +0000 Subject: [PATCH 09/18] moare fixes --- internal/restrictednet/client.go | 7 +++++-- internal/restrictednet/https.go | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index cdcd9472c..9225a96c1 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -23,7 +23,10 @@ type Client struct { func New(firewall Firewall, defaultInterface string, ipv6Supported bool, upstreamResolvers []provider.Provider, -) (*Client, error) { +) *Client { + if len(upstreamResolvers) == 0 { + panic("no upstream resolvers provided") // programming error + } dohServers := make([]provider.DoHServer, len(upstreamResolvers)) for i, upstreamResolver := range upstreamResolvers { dohServers[i] = upstreamResolver.DoH @@ -36,7 +39,7 @@ func New(firewall Firewall, defaultInterface string, ipv6Supported bool, ipv6Supported: ipv6Supported, dohServers: dohServers, httpsPort: defaultHTTPSPort, - }, nil + } } func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) ( diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 209e68f0f..06c378cea 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -37,7 +37,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti connection, err := connectSourceConnection(ctx, fd, destinationAddrPort) if err != nil { const remove = true - _ = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, + _ = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface, sourceAddrPort, destinationAddrPort, remove) return nil, nil, fmt.Errorf("connecting source socket: %w", err) } From dd07205b85a72373849bedcb5ff186aea063a0b2 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 9 Jun 2026 12:47:13 +0000 Subject: [PATCH 10/18] add tests --- AGENTS.md | 1 + internal/restrictednet/client.go | 23 +- internal/restrictednet/helpers_test.go | 185 ++++++++++++ internal/restrictednet/https.go | 22 +- internal/restrictednet/https_test.go | 9 +- internal/restrictednet/resolve.go | 4 +- internal/restrictednet/resolve_test.go | 399 ++++++++++++++++++++++--- internal/restrictednet/settings.go | 36 +++ 8 files changed, 609 insertions(+), 70 deletions(-) create mode 100644 internal/restrictednet/helpers_test.go create mode 100644 internal/restrictednet/settings.go diff --git a/AGENTS.md b/AGENTS.md index b7d0b3bb0..fb3f7d8ea 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -116,6 +116,7 @@ Mocking works with the `go.uber.org/mock` library, and the `mockgen` tool. - **Never** use `.AnyTimes()` on mocks. Always define the number of times a certain mock call should be called, with `.Times(3)` for example. - **Always** set the `.Return(...)` on the mock if the function returns something. - Avoid using **mock helpers** functions, prefer a bit of repetition than tight coupling and dependency + - Always define the gomock controller `ctrl` in the subtest and not in the parent test, or a subtest mock failing will crash all the other subtests. ### main.go diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index 9225a96c1..7b1547a69 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -14,30 +14,31 @@ import ( // It is not meant to be high performance, although it can be used for // multiple requests and concurrently. type Client struct { + outboundInterface string ipv6Supported bool firewall Firewall - outboundInterface string dohServers []provider.DoHServer + baseTransport *http.Transport httpsPort uint16 } -func New(firewall Firewall, defaultInterface string, ipv6Supported bool, - upstreamResolvers []provider.Provider, -) *Client { - if len(upstreamResolvers) == 0 { - panic("no upstream resolvers provided") // programming error +func New(settings Settings) *Client { + settings.setDefaults() + if err := settings.validate(); err != nil { + panic(fmt.Sprintf("invalid settings: %v", err)) // programming error } - dohServers := make([]provider.DoHServer, len(upstreamResolvers)) - for i, upstreamResolver := range upstreamResolvers { + dohServers := make([]provider.DoHServer, len(settings.UpstreamResolvers)) + for i, upstreamResolver := range settings.UpstreamResolvers { dohServers[i] = upstreamResolver.DoH } const defaultHTTPSPort = 443 return &Client{ - firewall: firewall, - outboundInterface: defaultInterface, - ipv6Supported: ipv6Supported, + outboundInterface: settings.DefaultInterface, + ipv6Supported: *settings.IPv6Supported, + firewall: settings.Firewall, dohServers: dohServers, + baseTransport: settings.BaseTransport, httpsPort: defaultHTTPSPort, } } diff --git a/internal/restrictednet/helpers_test.go b/internal/restrictednet/helpers_test.go new file mode 100644 index 000000000..cac3fd380 --- /dev/null +++ b/internal/restrictednet/helpers_test.go @@ -0,0 +1,185 @@ +package restrictednet + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "net" + "net/http" + "net/netip" + "net/url" + "os" + "strconv" + "sync" + "syscall" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func ptrTo[T any](value T) *T { + return &value +} + +func newInterceptTransport(handler func(host string, requestBody io.Reader) (*http.Response, error)) *http.Transport { + return &http.Transport{ + DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) { + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + + reader := bufio.NewReader(serverConn) + request, err := http.ReadRequest(reader) + if err != nil { + return + } + + response, err := handler(request.Host, request.Body) + if err != nil { + return + } + + // Read the response body and re-create it to avoid linting + // complaining that the response body must be closed. + responseData, err := io.ReadAll(response.Body) + if err != nil { + return + } + _ = response.Body.Close() + response.Body = io.NopCloser(bytes.NewReader(responseData)) + + _ = response.Write(serverConn) + }() + return clientConn, nil + }, + } +} + +func expectFirewallCallPair( + firewall *MockFirewall, + addContext context.Context, //nolint:revive + destinationIP netip.Addr, + destinationPort uint16, + addErr error, + removeErr error, +) { + destination := netip.AddrPortFrom(destinationIP, destinationPort) + sourceMatcher := listenAddrPortMatcher{} + + firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + addContext, "tcp", "eth0", sourceMatcher, destination, false, + ).DoAndReturn(func( + _ context.Context, _, _ string, source, _ netip.AddrPort, _ bool, + ) error { + sourceMatcher.expected = source + return addErr + }) + + firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + context.Background(), "tcp", "eth0", sourceMatcher, destination, true, + ).Return(removeErr) +} + +func urlToHostnamePort(rawURL string, port uint16) string { + parsedURL, err := url.Parse(rawURL) + if err != nil { + panic(err) // programming error in test + } + parsedURL.Host = net.JoinHostPort(parsedURL.Hostname(), strconv.FormatUint(uint64(port), 10)) + return parsedURL.String() +} + +func responseWireForQuery(t *testing.T, queryReader io.Reader, answers ...dns.RR) []byte { + t.Helper() + + queryData, err := io.ReadAll(queryReader) + require.NoError(t, err) + + query := new(dns.Msg) + err = query.Unpack(queryData) + require.NoError(t, err) + + response := new(dns.Msg) + response.SetReply(query) + response.Answer = append(response.Answer, answers...) + + wire, err := response.Pack() + require.NoError(t, err) + return wire +} + +func startTCPAccepter(t *testing.T) (port uint16) { + t.Helper() + + // Find a port available for both TCP IPv4 and TCP IPv6 + listeners := make([]net.Listener, 2) // IPv4 + IPv6 + netConfig := net.ListenConfig{} + var listenersToClose []net.Listener + for t.Context().Err() == nil { + // Find an available port for IPv4 + listeningAddress := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 0) + listener, err := netConfig.Listen(t.Context(), "tcp", listeningAddress.String()) + require.NoError(t, err) + listeners[0] = listener + port = uint16(listener.Addr().(*net.TCPAddr).Port) //nolint:gosec,forcetypeassert + + // Check if that port is also available for IPv6 + listeningAddress = netip.AddrPortFrom( + netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}), + port, + ) + listener, err = netConfig.Listen(t.Context(), "tcp", listeningAddress.String()) + if err == nil { + listeners[1] = listener + break // success, we found a port available for both IPv4 and IPv6 + } + var opErr *net.OpError + if errors.As(err, &opErr) { + var sysErr *os.SyscallError + if errors.As(opErr.Err, &sysErr) && errors.Is(sysErr.Err, syscall.EADDRINUSE) { + // Port found for IPv4 is already in use for IPv6, try another port + // We don't close the IPv4 listener yet to make sure we don't get the same port again from the OS. + listenersToClose = append(listenersToClose, listeners[0]) + continue + } + } + } + + for _, listener := range listenersToClose { + err := listener.Close() + assert.NoError(t, err) + } + + var ready sync.WaitGroup + ready.Add(len(listeners)) + for _, listener := range listeners { + t.Cleanup(func() { + err := listener.Close() + assert.NoError(t, err) + }) + + go func() { + ready.Done() + for { + connection, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) && t.Context().Err() != nil { + return + } + assert.NoError(t, err) + return + } + err = connection.Close() + assert.NoError(t, err) + } + }() + } + + ready.Wait() + + return port +} diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 06c378cea..d61f78d12 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -42,7 +42,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti return nil, nil, fmt.Errorf("connecting source socket: %w", err) } - httpClient = newHTTPSClient(destinationTLSName, connection) + httpClient = newHTTPSClient(c.baseTransport, destinationTLSName, connection) cleanup = func() error { var errs []error httpClient.CloseIdleConnections() @@ -64,21 +64,21 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti return httpClient, cleanup, nil } -func newHTTPSClient(destinationTLSName string, connection net.Conn) *http.Client { - httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert - httpTransport.Proxy = nil - httpTransport.MaxIdleConns = 1 - httpTransport.MaxIdleConnsPerHost = 1 - httpTransport.MaxConnsPerHost = 1 - httpTransport.IdleConnTimeout = time.Second - httpTransport.TLSClientConfig = &tls.Config{ +func newHTTPSClient(baseTransport *http.Transport, destinationTLSName string, connection net.Conn) *http.Client { + transport := baseTransport.Clone() + transport.Proxy = nil + transport.MaxIdleConns = 1 + transport.MaxIdleConnsPerHost = 1 + transport.MaxConnsPerHost = 1 + transport.IdleConnTimeout = time.Second + transport.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, ServerName: destinationTLSName, } _, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String()) expectedAddress := net.JoinHostPort(destinationTLSName, destinationPort) - httpTransport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) { + transport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) { switch network { case "tcp", "tcp4", "tcp6": default: @@ -93,7 +93,7 @@ func newHTTPSClient(destinationTLSName string, connection net.Conn) *http.Client const timeout = 5 * time.Second return &http.Client{ Timeout: timeout, - Transport: httpTransport, + Transport: transport, } } diff --git a/internal/restrictednet/https_test.go b/internal/restrictednet/https_test.go index 7db81e600..02e36fd20 100644 --- a/internal/restrictednet/https_test.go +++ b/internal/restrictednet/https_test.go @@ -70,8 +70,13 @@ func Test_Client_OpenHTTPS(t *testing.T) { const ipv6Supported = false upstreamResolvers := []provider.Provider{provider.Google()} - client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers) - require.NoError(t, err) + settings := Settings{ + Firewall: firewall, + DefaultInterface: "eth0", + IPv6Supported: ptrTo(ipv6Supported), + UpstreamResolvers: upstreamResolvers, + } + client := New(settings) client.httpsPort = listeningPort httpClient, cleanup, err := client.OpenHTTPS(ctx, "api.example.com", netip.MustParseAddr("127.0.0.1")) diff --git a/internal/restrictednet/resolve.go b/internal/restrictednet/resolve.go index aa1c3e647..2feffeb52 100644 --- a/internal/restrictednet/resolve.go +++ b/internal/restrictednet/resolve.go @@ -79,11 +79,11 @@ func (c *Client) resolveOneQuestionType(ctx context.Context, responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerIP) switch { case err != nil: - errs = append(errs, fmt.Errorf("querying DoH server %q at %s: %w", + errs = append(errs, fmt.Errorf("querying DoH server %q (ip %s): %w", dohServer.URL, dohServerIP, err)) continue case responseMessage.Rcode != dns.RcodeSuccess: - errs = append(errs, fmt.Errorf("querying DoH server %q at %s: DNS rcode %s", + errs = append(errs, fmt.Errorf("querying DoH server %q (ip %s): DNS rcode %s", dohServer.URL, dohServerIP, dns.RcodeToString[responseMessage.Rcode])) continue } diff --git a/internal/restrictednet/resolve_test.go b/internal/restrictednet/resolve_test.go index 51762778b..972b5ff12 100644 --- a/internal/restrictednet/resolve_test.go +++ b/internal/restrictednet/resolve_test.go @@ -1,80 +1,391 @@ package restrictednet import ( + "bytes" + "context" + "errors" + "io" "net" + "net/http" "net/netip" + "net/url" + "sync/atomic" "testing" + "github.com/golang/mock/gomock" "github.com/miekg/dns" + "github.com/qdm12/dns/v2/pkg/provider" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func Test_answersToNetipAddrs(t *testing.T) { +func Test_Client_ResolveName(t *testing.T) { t.Parallel() testCases := map[string]struct { - message *dns.Msg - expected []netip.Addr - errorIsNil bool + ipv6Supported bool + upstreamResolvers []provider.Provider + expectedAddresses []netip.Addr + errorContains string + expectedDestIPs []netip.Addr + responder func(host string, requestBody io.Reader) (*http.Response, error) }{ - "nil_message": { - message: nil, - expected: nil, - errorIsNil: true, - }, - "no_answers": { - message: &dns.Msg{}, - expected: []netip.Addr{}, - errorIsNil: true, + "success_single_server_ipv4": { + upstreamResolvers: []provider.Provider{{ + DoH: provider.DoHServer{ + URL: "https://resolver-1.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + }, + }}, + expectedAddresses: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + responder: func(_ string, requestBody io.Reader) (*http.Response, error) { + wire := responseWireForQuery(t, requestBody, &dns.A{ + Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 1}, + }) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil + }, }, - "a_record": { - message: &dns.Msg{ - Answer: []dns.RR{ - &dns.A{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, - A: net.IP{1, 1, 1, 1}, + "fallback_between_servers": { + upstreamResolvers: []provider.Provider{ + { + DoH: provider.DoHServer{ + URL: "https://resolver-1.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + }, + }, + { + DoH: provider.DoHServer{ + URL: "https://resolver-2.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, }, }, }, - expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, - errorIsNil: true, + expectedAddresses: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, + expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, + responder: func(host string, requestBody io.Reader) (*http.Response, error) { + if host == "resolver-1.local" || + len(host) > len("resolver-1.local:") && host[:len("resolver-1.local:")] == "resolver-1.local:" { + return &http.Response{ + StatusCode: http.StatusBadGateway, + Status: "502 Bad Gateway", + Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), + }, nil + } + wire := responseWireForQuery(t, requestBody, &dns.A{ + Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{2, 2, 2, 2}, + }) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil + }, }, - "aaaa_record": { - message: &dns.Msg{ - Answer: []dns.RR{ - &dns.AAAA{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, - AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, - }, + "fallback_between_ips": { + upstreamResolvers: []provider.Provider{{ + DoH: provider.DoHServer{ + URL: "https://resolver.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, + }, + }}, + expectedAddresses: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, + expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, + responder: func() func(host string, requestBody io.Reader) (*http.Response, error) { + var calls atomic.Int32 + return func(_ string, requestBody io.Reader) (*http.Response, error) { + if calls.Add(1) == 1 { // first call fails + return &http.Response{ + StatusCode: http.StatusNotFound, + Status: "502 Bad Gateway", + Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), + }, nil + } + wire := responseWireForQuery(t, requestBody, &dns.A{ + Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 2}, + }) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil + } + }(), //nolint:bodyclose + }, + "dns_rcode_error_servfail": { + upstreamResolvers: []provider.Provider{{ + DoH: provider.DoHServer{ + URL: "https://resolver.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, }, + }}, + errorContains: "SERVFAIL", + expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + responder: func(_ string, requestBody io.Reader) (*http.Response, error) { + queryWire, err := io.ReadAll(requestBody) + require.NoError(t, err) + query := new(dns.Msg) + err = query.Unpack(queryWire) + require.NoError(t, err) + response := new(dns.Msg) + response.SetReply(query) + response.Rcode = dns.RcodeServerFailure + wire, err := response.Pack() + require.NoError(t, err) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil }, - expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")}, - errorIsNil: true, }, - "mixed_records": { - message: &dns.Msg{ - Answer: []dns.RR{ - &dns.A{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, - A: net.IP{1, 1, 1, 1}, - }, - &dns.AAAA{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, - AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, - }, + "no_answer": { + upstreamResolvers: []provider.Provider{{ + DoH: provider.DoHServer{ + URL: "https://resolver.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, }, + }}, + expectedAddresses: nil, + expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + responder: func(_ string, requestBody io.Reader) (*http.Response, error) { + wire := responseWireForQuery(t, requestBody) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil + }, + }, + "ipv6_preference": { + ipv6Supported: true, + upstreamResolvers: []provider.Provider{{ + DoH: provider.DoHServer{ + URL: "https://resolver.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + IPv6: []netip.Addr{netip.MustParseAddr("::1")}, + }, + }}, + expectedAddresses: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")}, + expectedDestIPs: []netip.Addr{ + netip.MustParseAddr("::1"), + netip.MustParseAddr("::1"), + netip.MustParseAddr("127.0.0.1"), + }, + responder: func(_ string, requestBody io.Reader) (*http.Response, error) { + queryWire, err := io.ReadAll(requestBody) + require.NoError(t, err) + query := new(dns.Msg) + err = query.Unpack(queryWire) + require.NoError(t, err) + if len(query.Question) > 0 && query.Question[0].Qtype == dns.TypeA { + wire := responseWireForQuery(t, bytes.NewReader(queryWire)) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil + } + wire := responseWireForQuery(t, bytes.NewReader(queryWire), &dns.AAAA{ + Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, + AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, + }) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil + }, + }, + "all_servers_fail": { + upstreamResolvers: []provider.Provider{ + {DoH: provider.DoHServer{ + URL: "https://resolver-1.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + }}, + {DoH: provider.DoHServer{ + URL: "https://resolver-2.local/dns-query", + IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + }}, + }, + errorContains: "resolving host", + expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, + responder: func(_ string, _ io.Reader) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadGateway, + Status: "502 Bad Gateway", + Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), + }, nil }, - expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")}, - errorIsNil: true, }, } for testName, testCase := range testCases { t.Run(testName, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) - addresses := answersToNetipAddrs(testCase.message) + firewall := NewMockFirewall(ctrl) + port := startTCPAccepter(t) + + for _, destinationIP := range testCase.expectedDestIPs { + expectFirewallCallPair(firewall, t.Context(), destinationIP, port, nil, nil) + } + + resolvers := make([]provider.Provider, len(testCase.upstreamResolvers)) + copy(resolvers, testCase.upstreamResolvers) + for i := range resolvers { + resolvers[i].DoH.URL = urlToHostnamePort(resolvers[i].DoH.URL, port) + } + + settings := Settings{ + DefaultInterface: "eth0", + IPv6Supported: ptrTo(testCase.ipv6Supported), + Firewall: firewall, + UpstreamResolvers: resolvers, + BaseTransport: newInterceptTransport(testCase.responder), + } + client := New(settings) + client.httpsPort = port + + addresses, err := client.ResolveName(t.Context(), "github.com") + assert.Equal(t, testCase.expectedAddresses, addresses) + if testCase.errorContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, testCase.errorContains) + } else { + require.NoError(t, err) + } + }) + } +} + +func Test_Client_doHQuery(t *testing.T) { + t.Parallel() + + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + queryWire, err := query.Pack() + require.NoError(t, err) + + responseWire := responseWireForQuery(t, bytes.NewReader(queryWire), &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 1}, + }) + + testCases := map[string]struct { + response *http.Response + addFirewallRuleErr error + removeFirewallRuleErr error + errorContains string + expectedIPs []netip.Addr + }{ + "success": { + response: &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(responseWire))}, + expectedIPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + }, + "http_status_not_ok": { + response: &http.Response{ + StatusCode: http.StatusBadGateway, + Status: "502 Bad Gateway", + Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), + }, + errorContains: "response status code is 502 Bad Gateway", + }, + "malformed_dns_response": { + response: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString("not-dns")), + }, + errorContains: "parsing DoH response", + }, + "cleanup_error": { + response: &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(responseWire))}, + removeFirewallRuleErr: errors.New("cleanup failed"), + errorContains: "cleaning up https connection: removing output traffic rule: cleanup failed", + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + firewall := NewMockFirewall(ctrl) + port := startTCPAccepter(t) + + expectFirewallCallPair( + firewall, + context.Background(), + netip.MustParseAddr("127.0.0.1"), + port, + testCase.addFirewallRuleErr, + testCase.removeFirewallRuleErr, + ) + + settings := Settings{ + DefaultInterface: "eth0", + IPv6Supported: ptrTo(false), + Firewall: firewall, + UpstreamResolvers: []provider.Provider{provider.Google()}, + BaseTransport: newInterceptTransport(func(_ string, _ io.Reader) (*http.Response, error) { + return testCase.response, nil + }), + } + client := New(settings) + client.httpsPort = port + + dohURL, err := url.Parse(urlToHostnamePort("https://resolver.local/dns-query", port)) + require.NoError(t, err) + + message, err := client.doHQuery( + context.Background(), + queryWire, + dohURL, + netip.MustParseAddr("127.0.0.1"), + ) + + if testCase.errorContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, testCase.errorContains) + return + } + + require.NoError(t, err) + addresses := answersToNetipAddrs(message) + assert.Equal(t, testCase.expectedIPs, addresses) + }) + } +} + +func Test_answersToNetipAddrs(t *testing.T) { + t.Parallel() + testCases := map[string]struct { + message *dns.Msg + expected []netip.Addr + }{ + "nil_message": {}, + "no_answers": { + message: &dns.Msg{}, + expected: []netip.Addr{}, + }, + "a_record": { + message: &dns.Msg{Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 1}, + }, + }}, + expected: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + }, + "aaaa_record": { + message: &dns.Msg{Answer: []dns.RR{ + &dns.AAAA{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, + AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, + }, + }}, + expected: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")}, + }, + "mixed_records": { + message: &dns.Msg{Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, + A: net.IP{1, 1, 1, 1}, + }, + &dns.AAAA{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, + AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, + }, + }}, + expected: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("2001:4860:4860::8888")}, + }, + } + + for testName, testCase := range testCases { + t.Run(testName, func(t *testing.T) { + t.Parallel() + addresses := answersToNetipAddrs(testCase.message) assert.Equal(t, testCase.expected, addresses) }) } diff --git a/internal/restrictednet/settings.go b/internal/restrictednet/settings.go new file mode 100644 index 000000000..4b943b52c --- /dev/null +++ b/internal/restrictednet/settings.go @@ -0,0 +1,36 @@ +package restrictednet + +import ( + "errors" + "net/http" + + "github.com/qdm12/dns/v2/pkg/provider" +) + +type Settings struct { + DefaultInterface string + IPv6Supported *bool + Firewall Firewall + UpstreamResolvers []provider.Provider + BaseTransport *http.Transport +} + +func (s *Settings) setDefaults() { + if s.BaseTransport == nil { + s.BaseTransport = http.DefaultTransport.(*http.Transport) //nolint:forcetypeassert + } +} + +func (s *Settings) validate() error { + switch { + case s.DefaultInterface == "": + return errors.New("default interface is not set") + case s.IPv6Supported == nil: + return errors.New("IPv6 support field is not set") + case s.Firewall == nil: + return errors.New("firewall is not set") + case len(s.UpstreamResolvers) == 0: + return errors.New("no upstream resolvers provided") + } + return nil +} From b5366b9e440cef2a15ff06d61dcba95d9e1bd7e7 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 9 Jun 2026 14:04:32 +0000 Subject: [PATCH 11/18] Change tests to be more integration oriented --- internal/restrictednet/client.go | 30 ++- internal/restrictednet/helpers_test.go | 180 ------------- internal/restrictednet/https.go | 51 ++-- internal/restrictednet/https_test.go | 63 +++-- internal/restrictednet/resolve.go | 16 +- internal/restrictednet/resolve_test.go | 340 ++----------------------- internal/restrictednet/settings.go | 8 - 7 files changed, 125 insertions(+), 563 deletions(-) diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index 7b1547a69..82091a75c 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -4,7 +4,10 @@ import ( "context" "errors" "fmt" + "net" "net/http" + "net/netip" + "strconv" "github.com/qdm12/dns/v2/pkg/provider" ) @@ -18,12 +21,9 @@ type Client struct { ipv6Supported bool firewall Firewall dohServers []provider.DoHServer - baseTransport *http.Transport - httpsPort uint16 } func New(settings Settings) *Client { - settings.setDefaults() if err := settings.validate(); err != nil { panic(fmt.Sprintf("invalid settings: %v", err)) // programming error } @@ -32,30 +32,38 @@ func New(settings Settings) *Client { dohServers[i] = upstreamResolver.DoH } - const defaultHTTPSPort = 443 return &Client{ outboundInterface: settings.DefaultInterface, ipv6Supported: *settings.IPv6Supported, firewall: settings.Firewall, dohServers: dohServers, - baseTransport: settings.BaseTransport, - httpsPort: defaultHTTPSPort, } } -func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) ( +func (c *Client) OpenHTTPSByDomain(ctx context.Context, hostname string) ( httpClient *http.Client, cleanup func() error, err error, ) { - resolvedIPs, err := c.ResolveName(ctx, domain) + host, portStr, err := net.SplitHostPort(hostname) + if err != nil { + return nil, nil, fmt.Errorf("splitting host and port: %w", err) + } + resolvedIPs, err := c.ResolveName(ctx, host) if err != nil { return nil, nil, fmt.Errorf("resolving name: %w", err) } else if len(resolvedIPs) == 0 { - return nil, nil, fmt.Errorf("no IP address found for name %q", domain) + return nil, nil, fmt.Errorf("no IP address found for name %q", host) + } + + portUint, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, nil, fmt.Errorf("parsing port: %w", err) } + port := uint16(portUint) errs := make([]error, 0, len(resolvedIPs)) for _, ip := range resolvedIPs { - httpClient, cleanup, err := c.OpenHTTPS(ctx, domain, ip) + addrPort := netip.AddrPortFrom(ip, port) + httpClient, cleanup, err := c.OpenHTTPS(ctx, host, addrPort) if err != nil { errs = append(errs, fmt.Errorf("for %s: %w", ip, err)) continue @@ -63,5 +71,5 @@ func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) ( return httpClient, cleanup, nil } - return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", domain, errors.Join(errs...)) + return nil, nil, fmt.Errorf("opening HTTPS to %s: %w", hostname, errors.Join(errs...)) } diff --git a/internal/restrictednet/helpers_test.go b/internal/restrictednet/helpers_test.go index cac3fd380..54070c324 100644 --- a/internal/restrictednet/helpers_test.go +++ b/internal/restrictednet/helpers_test.go @@ -1,185 +1,5 @@ package restrictednet -import ( - "bufio" - "bytes" - "context" - "errors" - "io" - "net" - "net/http" - "net/netip" - "net/url" - "os" - "strconv" - "sync" - "syscall" - "testing" - - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - func ptrTo[T any](value T) *T { return &value } - -func newInterceptTransport(handler func(host string, requestBody io.Reader) (*http.Response, error)) *http.Transport { - return &http.Transport{ - DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) { - clientConn, serverConn := net.Pipe() - go func() { - defer serverConn.Close() - - reader := bufio.NewReader(serverConn) - request, err := http.ReadRequest(reader) - if err != nil { - return - } - - response, err := handler(request.Host, request.Body) - if err != nil { - return - } - - // Read the response body and re-create it to avoid linting - // complaining that the response body must be closed. - responseData, err := io.ReadAll(response.Body) - if err != nil { - return - } - _ = response.Body.Close() - response.Body = io.NopCloser(bytes.NewReader(responseData)) - - _ = response.Write(serverConn) - }() - return clientConn, nil - }, - } -} - -func expectFirewallCallPair( - firewall *MockFirewall, - addContext context.Context, //nolint:revive - destinationIP netip.Addr, - destinationPort uint16, - addErr error, - removeErr error, -) { - destination := netip.AddrPortFrom(destinationIP, destinationPort) - sourceMatcher := listenAddrPortMatcher{} - - firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - addContext, "tcp", "eth0", sourceMatcher, destination, false, - ).DoAndReturn(func( - _ context.Context, _, _ string, source, _ netip.AddrPort, _ bool, - ) error { - sourceMatcher.expected = source - return addErr - }) - - firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - context.Background(), "tcp", "eth0", sourceMatcher, destination, true, - ).Return(removeErr) -} - -func urlToHostnamePort(rawURL string, port uint16) string { - parsedURL, err := url.Parse(rawURL) - if err != nil { - panic(err) // programming error in test - } - parsedURL.Host = net.JoinHostPort(parsedURL.Hostname(), strconv.FormatUint(uint64(port), 10)) - return parsedURL.String() -} - -func responseWireForQuery(t *testing.T, queryReader io.Reader, answers ...dns.RR) []byte { - t.Helper() - - queryData, err := io.ReadAll(queryReader) - require.NoError(t, err) - - query := new(dns.Msg) - err = query.Unpack(queryData) - require.NoError(t, err) - - response := new(dns.Msg) - response.SetReply(query) - response.Answer = append(response.Answer, answers...) - - wire, err := response.Pack() - require.NoError(t, err) - return wire -} - -func startTCPAccepter(t *testing.T) (port uint16) { - t.Helper() - - // Find a port available for both TCP IPv4 and TCP IPv6 - listeners := make([]net.Listener, 2) // IPv4 + IPv6 - netConfig := net.ListenConfig{} - var listenersToClose []net.Listener - for t.Context().Err() == nil { - // Find an available port for IPv4 - listeningAddress := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 0) - listener, err := netConfig.Listen(t.Context(), "tcp", listeningAddress.String()) - require.NoError(t, err) - listeners[0] = listener - port = uint16(listener.Addr().(*net.TCPAddr).Port) //nolint:gosec,forcetypeassert - - // Check if that port is also available for IPv6 - listeningAddress = netip.AddrPortFrom( - netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}), - port, - ) - listener, err = netConfig.Listen(t.Context(), "tcp", listeningAddress.String()) - if err == nil { - listeners[1] = listener - break // success, we found a port available for both IPv4 and IPv6 - } - var opErr *net.OpError - if errors.As(err, &opErr) { - var sysErr *os.SyscallError - if errors.As(opErr.Err, &sysErr) && errors.Is(sysErr.Err, syscall.EADDRINUSE) { - // Port found for IPv4 is already in use for IPv6, try another port - // We don't close the IPv4 listener yet to make sure we don't get the same port again from the OS. - listenersToClose = append(listenersToClose, listeners[0]) - continue - } - } - } - - for _, listener := range listenersToClose { - err := listener.Close() - assert.NoError(t, err) - } - - var ready sync.WaitGroup - ready.Add(len(listeners)) - for _, listener := range listeners { - t.Cleanup(func() { - err := listener.Close() - assert.NoError(t, err) - }) - - go func() { - ready.Done() - for { - connection, err := listener.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) && t.Context().Err() != nil { - return - } - assert.NoError(t, err) - return - } - err = connection.Close() - assert.NoError(t, err) - } - }() - } - - ready.Wait() - - return port -} diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index d61f78d12..08ae73504 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -17,15 +17,13 @@ import ( // OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination. // The returned cleanup function must be called to remove the temporary firewall rule and close connections. -func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationIP netip.Addr, +func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationAddrPort netip.AddrPort, ) (httpClient *http.Client, cleanup func() error, err error) { - fd, sourceAddrPort, err := bindSourceConnection(destinationIP) + fd, sourceAddrPort, err := bindSourceConnection(destinationAddrPort.Addr()) if err != nil { return nil, nil, fmt.Errorf("binding source port: %w", err) } - destinationAddrPort := netip.AddrPortFrom(destinationIP, c.httpsPort) - const remove = false err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface, sourceAddrPort, destinationAddrPort, remove) @@ -42,7 +40,8 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti return nil, nil, fmt.Errorf("connecting source socket: %w", err) } - httpClient = newHTTPSClient(c.baseTransport, destinationTLSName, connection) + dial := makeDial(connection, destinationTLSName) + httpClient = newHTTPSClient(destinationTLSName, dial) cleanup = func() error { var errs []error httpClient.CloseIdleConnections() @@ -53,7 +52,7 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) } err = connection.Close() - if err != nil { + if err != nil && !errors.Is(err, net.ErrClosed) { errs = append(errs, fmt.Errorf("closing connection: %w", err)) } if len(errs) > 0 { @@ -64,21 +63,31 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti return httpClient, cleanup, nil } -func newHTTPSClient(baseTransport *http.Transport, destinationTLSName string, connection net.Conn) *http.Client { - transport := baseTransport.Clone() - transport.Proxy = nil - transport.MaxIdleConns = 1 - transport.MaxIdleConnsPerHost = 1 - transport.MaxConnsPerHost = 1 - transport.IdleConnTimeout = time.Second - transport.TLSClientConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - ServerName: destinationTLSName, +type dialFunc func(ctx context.Context, network, address string) (net.Conn, error) + +func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client { + const timeout = 5 * time.Second + transport := &http.Transport{ + MaxIdleConns: 1, + MaxIdleConnsPerHost: 1, + MaxConnsPerHost: 1, + IdleConnTimeout: time.Second, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: destinationTLSName, + }, + DialContext: dial, + } + return &http.Client{ + Timeout: timeout, + Transport: transport, } +} +func makeDial(connection net.Conn, tlsName string) dialFunc { _, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String()) - expectedAddress := net.JoinHostPort(destinationTLSName, destinationPort) - transport.DialContext = func(_ context.Context, network, address string) (net.Conn, error) { + expectedAddress := net.JoinHostPort(tlsName, destinationPort) + return func(_ context.Context, network, address string) (net.Conn, error) { switch network { case "tcp", "tcp4", "tcp6": default: @@ -89,12 +98,6 @@ func newHTTPSClient(baseTransport *http.Transport, destinationTLSName string, co } return connection, nil } - - const timeout = 5 * time.Second - return &http.Client{ - Timeout: timeout, - Transport: transport, - } } func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.AddrPort, err error) { diff --git a/internal/restrictednet/https_test.go b/internal/restrictednet/https_test.go index 02e36fd20..b488f5053 100644 --- a/internal/restrictednet/https_test.go +++ b/internal/restrictednet/https_test.go @@ -2,12 +2,14 @@ package restrictednet import ( "context" - "net" + "fmt" + "net/http" "net/netip" "testing" "github.com/golang/mock/gomock" "github.com/qdm12/dns/v2/pkg/provider" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -33,31 +35,40 @@ func (m listenAddrPortMatcher) String() string { return "is a valid netip.AddrPort with a valid IP and non-zero port" } +type destinationAddrPortMatcher struct { + expected netip.AddrPort +} + +func (m destinationAddrPortMatcher) Matches(x any) bool { + ip, ok := x.(netip.AddrPort) + if !ok { + return false + } + if m.expected.IsValid() { + return ip == m.expected + } + return ip.IsValid() && ip.Port() == m.expected.Port() +} + +func (m destinationAddrPortMatcher) String() string { + if m.expected.IsValid() { + return "is the same as " + m.expected.String() + } + return "matches the port " + fmt.Sprint(m.expected.Port()) +} + func Test_Client_OpenHTTPS(t *testing.T) { t.Parallel() ctx := t.Context() + ctrl := gomock.NewController(t) - netConfig := net.ListenConfig{} - listener, err := netConfig.Listen(ctx, "tcp", "127.0.0.1:0") - require.NoError(t, err) - t.Cleanup(func() { - _ = listener.Close() - }) - listeningPort := uint16(listener.Addr().(*net.TCPAddr).Port) //nolint:gosec,forcetypeassert - go func() { - connection, acceptErr := listener.Accept() - if acceptErr == nil { - _ = connection.Close() - } - }() + const destinationTLSName = "one.one.one.one" + destinationAddrPort := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 443) - ctrl := gomock.NewController(t) firewall := NewMockFirewall(ctrl) - - destination := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), listeningPort) sourceMatcher := listenAddrPortMatcher{} firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - ctx, "tcp", "eth0", sourceMatcher, destination, false, + ctx, "tcp", "eth0", sourceMatcher, destinationAddrPort, false, ).DoAndReturn(func(_ context.Context, _, _ string, source, _ netip.AddrPort, _ bool, ) error { @@ -65,7 +76,7 @@ func Test_Client_OpenHTTPS(t *testing.T) { return nil }) firewall.EXPECT().AcceptOutputFromIPPortToIPPort( - context.Background(), "tcp", "eth0", sourceMatcher, destination, true, + context.Background(), "tcp", "eth0", sourceMatcher, destinationAddrPort, true, ) const ipv6Supported = false @@ -77,13 +88,23 @@ func Test_Client_OpenHTTPS(t *testing.T) { UpstreamResolvers: upstreamResolvers, } client := New(settings) - client.httpsPort = listeningPort - httpClient, cleanup, err := client.OpenHTTPS(ctx, "api.example.com", netip.MustParseAddr("127.0.0.1")) + httpClient, cleanup, err := client.OpenHTTPS(ctx, destinationTLSName, destinationAddrPort) require.NoError(t, err) require.NotNil(t, httpClient) require.NotNil(t, cleanup) + request, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://"+destinationTLSName, nil) + require.NoError(t, err) + + response, err := httpClient.Do(request) + t.Cleanup(func() { + response.Body.Close() + }) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, response.StatusCode) + err = cleanup() require.NoError(t, err) } diff --git a/internal/restrictednet/resolve.go b/internal/restrictednet/resolve.go index 2feffeb52..8a15c39a5 100644 --- a/internal/restrictednet/resolve.go +++ b/internal/restrictednet/resolve.go @@ -76,15 +76,17 @@ func (c *Client) resolveOneQuestionType(ctx context.Context, dohServerIPs = append(dohServerIPs, dohServer.IPv4...) for _, dohServerIP := range dohServerIPs { - responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerIP) + const defaultDoHPort = 443 + dohServerAddrPort := netip.AddrPortFrom(dohServerIP, defaultDoHPort) + responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerAddrPort) switch { case err != nil: - errs = append(errs, fmt.Errorf("querying DoH server %q (ip %s): %w", - dohServer.URL, dohServerIP, err)) + errs = append(errs, fmt.Errorf("querying DoH server %q (%s): %w", + dohServer.URL, dohServerAddrPort, err)) continue case responseMessage.Rcode != dns.RcodeSuccess: - errs = append(errs, fmt.Errorf("querying DoH server %q (ip %s): DNS rcode %s", - dohServer.URL, dohServerIP, dns.RcodeToString[responseMessage.Rcode])) + errs = append(errs, fmt.Errorf("querying DoH server %q (%s): DNS rcode %s", + dohServer.URL, dohServerAddrPort, dns.RcodeToString[responseMessage.Rcode])) continue } addresses := answersToNetipAddrs(responseMessage) @@ -104,9 +106,9 @@ func (c *Client) resolveOneQuestionType(ctx context.Context, } func (c *Client) doHQuery(ctx context.Context, queryWire []byte, - dohURL *url.URL, dohServerIP netip.Addr, + dohURL *url.URL, dohServerAddrPort netip.AddrPort, ) (responseMessage *dns.Msg, err error) { - httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerIP) + httpClient, cleanup, err := c.OpenHTTPS(ctx, dohURL.Hostname(), dohServerAddrPort) if err != nil { return nil, fmt.Errorf("opening https connection: %w", err) } diff --git a/internal/restrictednet/resolve_test.go b/internal/restrictednet/resolve_test.go index 972b5ff12..3ef4b8476 100644 --- a/internal/restrictednet/resolve_test.go +++ b/internal/restrictednet/resolve_test.go @@ -1,15 +1,9 @@ package restrictednet import ( - "bytes" "context" - "errors" - "io" "net" - "net/http" "net/netip" - "net/url" - "sync/atomic" "testing" "github.com/golang/mock/gomock" @@ -21,320 +15,42 @@ import ( func Test_Client_ResolveName(t *testing.T) { t.Parallel() + ctx := t.Context() + ctrl := gomock.NewController(t) - testCases := map[string]struct { - ipv6Supported bool - upstreamResolvers []provider.Provider - expectedAddresses []netip.Addr - errorContains string - expectedDestIPs []netip.Addr - responder func(host string, requestBody io.Reader) (*http.Response, error) - }{ - "success_single_server_ipv4": { - upstreamResolvers: []provider.Provider{{ - DoH: provider.DoHServer{ - URL: "https://resolver-1.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }, - }}, - expectedAddresses: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, - expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - responder: func(_ string, requestBody io.Reader) (*http.Response, error) { - wire := responseWireForQuery(t, requestBody, &dns.A{ - Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, - A: net.IP{1, 1, 1, 1}, - }) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - }, - }, - "fallback_between_servers": { - upstreamResolvers: []provider.Provider{ - { - DoH: provider.DoHServer{ - URL: "https://resolver-1.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }, - }, - { - DoH: provider.DoHServer{ - URL: "https://resolver-2.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }, - }, - }, - expectedAddresses: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, - expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, - responder: func(host string, requestBody io.Reader) (*http.Response, error) { - if host == "resolver-1.local" || - len(host) > len("resolver-1.local:") && host[:len("resolver-1.local:")] == "resolver-1.local:" { - return &http.Response{ - StatusCode: http.StatusBadGateway, - Status: "502 Bad Gateway", - Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), - }, nil - } - wire := responseWireForQuery(t, requestBody, &dns.A{ - Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, - A: net.IP{2, 2, 2, 2}, - }) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - }, - }, - "fallback_between_ips": { - upstreamResolvers: []provider.Provider{{ - DoH: provider.DoHServer{ - URL: "https://resolver.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, - }, - }}, - expectedAddresses: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, - expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, - responder: func() func(host string, requestBody io.Reader) (*http.Response, error) { - var calls atomic.Int32 - return func(_ string, requestBody io.Reader) (*http.Response, error) { - if calls.Add(1) == 1 { // first call fails - return &http.Response{ - StatusCode: http.StatusNotFound, - Status: "502 Bad Gateway", - Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), - }, nil - } - wire := responseWireForQuery(t, requestBody, &dns.A{ - Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, - A: net.IP{1, 1, 1, 2}, - }) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - } - }(), //nolint:bodyclose - }, - "dns_rcode_error_servfail": { - upstreamResolvers: []provider.Provider{{ - DoH: provider.DoHServer{ - URL: "https://resolver.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }, - }}, - errorContains: "SERVFAIL", - expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - responder: func(_ string, requestBody io.Reader) (*http.Response, error) { - queryWire, err := io.ReadAll(requestBody) - require.NoError(t, err) - query := new(dns.Msg) - err = query.Unpack(queryWire) - require.NoError(t, err) - response := new(dns.Msg) - response.SetReply(query) - response.Rcode = dns.RcodeServerFailure - wire, err := response.Pack() - require.NoError(t, err) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - }, - }, - "no_answer": { - upstreamResolvers: []provider.Provider{{ - DoH: provider.DoHServer{ - URL: "https://resolver.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }, - }}, - expectedAddresses: nil, - expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - responder: func(_ string, requestBody io.Reader) (*http.Response, error) { - wire := responseWireForQuery(t, requestBody) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - }, - }, - "ipv6_preference": { - ipv6Supported: true, - upstreamResolvers: []provider.Provider{{ - DoH: provider.DoHServer{ - URL: "https://resolver.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - IPv6: []netip.Addr{netip.MustParseAddr("::1")}, - }, - }}, - expectedAddresses: []netip.Addr{netip.MustParseAddr("2001:4860:4860::8888")}, - expectedDestIPs: []netip.Addr{ - netip.MustParseAddr("::1"), - netip.MustParseAddr("::1"), - netip.MustParseAddr("127.0.0.1"), - }, - responder: func(_ string, requestBody io.Reader) (*http.Response, error) { - queryWire, err := io.ReadAll(requestBody) - require.NoError(t, err) - query := new(dns.Msg) - err = query.Unpack(queryWire) - require.NoError(t, err) - if len(query.Question) > 0 && query.Question[0].Qtype == dns.TypeA { - wire := responseWireForQuery(t, bytes.NewReader(queryWire)) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - } - wire := responseWireForQuery(t, bytes.NewReader(queryWire), &dns.AAAA{ - Hdr: dns.RR_Header{Name: "github.com.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET}, - AAAA: net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0, 0, 0, 0, 0, 0, 0, 0, 0x88, 0x88}, - }) - return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(wire))}, nil - }, - }, - "all_servers_fail": { - upstreamResolvers: []provider.Provider{ - {DoH: provider.DoHServer{ - URL: "https://resolver-1.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }}, - {DoH: provider.DoHServer{ - URL: "https://resolver-2.local/dns-query", - IPv4: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, - }}, - }, - errorContains: "resolving host", - expectedDestIPs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("127.0.0.1")}, - responder: func(_ string, _ io.Reader) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusBadGateway, - Status: "502 Bad Gateway", - Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), - }, nil - }, - }, + firewall := NewMockFirewall(ctrl) + sourceMatcher := listenAddrPortMatcher{} + destinationMatcher := destinationAddrPortMatcher{ + expected: netip.AddrPortFrom(netip.Addr{}, 443), } - for testName, testCase := range testCases { - t.Run(testName, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - firewall := NewMockFirewall(ctrl) - port := startTCPAccepter(t) - - for _, destinationIP := range testCase.expectedDestIPs { - expectFirewallCallPair(firewall, t.Context(), destinationIP, port, nil, nil) - } - - resolvers := make([]provider.Provider, len(testCase.upstreamResolvers)) - copy(resolvers, testCase.upstreamResolvers) - for i := range resolvers { - resolvers[i].DoH.URL = urlToHostnamePort(resolvers[i].DoH.URL, port) - } + // Add rule + firstCall := firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + ctx, "tcp", "eth0", sourceMatcher, destinationMatcher, false, + ).DoAndReturn(func( + _ context.Context, _, _ string, source, destination netip.AddrPort, _ bool, + ) error { + sourceMatcher.expected = source + destinationMatcher.expected = destination + return nil + }) - settings := Settings{ - DefaultInterface: "eth0", - IPv6Supported: ptrTo(testCase.ipv6Supported), - Firewall: firewall, - UpstreamResolvers: resolvers, - BaseTransport: newInterceptTransport(testCase.responder), - } - client := New(settings) - client.httpsPort = port + // Removal rule + firewall.EXPECT().AcceptOutputFromIPPortToIPPort( + context.Background(), "tcp", "eth0", sourceMatcher, destinationMatcher, true, + ).Return(nil).After(firstCall) - addresses, err := client.ResolveName(t.Context(), "github.com") - assert.Equal(t, testCase.expectedAddresses, addresses) - if testCase.errorContains != "" { - require.Error(t, err) - assert.ErrorContains(t, err, testCase.errorContains) - } else { - require.NoError(t, err) - } - }) + settings := Settings{ + DefaultInterface: "eth0", + IPv6Supported: ptrTo(false), + Firewall: firewall, + UpstreamResolvers: []provider.Provider{provider.Cloudflare()}, } -} + client := New(settings) -func Test_Client_doHQuery(t *testing.T) { - t.Parallel() - - query := new(dns.Msg) - query.SetQuestion("example.com.", dns.TypeA) - queryWire, err := query.Pack() + addresses, err := client.ResolveName(ctx, "github.com") require.NoError(t, err) - - responseWire := responseWireForQuery(t, bytes.NewReader(queryWire), &dns.A{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET}, - A: net.IP{1, 1, 1, 1}, - }) - - testCases := map[string]struct { - response *http.Response - addFirewallRuleErr error - removeFirewallRuleErr error - errorContains string - expectedIPs []netip.Addr - }{ - "success": { - response: &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(responseWire))}, - expectedIPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, - }, - "http_status_not_ok": { - response: &http.Response{ - StatusCode: http.StatusBadGateway, - Status: "502 Bad Gateway", - Body: io.NopCloser(bytes.NewReader([]byte("bad gateway"))), - }, - errorContains: "response status code is 502 Bad Gateway", - }, - "malformed_dns_response": { - response: &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString("not-dns")), - }, - errorContains: "parsing DoH response", - }, - "cleanup_error": { - response: &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(responseWire))}, - removeFirewallRuleErr: errors.New("cleanup failed"), - errorContains: "cleaning up https connection: removing output traffic rule: cleanup failed", - }, - } - - for name, testCase := range testCases { - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - firewall := NewMockFirewall(ctrl) - port := startTCPAccepter(t) - - expectFirewallCallPair( - firewall, - context.Background(), - netip.MustParseAddr("127.0.0.1"), - port, - testCase.addFirewallRuleErr, - testCase.removeFirewallRuleErr, - ) - - settings := Settings{ - DefaultInterface: "eth0", - IPv6Supported: ptrTo(false), - Firewall: firewall, - UpstreamResolvers: []provider.Provider{provider.Google()}, - BaseTransport: newInterceptTransport(func(_ string, _ io.Reader) (*http.Response, error) { - return testCase.response, nil - }), - } - client := New(settings) - client.httpsPort = port - - dohURL, err := url.Parse(urlToHostnamePort("https://resolver.local/dns-query", port)) - require.NoError(t, err) - - message, err := client.doHQuery( - context.Background(), - queryWire, - dohURL, - netip.MustParseAddr("127.0.0.1"), - ) - - if testCase.errorContains != "" { - require.Error(t, err) - assert.ErrorContains(t, err, testCase.errorContains) - return - } - - require.NoError(t, err) - addresses := answersToNetipAddrs(message) - assert.Equal(t, testCase.expectedIPs, addresses) - }) - } + assert.NotEmpty(t, addresses) } func Test_answersToNetipAddrs(t *testing.T) { diff --git a/internal/restrictednet/settings.go b/internal/restrictednet/settings.go index 4b943b52c..52c678c37 100644 --- a/internal/restrictednet/settings.go +++ b/internal/restrictednet/settings.go @@ -2,7 +2,6 @@ package restrictednet import ( "errors" - "net/http" "github.com/qdm12/dns/v2/pkg/provider" ) @@ -12,13 +11,6 @@ type Settings struct { IPv6Supported *bool Firewall Firewall UpstreamResolvers []provider.Provider - BaseTransport *http.Transport -} - -func (s *Settings) setDefaults() { - if s.BaseTransport == nil { - s.BaseTransport = http.DefaultTransport.(*http.Transport) //nolint:forcetypeassert - } } func (s *Settings) validate() error { From 29186feccc8af52cb93a9e6ab8d6d09f3395d9a1 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 9 Jun 2026 14:07:05 +0000 Subject: [PATCH 12/18] Fix ordering in cleanup function --- internal/restrictednet/https.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 08ae73504..ea08c6c8b 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -45,16 +45,16 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti cleanup = func() error { var errs []error httpClient.CloseIdleConnections() + err = connection.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + errs = append(errs, fmt.Errorf("closing connection: %w", err)) + } const remove = true err := c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface, sourceAddrPort, destinationAddrPort, remove) if err != nil { errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) } - err = connection.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - errs = append(errs, fmt.Errorf("closing connection: %w", err)) - } if len(errs) > 0 { return errors.Join(errs...) } From 69b4e5c584653e04b18f4fb4456aa96b62bb4bfa Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 9 Jun 2026 21:11:15 +0000 Subject: [PATCH 13/18] PR feedback fixes --- internal/restrictednet/client.go | 6 +++++- internal/restrictednet/https.go | 14 +++++++++++--- internal/restrictednet/https_test.go | 6 +++--- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index 82091a75c..fb070e8a0 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -40,7 +40,11 @@ func New(settings Settings) *Client { } } -func (c *Client) OpenHTTPSByDomain(ctx context.Context, hostname string) ( +// OpenHTTPSByHostname opens an https connection through the firewall, +// valid for up to one second, to the hostname which in the format `host:port`. +// It first resolves the domain in hostname using DNS over HTTPS and then opens +// the restricted HTTPS connection to the resolved IP. +func (c *Client) OpenHTTPSByHostname(ctx context.Context, hostname string) ( httpClient *http.Client, cleanup func() error, err error, ) { host, portStr, err := net.SplitHostPort(hostname) diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index ea08c6c8b..6912eff03 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -45,12 +45,12 @@ func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, desti cleanup = func() error { var errs []error httpClient.CloseIdleConnections() - err = connection.Close() + err := connection.Close() if err != nil && !errors.Is(err, net.ErrClosed) { errs = append(errs, fmt.Errorf("closing connection: %w", err)) } const remove = true - err := c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface, + err = c.firewall.AcceptOutputFromIPPortToIPPort(context.Background(), "tcp", c.outboundInterface, sourceAddrPort, destinationAddrPort, remove) if err != nil { errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err)) @@ -85,9 +85,17 @@ func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client { } func makeDial(connection net.Conn, tlsName string) dialFunc { - _, destinationPort, _ := net.SplitHostPort(connection.RemoteAddr().String()) + _, destinationPort, err := net.SplitHostPort(connection.RemoteAddr().String()) + if err != nil { + panic(err) // connection remote address should always be in the form "host:port" + } expectedAddress := net.JoinHostPort(tlsName, destinationPort) + used := false return func(_ context.Context, network, address string) (net.Conn, error) { + if used { + return nil, errors.New("dial function called more than once") + } + used = true switch network { case "tcp", "tcp4", "tcp6": default: diff --git a/internal/restrictednet/https_test.go b/internal/restrictednet/https_test.go index b488f5053..a977b5ff7 100644 --- a/internal/restrictednet/https_test.go +++ b/internal/restrictednet/https_test.go @@ -77,7 +77,7 @@ func Test_Client_OpenHTTPS(t *testing.T) { }) firewall.EXPECT().AcceptOutputFromIPPortToIPPort( context.Background(), "tcp", "eth0", sourceMatcher, destinationAddrPort, true, - ) + ).Return(nil) const ipv6Supported = false upstreamResolvers := []provider.Provider{provider.Google()} @@ -98,10 +98,10 @@ func Test_Client_OpenHTTPS(t *testing.T) { require.NoError(t, err) response, err := httpClient.Do(request) + require.NoError(t, err) t.Cleanup(func() { - response.Body.Close() + _ = response.Body.Close() }) - require.NoError(t, err) assert.Equal(t, http.StatusOK, response.StatusCode) From d28744e06d30450b90446dd8b07d5b21c623f5a6 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 11 Jun 2026 00:16:32 +0000 Subject: [PATCH 14/18] pr review changes --- .github/workflows/ci.yml | 4 ++++ .vscode/settings.json | 2 +- internal/restrictednet/client.go | 3 ++- internal/restrictednet/https.go | 2 ++ internal/restrictednet/https_test.go | 23 +++++++++++++------- internal/restrictednet/resolve.go | 29 ++++++++++++++++++++++++-- internal/restrictednet/resolve_test.go | 2 ++ internal/restrictednet/unix.go | 2 +- 8 files changed, 54 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d3cd47f13..df57e3970 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -63,6 +63,10 @@ jobs: -v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \ test-container + - name: Run integration tests in test container + run: | + docker run --rm --entrypoint "go test -tags=integration ./internal/restrictednet" test-container + - name: Verify dev cross platform compatibility run: docker build --target xcompile . diff --git a/.vscode/settings.json b/.vscode/settings.json index f7e463972..2346a6913 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,7 +3,7 @@ // to develop this project. "files.eol": "\n", "editor.formatOnSave": true, - "go.buildTags": "linux", + "go.buildTags": "linux,integration", "go.toolsEnvVars": { "CGO_ENABLED": "0" }, diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index fb070e8a0..6b355f346 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -41,7 +41,8 @@ func New(settings Settings) *Client { } // OpenHTTPSByHostname opens an https connection through the firewall, -// valid for up to one second, to the hostname which in the format `host:port`. +// to the hostname which in the format `host:port`. The returned cleanup +// function must be called to remove the temporary firewall rule and close connections. // It first resolves the domain in hostname using DNS over HTTPS and then opens // the restricted HTTPS connection to the resolved IP. func (c *Client) OpenHTTPSByHostname(ctx context.Context, hostname string) ( diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 6912eff03..10def5c61 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -16,6 +16,8 @@ import ( ) // OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination. +// The returned [*http.Client] must be used sequentially only, and each request must +// have its response body fully read/discarded and then closed. // The returned cleanup function must be called to remove the temporary firewall rule and close connections. func (c *Client) OpenHTTPS(ctx context.Context, destinationTLSName string, destinationAddrPort netip.AddrPort, ) (httpClient *http.Client, cleanup func() error, err error) { diff --git a/internal/restrictednet/https_test.go b/internal/restrictednet/https_test.go index a977b5ff7..2a300cfcc 100644 --- a/internal/restrictednet/https_test.go +++ b/internal/restrictednet/https_test.go @@ -1,8 +1,11 @@ +//go:build integration + package restrictednet import ( "context" "fmt" + "io" "net/http" "net/netip" "testing" @@ -94,16 +97,20 @@ func Test_Client_OpenHTTPS(t *testing.T) { require.NotNil(t, httpClient) require.NotNil(t, cleanup) - request, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://"+destinationTLSName, nil) - require.NoError(t, err) + const requests = 2 - response, err := httpClient.Do(request) - require.NoError(t, err) - t.Cleanup(func() { - _ = response.Body.Close() - }) + for range requests { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://"+destinationTLSName, nil) + require.NoError(t, err) - assert.Equal(t, http.StatusOK, response.StatusCode) + response, err := httpClient.Do(request) + require.NoError(t, err) + _, err = io.Copy(io.Discard, response.Body) + require.NoError(t, err) + err = response.Body.Close() + require.NoError(t, err) + assert.Equal(t, http.StatusOK, response.StatusCode) + } err = cleanup() require.NoError(t, err) diff --git a/internal/restrictednet/resolve.go b/internal/restrictednet/resolve.go index 8a15c39a5..ab71a8e7c 100644 --- a/internal/restrictednet/resolve.go +++ b/internal/restrictednet/resolve.go @@ -9,6 +9,7 @@ import ( "net/http" "net/netip" "net/url" + "strconv" "github.com/miekg/dns" ) @@ -76,8 +77,16 @@ func (c *Client) resolveOneQuestionType(ctx context.Context, dohServerIPs = append(dohServerIPs, dohServer.IPv4...) for _, dohServerIP := range dohServerIPs { - const defaultDoHPort = 443 - dohServerAddrPort := netip.AddrPortFrom(dohServerIP, defaultDoHPort) + const defaultDoHPort uint16 = 443 + port := defaultDoHPort + if portStr := dohURL.Port(); portStr != "" { + port, err = parseDestinationPort(portStr) + if err != nil { + errs = append(errs, fmt.Errorf("parsing DoH server port: %w", err)) + continue + } + } + dohServerAddrPort := netip.AddrPortFrom(dohServerIP, port) responseMessage, err := c.doHQuery(ctx, queryWire, dohURL, dohServerAddrPort) switch { case err != nil: @@ -178,3 +187,19 @@ func answersToNetipAddrs(message *dns.Msg) (addresses []netip.Addr) { } return addresses } + +func parseDestinationPort(portStr string) (port uint16, err error) { + portUint, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return 0, err + } + + const maxPortUint = 65535 + switch { + case portUint == 0: + return 0, errors.New("port cannot be 0") + case portUint > maxPortUint: + return 0, fmt.Errorf("port cannot be greater than %d", maxPortUint) + } + return uint16(portUint), nil +} diff --git a/internal/restrictednet/resolve_test.go b/internal/restrictednet/resolve_test.go index 3ef4b8476..0fe602c23 100644 --- a/internal/restrictednet/resolve_test.go +++ b/internal/restrictednet/resolve_test.go @@ -1,3 +1,5 @@ +//go:build integration + package restrictednet import ( diff --git a/internal/restrictednet/unix.go b/internal/restrictednet/unix.go index 76895943e..968f8d309 100644 --- a/internal/restrictednet/unix.go +++ b/internal/restrictednet/unix.go @@ -1,4 +1,4 @@ -//go:build unix +//go:build !windows package restrictednet From 9af6aaff27669a24f9500f787d5cd80813e898f7 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 11 Jun 2026 01:17:55 +0000 Subject: [PATCH 15/18] PR feedback --- .github/workflows/ci.yml | 2 +- internal/restrictednet/client.go | 2 ++ internal/restrictednet/helpers_test.go | 2 ++ internal/restrictednet/https.go | 1 - .../restrictednet/{https_test.go => https_integration_test.go} | 0 .../{resolve_test.go => resolve_integration_test.go} | 0 6 files changed, 5 insertions(+), 2 deletions(-) rename internal/restrictednet/{https_test.go => https_integration_test.go} (100%) rename internal/restrictednet/{resolve_test.go => resolve_integration_test.go} (100%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index df57e3970..ed72f7515 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,7 +65,7 @@ jobs: - name: Run integration tests in test container run: | - docker run --rm --entrypoint "go test -tags=integration ./internal/restrictednet" test-container + docker run --rm --entrypoint go test-container test -tags=integration ./internal/restrictednet - name: Verify dev cross platform compatibility run: docker build --target xcompile . diff --git a/internal/restrictednet/client.go b/internal/restrictednet/client.go index 6b355f346..040557583 100644 --- a/internal/restrictednet/client.go +++ b/internal/restrictednet/client.go @@ -62,6 +62,8 @@ func (c *Client) OpenHTTPSByHostname(ctx context.Context, hostname string) ( portUint, err := strconv.ParseUint(portStr, 10, 16) if err != nil { return nil, nil, fmt.Errorf("parsing port: %w", err) + } else if portUint == 0 { + return nil, nil, errors.New("destination port cannot be 0") } port := uint16(portUint) diff --git a/internal/restrictednet/helpers_test.go b/internal/restrictednet/helpers_test.go index 54070c324..091aefd1c 100644 --- a/internal/restrictednet/helpers_test.go +++ b/internal/restrictednet/helpers_test.go @@ -1,3 +1,5 @@ +//go:build integration + package restrictednet func ptrTo[T any](value T) *T { diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 10def5c61..209b5a9f8 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -73,7 +73,6 @@ func newHTTPSClient(destinationTLSName string, dial dialFunc) *http.Client { MaxIdleConns: 1, MaxIdleConnsPerHost: 1, MaxConnsPerHost: 1, - IdleConnTimeout: time.Second, TLSClientConfig: &tls.Config{ MinVersion: tls.VersionTLS12, ServerName: destinationTLSName, diff --git a/internal/restrictednet/https_test.go b/internal/restrictednet/https_integration_test.go similarity index 100% rename from internal/restrictednet/https_test.go rename to internal/restrictednet/https_integration_test.go diff --git a/internal/restrictednet/resolve_test.go b/internal/restrictednet/resolve_integration_test.go similarity index 100% rename from internal/restrictednet/resolve_test.go rename to internal/restrictednet/resolve_integration_test.go From 70d80f7473f66a5b04ec7e88bf590267d271ecd5 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 11 Jun 2026 13:06:05 +0000 Subject: [PATCH 16/18] context aware connectFD --- internal/restrictednet/https.go | 19 ++--------- internal/restrictednet/unix.go | 52 +++++++++++++++++++++++++++++-- internal/restrictednet/windows.go | 3 +- 3 files changed, 54 insertions(+), 20 deletions(-) diff --git a/internal/restrictednet/https.go b/internal/restrictednet/https.go index 209b5a9f8..1ad6d891b 100644 --- a/internal/restrictednet/https.go +++ b/internal/restrictednet/https.go @@ -144,24 +144,9 @@ func bindSourceConnection(destinationIP netip.Addr) (fd int, sourceAddr netip.Ad func connectSourceConnection(ctx context.Context, fd int, destinationAddrPort netip.AddrPort) ( connection net.Conn, err error, ) { - errCh := make(chan error) - go func() { - errCh <- connectFD(fd, destinationAddrPort) - }() - - select { - case err = <-errCh: - if err != nil { - closeFD(fd) - return nil, fmt.Errorf("connecting socket: %w", err) - } - case <-ctx.Done(): - err = ctx.Err() + err = connectFD(ctx, fd, destinationAddrPort) + if err != nil { closeFD(fd) - connectErr := <-errCh - if connectErr != nil { - err = fmt.Errorf("%w (%w)", connectErr, err) - } return nil, fmt.Errorf("connecting socket: %w", err) } diff --git a/internal/restrictednet/unix.go b/internal/restrictednet/unix.go index 968f8d309..387233cc7 100644 --- a/internal/restrictednet/unix.go +++ b/internal/restrictednet/unix.go @@ -3,8 +3,11 @@ package restrictednet import ( + "context" + "errors" "fmt" "net/netip" + "time" "golang.org/x/sys/unix" ) @@ -22,8 +25,53 @@ func bindFD(fd int, address netip.AddrPort) error { return unix.Bind(fd, bindAddr) } -func connectFD(fd int, destination netip.AddrPort) error { - return unix.Connect(fd, makeSockAddr(destination)) +func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error { + err := unix.Connect(fd, makeSockAddr(destination)) + switch { + case err == nil: + return nil + case !errors.Is(err, unix.EINPROGRESS): + return err + } + + for { + select { + case <-ctx.Done(): + err = unix.Close(fd) + if err != nil { + return fmt.Errorf("error closing fd: %w (%w)", err, ctx.Err()) + } + return ctx.Err() + default: + wset := &unix.FdSet{} + wset.Bits[fd/64] |= 1 << (uint(fd) % 64) + eset := &unix.FdSet{} + eset.Bits[fd/64] |= 1 << (uint(fd) % 64) + const selectTimeout = 50 * time.Millisecond + timeval := unix.NsecToTimeval(int64(selectTimeout)) + + // Wait for the FD to become writable or hit an error state + n, err := unix.Select(fd+1, nil, wset, eset, &timeval) + if err != nil { + if errors.Is(err, unix.EINTR) { + continue // Syscall interrupted, try again + } + return fmt.Errorf("select error: %w", err) + } else if n == 0 { + continue // no status change yet + } + + // Check if the socket encountered an error + n, err = unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR) + if err != nil { + return fmt.Errorf("getsockopt error: %w", err) + } else if n != 0 { + return fmt.Errorf("connect failed asynchronously: %w", unix.Errno(n)) + } + + return nil + } + } } func fdToSourceAddr(fd int) (sourceAddrPort netip.AddrPort, err error) { diff --git a/internal/restrictednet/windows.go b/internal/restrictednet/windows.go index e1b88453a..454fc2c62 100644 --- a/internal/restrictednet/windows.go +++ b/internal/restrictednet/windows.go @@ -3,6 +3,7 @@ package restrictednet import ( + "context" "net/netip" ) @@ -18,7 +19,7 @@ func bindFD(fd int, address netip.AddrPort) error { panic("not implemented") } -func connectFD(fd int, destination netip.AddrPort) error { +func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error { panic("not implemented") } From b44c6712179f975ae75fda3785b9c218a6fd2cb8 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 11 Jun 2026 13:36:08 +0000 Subject: [PATCH 17/18] lint fix --- internal/restrictednet/unix.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/restrictednet/unix.go b/internal/restrictednet/unix.go index 387233cc7..d91ea688e 100644 --- a/internal/restrictednet/unix.go +++ b/internal/restrictednet/unix.go @@ -44,9 +44,9 @@ func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error { return ctx.Err() default: wset := &unix.FdSet{} - wset.Bits[fd/64] |= 1 << (uint(fd) % 64) + wset.Bits[fd/64] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd eset := &unix.FdSet{} - eset.Bits[fd/64] |= 1 << (uint(fd) % 64) + eset.Bits[fd/64] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd const selectTimeout = 50 * time.Millisecond timeval := unix.NsecToTimeval(int64(selectTimeout)) From 08dfd733678dcec8a97d5641b67a3856a707ca55 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 11 Jun 2026 14:01:05 +0000 Subject: [PATCH 18/18] pr review feedback --- internal/restrictednet/unix.go | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/internal/restrictednet/unix.go b/internal/restrictednet/unix.go index d91ea688e..bb52bd310 100644 --- a/internal/restrictednet/unix.go +++ b/internal/restrictednet/unix.go @@ -17,7 +17,16 @@ func closeFD(fd int) { } func newTCPSockStream(family int) (fd int, err error) { - return unix.Socket(family, unix.SOCK_STREAM, unix.IPPROTO_TCP) + fd, err = unix.Socket(family, unix.SOCK_STREAM, unix.IPPROTO_TCP) + if err != nil { + return 0, err + } + err = unix.SetNonblock(fd, true) + if err != nil { + _ = unix.Close(fd) + return 0, err + } + return fd, nil } func bindFD(fd int, address netip.AddrPort) error { @@ -37,16 +46,16 @@ func connectFD(ctx context.Context, fd int, destination netip.AddrPort) error { for { select { case <-ctx.Done(): - err = unix.Close(fd) - if err != nil { - return fmt.Errorf("error closing fd: %w (%w)", err, ctx.Err()) - } return ctx.Err() default: + bitsIndex := fd / 64 //nolint:mnd + if bitsIndex >= len(unix.FdSet{}.Bits) { + return fmt.Errorf("fd %d exceeds unix.Select FdSet capacity", fd) + } wset := &unix.FdSet{} - wset.Bits[fd/64] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd + wset.Bits[bitsIndex] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd eset := &unix.FdSet{} - eset.Bits[fd/64] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd + eset.Bits[bitsIndex] |= 1 << (uint64(fd) % 64) //nolint:gosec,mnd const selectTimeout = 50 * time.Millisecond timeval := unix.NsecToTimeval(int64(selectTimeout))