Skip to content

Commit 1cda069

Browse files
committed
proxy: add skipping invoice creation on request
1 parent 2821ba3 commit 1cda069

File tree

2 files changed

+88
-29
lines changed

2 files changed

+88
-29
lines changed

proxy/proxy.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
173173
// Determine auth level required to access service and dispatch request
174174
// accordingly.
175175
authLevel := target.AuthRequired(r)
176+
skipInvoiceCreation := target.SkipInvoiceCreation(r)
176177
switch {
177178
case authLevel.IsOn():
178179
// Determine if the header contains the authentication
@@ -182,6 +183,14 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
182183
// resources.
183184
acceptAuth := p.authenticator.Accept(&r.Header, resourceName)
184185
if !acceptAuth {
186+
if skipInvoiceCreation {
187+
addCorsHeaders(w.Header())
188+
sendDirectResponse(w, r, http.StatusUnauthorized,
189+
"unauthorized")
190+
191+
return
192+
}
193+
185194
price, err := target.pricer.GetPrice(r.Context(), r)
186195
if err != nil {
187196
prefixLog.Errorf("error getting "+

proxy/proxy_test.go

Lines changed: 79 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ var (
4848
)
4949

5050
type testCase struct {
51-
name string
52-
auth auth.Level
53-
authWhitelist []string
54-
wantBackendErr bool
51+
name string
52+
auth auth.Level
53+
authWhitelist []string
54+
authSkipInvoiceCreationPaths []string
55+
wantBackendErr bool
5556
}
5657

5758
// helloServer is a simple server that implements the GreeterServer interface.
@@ -98,6 +99,15 @@ func TestProxyHTTP(t *testing.T) {
9899
name: "with whitelist",
99100
auth: "on",
100101
authWhitelist: []string{"^/http/white.*$"},
102+
}, {
103+
name: "no whitelist with skip",
104+
auth: "on",
105+
authSkipInvoiceCreationPaths: []string{"^/http/skip.*$"},
106+
}, {
107+
name: "with whitelist with skip",
108+
auth: "on",
109+
authWhitelist: []string{"^/http/white.*$"},
110+
authSkipInvoiceCreationPaths: []string{"^/http/skip.*$"},
101111
}}
102112

103113
for _, tc := range testCases {
@@ -182,12 +192,13 @@ func TestProxyHTTPBlocklist(t *testing.T) {
182192
func runHTTPTest(t *testing.T, tc *testCase, method string) {
183193
// Create a list of services to proxy between.
184194
services := []*proxy.Service{{
185-
Address: testTargetServiceAddress,
186-
HostRegexp: testHostRegexp,
187-
PathRegexp: testPathRegexpHTTP,
188-
Protocol: "http",
189-
Auth: tc.auth,
190-
AuthWhitelistPaths: tc.authWhitelist,
195+
Address: testTargetServiceAddress,
196+
HostRegexp: testHostRegexp,
197+
PathRegexp: testPathRegexpHTTP,
198+
Protocol: "http",
199+
Auth: tc.auth,
200+
AuthWhitelistPaths: tc.authWhitelist,
201+
AuthSkipInvoiceCreationPaths: tc.authSkipInvoiceCreationPaths,
191202
}}
192203

193204
mockAuth := auth.NewMockAuthenticator()
@@ -261,8 +272,33 @@ func runHTTPTest(t *testing.T, tc *testCase, method string) {
261272
require.EqualValues(t, len(bodyBytes), resp.ContentLength)
262273
}
263274

275+
// Make sure that if we query a URL that is on the skip invoice
276+
// creation list, we get a 401 if auth fails.
277+
if len(tc.authSkipInvoiceCreationPaths) > 0 {
278+
urlToSkip := fmt.Sprintf("http://%s/http/skip", testProxyAddr)
279+
reqToSkip, err := http.NewRequest(method, urlToSkip, nil)
280+
require.NoError(t, err)
281+
282+
if method == "POST" {
283+
reqToSkip.Header.Add("Content-Type", "application/json")
284+
reqToSkip.Body = io.NopCloser(strings.NewReader(`{}`))
285+
}
286+
287+
respSkipped, err := client.Do(reqToSkip)
288+
require.NoError(t, err)
289+
290+
require.Equal(t, http.StatusUnauthorized, respSkipped.StatusCode)
291+
require.Equal(t, "401 Unauthorized", respSkipped.Status)
292+
293+
bodySkippedContent, err := io.ReadAll(respSkipped.Body)
294+
require.NoError(t, err)
295+
require.Equal(t, "unauthorized\n", string(bodySkippedContent))
296+
require.EqualValues(t, len(bodySkippedContent), respSkipped.ContentLength)
297+
_ = respSkipped.Body.Close()
298+
}
299+
264300
// Make sure that if the Auth header is set, the client's request is
265-
// proxied to the backend service.
301+
// proxied to the backend service for a non-skipped, non-whitelisted path.
266302
req, err = http.NewRequest(method, url, nil)
267303
require.NoError(t, err)
268304
req.Header.Add("Authorization", "foobar")
@@ -297,6 +333,12 @@ func TestProxyGRPC(t *testing.T) {
297333
authWhitelist: []string{
298334
"^/proxy_test\\.Greeter/SayHelloNoAuth.*$",
299335
},
336+
}, {
337+
name: "gRPC no whitelist with skip for SayHello",
338+
auth: "on",
339+
authSkipInvoiceCreationPaths: []string{
340+
`^/proxy_test[.]Greeter/SayHello.*$`,
341+
},
300342
}}
301343

302344
for _, tc := range testCases {
@@ -343,13 +385,14 @@ func runGRPCTest(t *testing.T, tc *testCase) {
343385

344386
// Create a list of services to proxy between.
345387
services := []*proxy.Service{{
346-
Address: testTargetServiceAddress,
347-
HostRegexp: testHostRegexp,
348-
PathRegexp: testPathRegexpGRPC,
349-
Protocol: "https",
350-
TLSCertPath: certFile,
351-
Auth: tc.auth,
352-
AuthWhitelistPaths: tc.authWhitelist,
388+
Address: testTargetServiceAddress,
389+
HostRegexp: testHostRegexp,
390+
PathRegexp: testPathRegexpGRPC,
391+
Protocol: "https",
392+
TLSCertPath: certFile,
393+
Auth: tc.auth,
394+
AuthWhitelistPaths: tc.authWhitelist,
395+
AuthSkipInvoiceCreationPaths: tc.authSkipInvoiceCreationPaths,
353396
}}
354397

355398
// Create the proxy server and start serving on TLS.
@@ -393,17 +436,24 @@ func runGRPCTest(t *testing.T, tc *testCase) {
393436
grpc.Trailer(&captureMetadata),
394437
)
395438
require.Error(t, err)
396-
require.True(t, l402.IsPaymentRequired(err))
397-
398-
// We expect the WWW-Authenticate header field to be set to an L402
399-
// auth response.
400-
expectedHeaderContent, _ := mockAuth.FreshChallengeHeader("", 0)
401-
capturedHeader := captureMetadata.Get("WWW-Authenticate")
402-
require.Len(t, capturedHeader, 2)
403-
require.Equal(
404-
t, expectedHeaderContent.Values("WWW-Authenticate"),
405-
capturedHeader,
406-
)
439+
if len(tc.authSkipInvoiceCreationPaths) > 0 {
440+
statusErr, ok := status.FromError(err)
441+
require.True(t, ok)
442+
require.Equal(t, codes.Internal, statusErr.Code())
443+
require.Equal(t, "unauthorized", statusErr.Message())
444+
} else {
445+
require.True(t, l402.IsPaymentRequired(err))
446+
447+
// We expect the WWW-Authenticate header field to be set to an L402
448+
// auth response.
449+
expectedHeaderContent, _ := mockAuth.FreshChallengeHeader("", 0)
450+
capturedHeader := captureMetadata.Get("WWW-Authenticate")
451+
require.Len(t, capturedHeader, 2)
452+
require.Equal(
453+
t, expectedHeaderContent.Values("WWW-Authenticate"),
454+
capturedHeader,
455+
)
456+
}
407457

408458
// Make sure that if we query an URL that is on the whitelist, we don't
409459
// get the 402 response.

0 commit comments

Comments
 (0)