Skip to content

Commit e4782db

Browse files
alnrory-bot
authored andcommitted
fix: don't leak info from SSRF protection
GitOrigin-RevId: 49d2f068100fc87a061089e595dace8a1b6519a0
1 parent 52b5efb commit e4782db

File tree

4 files changed

+65
-142
lines changed

4 files changed

+65
-142
lines changed

driver/registry_sql_test.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ import (
2323

2424
"github.com/ory/hydra/v2/client"
2525
"github.com/ory/hydra/v2/driver/config"
26+
"github.com/ory/hydra/v2/fosite"
2627
"github.com/ory/hydra/v2/persistence/sql"
2728
"github.com/ory/x/configx"
2829
"github.com/ory/x/dbal"
29-
"github.com/ory/x/httpx"
3030
"github.com/ory/x/logrusx"
3131
"github.com/ory/x/popx"
3232
"github.com/ory/x/randx"
@@ -39,6 +39,12 @@ func init() {
3939
func TestGetJWKSFetcherStrategyHostEnforcement(t *testing.T) {
4040
t.Parallel()
4141

42+
ts := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
43+
t.Fatal("Should not be called")
44+
writer.WriteHeader(http.StatusOK)
45+
}))
46+
t.Cleanup(ts.Close)
47+
4248
r, err := New(t.Context(), WithAutoMigrate(), WithConfigOptions(
4349
configx.WithValues(map[string]any{
4450
config.KeyDSN: dbal.NewSQLiteTestDatabase(t),
@@ -49,8 +55,10 @@ func TestGetJWKSFetcherStrategyHostEnforcement(t *testing.T) {
4955
))
5056
require.NoError(t, err)
5157

52-
_, err = r.OAuth2Config().GetJWKSFetcherStrategy(t.Context()).Resolve(t.Context(), "http://localhost:8080", true)
53-
require.ErrorAs(t, err, new(httpx.ErrPrivateIPAddressDisallowed))
58+
_, err = r.OAuth2Config().GetJWKSFetcherStrategy(t.Context()).Resolve(t.Context(), ts.URL, true)
59+
rfcErr, ok := errors.AsType[*fosite.RFC6749Error](err)
60+
require.True(t, ok, "expected a fosite.RFC6749Error, got %T: %+v", err, err)
61+
require.Contains(t, rfcErr.DebugField, "no route to host")
5462
}
5563

5664
func TestRegistrySQL_newKeyStrategy_handlesNetworkError(t *testing.T) {
@@ -133,7 +141,7 @@ func TestRegistrySQL_HTTPClient(t *testing.T) {
133141

134142
t.Run("case=does not match exception glob", func(t *testing.T) {
135143
_, err := r.HTTPClient(t.Context()).Get(ts.URL + "/foo")
136-
assert.ErrorContains(t, err, "prohibited IP address")
144+
assert.Error(t, err)
137145
})
138146
}
139147

oryx/httpx/private_ip_validator.go

Lines changed: 0 additions & 85 deletions
This file was deleted.

oryx/httpx/ssrf.go

Lines changed: 50 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package httpx
55

66
import (
77
"context"
8+
"errors"
89
"net"
910
"net/http"
1011
"net/http/httptrace"
@@ -43,33 +44,12 @@ func (n noInternalIPRoundTripper) RoundTrip(request *http.Request) (*http.Respon
4344
}
4445

4546
var (
46-
prohibitInternalAllowIPv6 http.RoundTripper
47-
allowInternalAllowIPv6 http.RoundTripper
48-
)
49-
50-
func init() {
51-
t, d := newDefaultTransport()
52-
d.Control = ssrf.New(
47+
prohibitInternalAllowIPv6 http.RoundTripper = OTELTraceTransport(ssrfTransport(
5348
ssrf.WithAnyPort(),
5449
ssrf.WithNetworks("tcp4", "tcp6"),
55-
).Safe
56-
prohibitInternalAllowIPv6 = OTELTraceTransport(t)
57-
}
50+
))
5851

59-
func init() {
60-
t, d := newDefaultTransport()
61-
d.Control = ssrf.New(
62-
ssrf.WithAnyPort(),
63-
ssrf.WithNetworks("tcp4"),
64-
).Safe
65-
t.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
66-
return d.DialContext(ctx, "tcp4", addr)
67-
}
68-
}
69-
70-
func init() {
71-
t, d := newDefaultTransport()
72-
d.Control = ssrf.New(
52+
allowInternalAllowIPv6 http.RoundTripper = OTELTraceTransport(ssrfTransport(
7353
ssrf.WithAnyPort(),
7454
ssrf.WithNetworks("tcp4", "tcp6"),
7555
ssrf.WithAllowedV4Prefixes(
@@ -83,46 +63,64 @@ func init() {
8363
netip.MustParsePrefix("::1/128"), // Loopback (RFC 4193)
8464
netip.MustParsePrefix("fc00::/7"), // Unique Local (RFC 4193)
8565
),
86-
).Safe
87-
allowInternalAllowIPv6 = OTELTraceTransport(t)
88-
}
89-
90-
func init() {
91-
t, d := newDefaultTransport()
92-
d.Control = ssrf.New(
93-
ssrf.WithAnyPort(),
94-
ssrf.WithNetworks("tcp4"),
95-
ssrf.WithAllowedV4Prefixes(
96-
netip.MustParsePrefix("10.0.0.0/8"), // Private-Use (RFC 1918)
97-
netip.MustParsePrefix("127.0.0.0/8"), // Loopback (RFC 1122, Section 3.2.1.3))
98-
netip.MustParsePrefix("169.254.0.0/16"), // Link Local (RFC 3927)
99-
netip.MustParsePrefix("172.16.0.0/12"), // Private-Use (RFC 1918)
100-
netip.MustParsePrefix("192.168.0.0/16"), // Private-Use (RFC 1918)
101-
),
102-
ssrf.WithAllowedV6Prefixes(
103-
netip.MustParsePrefix("::1/128"), // Loopback (RFC 4193)
104-
netip.MustParsePrefix("fc00::/7"), // Unique Local (RFC 4193)
105-
),
106-
).Safe
107-
t.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
108-
return d.DialContext(ctx, "tcp4", addr)
109-
}
110-
}
66+
))
67+
)
11168

112-
func newDefaultTransport() (*http.Transport, *net.Dialer) {
69+
func ssrfTransport(opt ...ssrf.Option) *http.Transport {
11370
dialer := net.Dialer{
11471
Timeout: 30 * time.Second,
11572
KeepAlive: 30 * time.Second,
11673
}
74+
dialer.Control = ssrf.New(opt...).Safe
75+
dial := func(ctx context.Context, network string, address string) (net.Conn, error) {
76+
c, err := dialer.DialContext(ctx, network, address)
77+
if err == nil {
78+
return c, nil
79+
}
80+
81+
if dnsErr, ok := errors.AsType[*net.DNSError](err); ok {
82+
dnsErr.Server = "" // mask our DNS server's IP address
83+
return nil, err
84+
}
85+
86+
if !errors.Is(err, ssrf.ErrProhibitedIP) {
87+
return nil, err
88+
}
89+
90+
host, _, _ := net.SplitHostPort(address)
91+
_, addrErr := netip.ParseAddrPort(address)
92+
if addrErr != nil {
93+
// We were given a DNS name: the error we return must look like a DNS error.
94+
return nil, &net.OpError{
95+
Op: "dial",
96+
Net: network,
97+
Addr: nil,
98+
Err: &net.DNSError{
99+
Err: "no such host",
100+
Name: host,
101+
Server: "",
102+
IsTimeout: false,
103+
IsTemporary: false,
104+
IsNotFound: true,
105+
},
106+
}
107+
}
108+
return nil, &net.OpError{
109+
Op: "dial",
110+
Net: network,
111+
Addr: nil,
112+
Err: errors.New("no route to host"),
113+
}
114+
}
117115
return &http.Transport{
118116
Proxy: http.ProxyFromEnvironment,
119-
DialContext: dialer.DialContext,
117+
DialContext: dial,
120118
ForceAttemptHTTP2: true,
121119
MaxIdleConns: 100,
122120
IdleConnTimeout: 90 * time.Second,
123121
TLSHandshakeTimeout: 10 * time.Second,
124122
ExpectContinueTimeout: 1 * time.Second,
125-
}, &dialer
123+
}
126124
}
127125

128126
// OTELTraceTransport wraps the given http.Transport with OpenTelemetry instrumentation.

oryx/httpx/wait_for.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ func WaitForEndpointWithClient(ctx context.Context, client *http.Client, endpoin
3131
if err != nil {
3232
return err
3333
}
34-
defer res.Body.Close()
34+
defer func() {
35+
_ = res.Body.Close()
36+
}()
3537

3638
body, err := io.ReadAll(res.Body)
3739
if err != nil {

0 commit comments

Comments
 (0)