Skip to content

Commit 6084c3d

Browse files
Merge pull request #129 from dotindustries/feat/patch-netdial-dns-resolution-via-warpgrid-shim-bye1t93yco4s503zgcoqdqn7
feat: patch net.Dial DNS resolution via WarpGrid shim
2 parents e0cae78 + a581ed8 commit 6084c3d

File tree

10 files changed

+721
-4
lines changed

10 files changed

+721
-4
lines changed

crates/warpgrid-host/tests/integration_go_dns_dial.rs

Lines changed: 408 additions & 0 deletions
Large diffs are not rendered by default.

packages/warpgrid-go/net/dial.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) {
6262
return nil, &net.OpError{
6363
Op: "dial",
6464
Net: network,
65-
Err: &net.DNSError{
65+
Err: &DNSError{
6666
Err: err.Error(),
6767
Name: host,
6868
IsNotFound: true,
@@ -74,7 +74,7 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) {
7474
return nil, &net.OpError{
7575
Op: "dial",
7676
Net: network,
77-
Err: &net.DNSError{
77+
Err: &DNSError{
7878
Err: "no addresses found",
7979
Name: host,
8080
IsNotFound: true,
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Non-WASI fallback convenience functions for DNS-aware dialing.
2+
//
3+
// On standard Go (non-WASI), there is no WarpGrid DNS shim backend.
4+
// These functions fall through to the standard library's net.Dial so
5+
// that code importing this package compiles and works in native
6+
// development and testing environments.
7+
8+
//go:build !wasip1 && !wasip2
9+
10+
package net
11+
12+
import (
13+
"net"
14+
"time"
15+
)
16+
17+
// Dial connects to the address on the named network.
18+
//
19+
// On non-WASI targets this delegates directly to net.Dial from the
20+
// standard library since no WarpGrid DNS shim is available.
21+
func Dial(network, address string) (net.Conn, error) {
22+
return net.Dial(network, address)
23+
}
24+
25+
// DialTimeout is like Dial but with a connection timeout.
26+
//
27+
// On non-WASI targets this delegates to net.DialTimeout.
28+
func DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
29+
return net.DialTimeout(network, address, timeout)
30+
}

packages/warpgrid-go/net/dial_test.go

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,53 @@ func TestDial_DNSErrorContainsHostname(t *testing.T) {
387387
}
388388
}
389389

390+
// ── Package-level Dial() convenience function tests ─────────────────
391+
392+
func TestPackageDial_HostnameResolvedViaDNS(t *testing.T) {
393+
// Package-level Dial on non-WASI falls through to net.Dial, so it
394+
// won't use the WarpGrid DNS shim. We test that the function at
395+
// least compiles and works with an IP literal (which bypasses DNS
396+
// on all platforms).
397+
addr, cleanup := startEchoServer(t)
398+
defer cleanup()
399+
400+
conn, err := wgnet.Dial("tcp", addr)
401+
if err != nil {
402+
t.Fatalf("wgnet.Dial failed: %v", err)
403+
}
404+
defer conn.Close()
405+
406+
// Echo round-trip
407+
message := "Hello from package-level Dial!"
408+
_, err = conn.Write([]byte(message))
409+
if err != nil {
410+
t.Fatalf("Write failed: %v", err)
411+
}
412+
413+
buf := make([]byte, len(message))
414+
_, err = io.ReadFull(conn, buf)
415+
if err != nil {
416+
t.Fatalf("Read failed: %v", err)
417+
}
418+
419+
if string(buf) != message {
420+
t.Fatalf("expected %q, got %q", message, string(buf))
421+
}
422+
}
423+
424+
func TestPackageDialTimeout_RespectsTimeout(t *testing.T) {
425+
start := time.Now()
426+
_, err := wgnet.DialTimeout("tcp", "192.0.2.1:65535", 200*time.Millisecond)
427+
elapsed := time.Since(start)
428+
429+
if err == nil {
430+
t.Fatal("expected error dialing unreachable address")
431+
}
432+
if elapsed > 5*time.Second {
433+
t.Fatalf("DialTimeout not respected: took %v (expected <5s with 200ms timeout)", elapsed)
434+
}
435+
}
436+
390437
// ── ConnectTimeout tests ────────────────────────────────────────────
391438

392439
func TestDial_ConnectTimeoutIsApplied(t *testing.T) {
@@ -433,7 +480,7 @@ func TestDial_DNSFailureWrapsInnerDNSError(t *testing.T) {
433480
}
434481

435482
// The inner error should be *net.DNSError with correct fields
436-
var dnsErr *net.DNSError
483+
var dnsErr *wgnet.DNSError
437484
if !errors.As(opErr.Err, &dnsErr) {
438485
t.Fatalf("expected inner *net.DNSError, got %T: %v", opErr.Err, opErr.Err)
439486
}
@@ -462,7 +509,7 @@ func TestDial_DNSEmptyResultWrapsInnerDNSError(t *testing.T) {
462509
t.Fatalf("expected *net.OpError, got %T: %v", err, err)
463510
}
464511

465-
var dnsErr *net.DNSError
512+
var dnsErr *wgnet.DNSError
466513
if !errors.As(opErr.Err, &dnsErr) {
467514
t.Fatalf("expected inner *net.DNSError, got %T: %v", opErr.Err, opErr.Err)
468515
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// WASI-specific convenience functions for DNS-aware dialing.
2+
//
3+
// On WASI targets, DefaultDialer() returns a Dialer wired to the
4+
// WarpGrid DNS shim backend. The package-level Dial() and DialTimeout()
5+
// functions provide a drop-in API that resolves hostnames via the shim
6+
// before connecting.
7+
//
8+
// This file is only compiled when targeting WASI (wasip1 or wasip2).
9+
10+
//go:build wasip1 || wasip2
11+
12+
package net
13+
14+
import (
15+
"net"
16+
"time"
17+
18+
"github.com/anthropics/warpgrid/packages/warpgrid-go/dns"
19+
)
20+
21+
// DefaultDialer returns a Dialer configured with the WASI DNS backend.
22+
// Use this when you need to customise timeouts or other Dialer fields.
23+
func DefaultDialer() *Dialer {
24+
return NewDialer(dns.DefaultResolver())
25+
}
26+
27+
// Dial connects to the address on the named network, resolving
28+
// hostnames via the WarpGrid DNS shim. IP literals bypass DNS.
29+
//
30+
// This is the package-level convenience wrapper around DefaultDialer().Dial.
31+
func Dial(network, address string) (net.Conn, error) {
32+
return DefaultDialer().Dial(network, address)
33+
}
34+
35+
// DialTimeout is like Dial but with a per-address connection timeout.
36+
func DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
37+
d := DefaultDialer()
38+
d.ConnectTimeout = timeout
39+
return d.Dial(network, address)
40+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// DNSError polyfill for environments where net.DNSError is unavailable.
2+
//
3+
// TinyGo's net package does not include net.DNSError, so we define a
4+
// compatible type here. On standard Go, we use net.DNSError directly
5+
// (see dnserror_std.go). This file is only compiled under TinyGo.
6+
7+
//go:build tinygo
8+
9+
package net
10+
11+
// DNSError represents a DNS lookup failure, compatible with Go's
12+
// net.DNSError interface. This polyfill is used by TinyGo builds
13+
// where net.DNSError is not available.
14+
type DNSError struct {
15+
Err string
16+
Name string
17+
IsNotFound bool
18+
}
19+
20+
func (e *DNSError) Error() string {
21+
s := "lookup " + e.Name
22+
if e.Err != "" {
23+
s += ": " + e.Err
24+
}
25+
return s
26+
}
27+
28+
func (e *DNSError) Timeout() bool { return false }
29+
func (e *DNSError) Temporary() bool { return false }
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// DNSError alias for standard Go (non-TinyGo) builds.
2+
//
3+
// On standard Go, net.DNSError is available directly. We re-export it
4+
// so that dial.go can use DNSError without conditional imports.
5+
6+
//go:build !tinygo
7+
8+
package net
9+
10+
import "net"
11+
12+
// DNSError is an alias for net.DNSError on standard Go builds.
13+
type DNSError = net.DNSError
706 KB
Binary file not shown.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module go-dns-dial-guest
2+
3+
go 1.22.0
4+
5+
require github.com/anthropics/warpgrid/packages/warpgrid-go v0.0.0
6+
7+
replace github.com/anthropics/warpgrid/packages/warpgrid-go => ../../../packages/warpgrid-go
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Package main is a WASI guest fixture that exercises DNS resolution
2+
// through the WarpGrid DNS shim (warpgrid_shim dns_resolve).
3+
//
4+
// Each exported function tests a specific DNS resolution scenario via
5+
// dns.DefaultResolver(). The host integration test registers a service
6+
// registry mapping test hostnames to known IPs, then calls these exports
7+
// and validates the results.
8+
//
9+
// Build: tinygo build -target=wasi -buildmode=c-shared -o go-dns-dial-guest.wasm .
10+
//
11+
// Reactor mode (-buildmode=c-shared) is required so that //go:wasmexport
12+
// functions work after _initialize.
13+
package main
14+
15+
import (
16+
"errors"
17+
"fmt"
18+
"net"
19+
"strings"
20+
"unsafe"
21+
22+
wgdns "github.com/anthropics/warpgrid/packages/warpgrid-go/dns"
23+
wgnet "github.com/anthropics/warpgrid/packages/warpgrid-go/net"
24+
)
25+
26+
func main() {}
27+
28+
// ── Exported test functions ─────────────────────────────────────────
29+
//
30+
// Each function returns a packed uint64: high 32 bits = pointer to
31+
// result string, low 32 bits = length. The host reads the string from
32+
// linear memory. Format: "OK:<data>" on success, "ERR:<message>" on failure.
33+
34+
// testResolveRegistry resolves a hostname expected to be in the service
35+
// registry and returns the first resolved IP address.
36+
//
37+
//go:wasmexport test-resolve-registry
38+
func testResolveRegistry() uint64 {
39+
resolver := wgdns.DefaultResolver()
40+
ips, err := resolver.Resolve("echo-server.test.warp.local")
41+
if err != nil {
42+
return writeResult(fmt.Sprintf("ERR:resolve failed: %v", err))
43+
}
44+
if len(ips) == 0 {
45+
return writeResult("ERR:no addresses returned")
46+
}
47+
return writeResult(fmt.Sprintf("OK:%s", ips[0].String()))
48+
}
49+
50+
// testResolveMultiple resolves a hostname with multiple A records and
51+
// returns all addresses comma-separated.
52+
//
53+
//go:wasmexport test-resolve-multiple
54+
func testResolveMultiple() uint64 {
55+
resolver := wgdns.DefaultResolver()
56+
ips, err := resolver.Resolve("multi.test.warp.local")
57+
if err != nil {
58+
return writeResult(fmt.Sprintf("ERR:resolve failed: %v", err))
59+
}
60+
if len(ips) == 0 {
61+
return writeResult("ERR:no addresses returned")
62+
}
63+
parts := make([]string, len(ips))
64+
for i, ip := range ips {
65+
parts[i] = ip.String()
66+
}
67+
return writeResult(fmt.Sprintf("OK:%s", strings.Join(parts, ",")))
68+
}
69+
70+
// testResolveNonexistent attempts to resolve a hostname that does not
71+
// exist and returns the error message.
72+
//
73+
//go:wasmexport test-resolve-nonexistent
74+
func testResolveNonexistent() uint64 {
75+
resolver := wgdns.DefaultResolver()
76+
_, err := resolver.Resolve("nonexistent.invalid")
77+
if err == nil {
78+
return writeResult("ERR:expected error for nonexistent hostname, got nil")
79+
}
80+
return writeResult(fmt.Sprintf("OK:%s", err.Error()))
81+
}
82+
83+
// testResolveIPLiteral verifies that IP literals bypass DNS resolution
84+
// and are returned directly.
85+
//
86+
//go:wasmexport test-resolve-ip-literal
87+
func testResolveIPLiteral() uint64 {
88+
resolver := wgdns.DefaultResolver()
89+
ips, err := resolver.Resolve("192.168.1.1")
90+
if err != nil {
91+
return writeResult(fmt.Sprintf("ERR:resolve failed: %v", err))
92+
}
93+
if len(ips) != 1 {
94+
return writeResult(fmt.Sprintf("ERR:expected 1 address, got %d", len(ips)))
95+
}
96+
if ips[0].String() != "192.168.1.1" {
97+
return writeResult(fmt.Sprintf("ERR:expected 192.168.1.1, got %s", ips[0].String()))
98+
}
99+
return writeResult("OK:192.168.1.1")
100+
}
101+
102+
// testDialerDNSErrorWrapping exercises the Dialer with a hostname that
103+
// fails DNS resolution and verifies proper error wrapping.
104+
//
105+
//go:wasmexport test-dialer-dns-error
106+
func testDialerDNSErrorWrapping() uint64 {
107+
dialer := wgnet.DefaultDialer()
108+
_, err := dialer.Dial("tcp", "nonexistent.invalid:5432")
109+
if err == nil {
110+
return writeResult("ERR:expected error, got nil")
111+
}
112+
113+
// Verify *net.OpError wrapping
114+
var opErr *net.OpError
115+
if !errors.As(err, &opErr) {
116+
return writeResult(fmt.Sprintf("ERR:expected *net.OpError, got %T: %v", err, err))
117+
}
118+
if opErr.Op != "dial" {
119+
return writeResult(fmt.Sprintf("ERR:expected Op=dial, got %s", opErr.Op))
120+
}
121+
122+
// Verify inner *wgnet.DNSError
123+
var dnsErr *wgnet.DNSError
124+
if !errors.As(opErr.Err, &dnsErr) {
125+
return writeResult(fmt.Sprintf("ERR:expected inner *DNSError, got %T: %v", opErr.Err, opErr.Err))
126+
}
127+
if dnsErr.Name != "nonexistent.invalid" {
128+
return writeResult(fmt.Sprintf("ERR:DNSError.Name=%q, want nonexistent.invalid", dnsErr.Name))
129+
}
130+
131+
return writeResult("OK:correctly wrapped as *net.OpError{*DNSError}")
132+
}
133+
134+
// ── Helper: pack result string into uint64 (ptr << 32 | len) ───────
135+
136+
var resultBuf []byte
137+
138+
func writeResult(s string) uint64 {
139+
resultBuf = []byte(s)
140+
ptr := uint64(uintptr(unsafe.Pointer(&resultBuf[0])))
141+
length := uint64(len(resultBuf))
142+
return (ptr << 32) | length
143+
}

0 commit comments

Comments
 (0)