diff --git a/proxy/proxy.go b/proxy/proxy.go index 0882b62..7dfb5d4 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -134,6 +134,12 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + // If the request is a gRPC request, we need to set the Content-Type + // header to application/grpc. + if strings.HasPrefix(r.Header.Get(hdrContentType), hdrTypeGrpc) { + w.Header().Set(hdrContentType, hdrTypeGrpc) + } + // Requests that can't be matched to a service backend will be // dispatched to the static file server. If the file exists in the // static file folder it will be served, otherwise the static server @@ -166,6 +172,7 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Determine auth level required to access service and dispatch request // accordingly. authLevel := target.AuthRequired(r) + skipInvoiceCreation := target.SkipInvoiceCreation(r) switch { case authLevel.IsOn(): // Determine if the header contains the authentication @@ -175,6 +182,16 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // resources. acceptAuth := p.authenticator.Accept(&r.Header, resourceName) if !acceptAuth { + if skipInvoiceCreation { + addCorsHeaders(w.Header()) + sendDirectResponse( + w, r, http.StatusUnauthorized, + "unauthorized", + ) + + return + } + price, err := target.pricer.GetPrice(r.Context(), r) if err != nil { prefixLog.Errorf("error getting "+ diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 64a2372..48ab200 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -48,10 +48,11 @@ var ( ) type testCase struct { - name string - auth auth.Level - authWhitelist []string - wantBackendErr bool + name string + auth auth.Level + authWhitelist []string + authSkipInvoiceCreationPaths []string + wantBackendErr bool } // helloServer is a simple server that implements the GreeterServer interface. @@ -98,6 +99,15 @@ func TestProxyHTTP(t *testing.T) { name: "with whitelist", auth: "on", authWhitelist: []string{"^/http/white.*$"}, + }, { + name: "no whitelist with skip", + auth: "on", + authSkipInvoiceCreationPaths: []string{"^/http/skip.*$"}, + }, { + name: "with whitelist with skip", + auth: "on", + authWhitelist: []string{"^/http/white.*$"}, + authSkipInvoiceCreationPaths: []string{"^/http/skip.*$"}, }} for _, tc := range testCases { @@ -182,12 +192,13 @@ func TestProxyHTTPBlocklist(t *testing.T) { func runHTTPTest(t *testing.T, tc *testCase, method string) { // Create a list of services to proxy between. services := []*proxy.Service{{ - Address: testTargetServiceAddress, - HostRegexp: testHostRegexp, - PathRegexp: testPathRegexpHTTP, - Protocol: "http", - Auth: tc.auth, - AuthWhitelistPaths: tc.authWhitelist, + Address: testTargetServiceAddress, + HostRegexp: testHostRegexp, + PathRegexp: testPathRegexpHTTP, + Protocol: "http", + Auth: tc.auth, + AuthWhitelistPaths: tc.authWhitelist, + AuthSkipInvoiceCreationPaths: tc.authSkipInvoiceCreationPaths, }} mockAuth := auth.NewMockAuthenticator() @@ -261,8 +272,33 @@ func runHTTPTest(t *testing.T, tc *testCase, method string) { require.EqualValues(t, len(bodyBytes), resp.ContentLength) } + // Make sure that if we query a URL that is on the skip invoice + // creation list, we get a 401 if auth fails. + if len(tc.authSkipInvoiceCreationPaths) > 0 { + urlToSkip := fmt.Sprintf("http://%s/http/skip", testProxyAddr) + reqToSkip, err := http.NewRequest(method, urlToSkip, nil) + require.NoError(t, err) + + if method == "POST" { + reqToSkip.Header.Add("Content-Type", "application/json") + reqToSkip.Body = io.NopCloser(strings.NewReader(`{}`)) + } + + respSkipped, err := client.Do(reqToSkip) + require.NoError(t, err) + + require.Equal(t, http.StatusUnauthorized, respSkipped.StatusCode) + require.Equal(t, "401 Unauthorized", respSkipped.Status) + + bodySkippedContent, err := io.ReadAll(respSkipped.Body) + require.NoError(t, err) + require.Equal(t, "unauthorized\n", string(bodySkippedContent)) + require.EqualValues(t, len(bodySkippedContent), respSkipped.ContentLength) + _ = respSkipped.Body.Close() + } + // Make sure that if the Auth header is set, the client's request is - // proxied to the backend service. + // proxied to the backend service for a non-skipped, non-whitelisted path. req, err = http.NewRequest(method, url, nil) require.NoError(t, err) req.Header.Add("Authorization", "foobar") @@ -297,6 +333,12 @@ func TestProxyGRPC(t *testing.T) { authWhitelist: []string{ "^/proxy_test\\.Greeter/SayHelloNoAuth.*$", }, + }, { + name: "gRPC no whitelist with skip for SayHello", + auth: "on", + authSkipInvoiceCreationPaths: []string{ + `^/proxy_test[.]Greeter/SayHello.*$`, + }, }} for _, tc := range testCases { @@ -343,13 +385,14 @@ func runGRPCTest(t *testing.T, tc *testCase) { // Create a list of services to proxy between. services := []*proxy.Service{{ - Address: testTargetServiceAddress, - HostRegexp: testHostRegexp, - PathRegexp: testPathRegexpGRPC, - Protocol: "https", - TLSCertPath: certFile, - Auth: tc.auth, - AuthWhitelistPaths: tc.authWhitelist, + Address: testTargetServiceAddress, + HostRegexp: testHostRegexp, + PathRegexp: testPathRegexpGRPC, + Protocol: "https", + TLSCertPath: certFile, + Auth: tc.auth, + AuthWhitelistPaths: tc.authWhitelist, + AuthSkipInvoiceCreationPaths: tc.authSkipInvoiceCreationPaths, }} // Create the proxy server and start serving on TLS. @@ -393,17 +436,24 @@ func runGRPCTest(t *testing.T, tc *testCase) { grpc.Trailer(&captureMetadata), ) require.Error(t, err) - require.True(t, l402.IsPaymentRequired(err)) - - // We expect the WWW-Authenticate header field to be set to an L402 - // auth response. - expectedHeaderContent, _ := mockAuth.FreshChallengeHeader("", 0) - capturedHeader := captureMetadata.Get("WWW-Authenticate") - require.Len(t, capturedHeader, 2) - require.Equal( - t, expectedHeaderContent.Values("WWW-Authenticate"), - capturedHeader, - ) + if len(tc.authSkipInvoiceCreationPaths) > 0 { + statusErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.Internal, statusErr.Code()) + require.Equal(t, "unauthorized", statusErr.Message()) + } else { + require.True(t, l402.IsPaymentRequired(err)) + + // We expect the WWW-Authenticate header field to be set to an L402 + // auth response. + expectedHeaderContent, _ := mockAuth.FreshChallengeHeader("", 0) + capturedHeader := captureMetadata.Get("WWW-Authenticate") + require.Len(t, capturedHeader, 2) + require.Equal( + t, expectedHeaderContent.Values("WWW-Authenticate"), + capturedHeader, + ) + } // Make sure that if we query an URL that is on the whitelist, we don't // get the 402 response. diff --git a/proxy/service.go b/proxy/service.go index 0248e88..d7b9626 100644 --- a/proxy/service.go +++ b/proxy/service.go @@ -103,6 +103,13 @@ type Service struct { // /package_name.ServiceName/MethodName AuthWhitelistPaths []string `long:"authwhitelistpaths" description:"List of regular expressions for paths that don't require authentication'"` + // AuthSkipInvoiceCreationPaths is an optional list of regular + // expressions that are matched against the path of the URL of a + // request. If the request URL matches any of those regular + // expressions, the call will not try to create an invoice for the + // request, but still try to do the l402 authentication. + AuthSkipInvoiceCreationPaths []string `long:"authskipinvoicecreationpaths" description:"List of regular expressions for paths that will skip invoice creation'"` + // compiledHostRegexp is the compiled host regex. compiledHostRegexp *regexp.Regexp @@ -112,6 +119,10 @@ type Service struct { // compiledAuthWhitelistPaths is the compiled auth whitelist paths. compiledAuthWhitelistPaths []*regexp.Regexp + // compiledAuthSkipInvoiceCreationPaths is the compiled auth skip + // invoice creation paths. + compiledAuthSkipInvoiceCreationPaths []*regexp.Regexp + freebieDB freebie.DB pricer pricer.Pricer } @@ -144,6 +155,20 @@ func (s *Service) AuthRequired(r *http.Request) auth.Level { return s.Auth } +// SkipInvoiceCreation determines if an invoice should be created for a +// given request. +func (s *Service) SkipInvoiceCreation(r *http.Request) bool { + for _, pathRegexp := range s.compiledAuthSkipInvoiceCreationPaths { + if pathRegexp.MatchString(r.URL.Path) { + log.Tracef("Req path [%s] matches skip entry "+ + "[%s].", r.URL.Path, pathRegexp) + return true + } + } + + return false +} + // prepareServices prepares the backend service configurations to be used by the // proxy. func prepareServices(services []*Service) error { @@ -195,7 +220,7 @@ func prepareServices(services []*Service) error { // Compile the host regex. compiledHostRegexp, err := regexp.Compile(service.HostRegexp) if err != nil { - return fmt.Errorf("error compiling host regex: %v", err) + return fmt.Errorf("error compiling host regex: %w", err) } service.compiledHostRegexp = compiledHostRegexp @@ -206,7 +231,7 @@ func prepareServices(services []*Service) error { ) if err != nil { return fmt.Errorf("error compiling path "+ - "regex: %v", err) + "regex: %w", err) } service.compiledPathRegexp = compiledPathRegexp } @@ -222,13 +247,34 @@ func prepareServices(services []*Service) error { regExp, err := regexp.Compile(entry) if err != nil { return fmt.Errorf("error validating auth "+ - "whitelist: %v", err) + "whitelist: %w", err) } service.compiledAuthWhitelistPaths = append( service.compiledAuthWhitelistPaths, regExp, ) } + service.compiledAuthSkipInvoiceCreationPaths = make( + []*regexp.Regexp, 0, len( + service.AuthSkipInvoiceCreationPaths, + ), + ) + + // Make sure all skip invoice creation regular expression + // entries actually compile so we run into an eventual panic + // during startup and not only when the request happens. + for _, entry := range service.AuthSkipInvoiceCreationPaths { + regExp, err := regexp.Compile(entry) + if err != nil { + return fmt.Errorf("error validating skip "+ + "invoice creation whitelist: %w", err) + } + service.compiledAuthSkipInvoiceCreationPaths = append( + service.compiledAuthSkipInvoiceCreationPaths, + regExp, + ) + } + // If dynamic prices are enabled then use the provided // DynamicPrice options to initialise a gRPC backed // pricer client. diff --git a/sample-conf.yaml b/sample-conf.yaml index 5e19b80..b5e1502 100644 --- a/sample-conf.yaml +++ b/sample-conf.yaml @@ -151,6 +151,17 @@ services: # dynamicprice.enabled is set to true. price: 0 + # A list of regular expressions for path that are free of charge. + authwhitelistpaths: + - '^/freebieservice.*$' + + # A list of regular expressions for path that will skip invoice creation, + # but still try to do the l402 authentication. This is useful for streaming + # services, as they are not supported to be the initial request to receive + # a L402. + authskipinvoicecreationpaths: + - '^/streamingservice.*$' + # Options to use for connection to the price serving gRPC server. dynamicprice: # Whether or not a gRPC server is available to query price data from. If