Skip to content

Add Invoice skip list #176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

// If the request is a gRPC request, we need to set the Content-Type
// header to application/grpc.
if strings.HasPrefix(r.Header.Get(hdrContentType), hdrTypeGrpc) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this might fix a long-standing error that we see sometimes. I just hope it doesn't introduce any side effects (I remember the whole header handling for gRPC being super brittle with many weird behaviors).

w.Header().Set(hdrContentType, hdrTypeGrpc)
}

// Requests that can't be matched to a service backend will be
// dispatched to the static file server. If the file exists in the
// static file folder it will be served, otherwise the static server
Expand Down Expand Up @@ -166,6 +172,7 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Determine auth level required to access service and dispatch request
// accordingly.
authLevel := target.AuthRequired(r)
skipInvoiceCreation := target.SkipInvoiceCreation(r)
switch {
case authLevel.IsOn():
// Determine if the header contains the authentication
Expand All @@ -175,6 +182,16 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// resources.
acceptAuth := p.authenticator.Accept(&r.Header, resourceName)
if !acceptAuth {
if skipInvoiceCreation {
addCorsHeaders(w.Header())
sendDirectResponse(
w, r, http.StatusUnauthorized,
"unauthorized",
)

return
}

price, err := target.pricer.GetPrice(r.Context(), r)
if err != nil {
prefixLog.Errorf("error getting "+
Expand Down
108 changes: 79 additions & 29 deletions proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ var (
)

type testCase struct {
name string
auth auth.Level
authWhitelist []string
wantBackendErr bool
name string
auth auth.Level
authWhitelist []string
authSkipInvoiceCreationPaths []string
wantBackendErr bool
}

// helloServer is a simple server that implements the GreeterServer interface.
Expand Down Expand Up @@ -98,6 +99,15 @@ func TestProxyHTTP(t *testing.T) {
name: "with whitelist",
auth: "on",
authWhitelist: []string{"^/http/white.*$"},
}, {
name: "no whitelist with skip",
auth: "on",
authSkipInvoiceCreationPaths: []string{"^/http/skip.*$"},
}, {
name: "with whitelist with skip",
auth: "on",
authWhitelist: []string{"^/http/white.*$"},
authSkipInvoiceCreationPaths: []string{"^/http/skip.*$"},
}}

for _, tc := range testCases {
Expand Down Expand Up @@ -182,12 +192,13 @@ func TestProxyHTTPBlocklist(t *testing.T) {
func runHTTPTest(t *testing.T, tc *testCase, method string) {
// Create a list of services to proxy between.
services := []*proxy.Service{{
Address: testTargetServiceAddress,
HostRegexp: testHostRegexp,
PathRegexp: testPathRegexpHTTP,
Protocol: "http",
Auth: tc.auth,
AuthWhitelistPaths: tc.authWhitelist,
Address: testTargetServiceAddress,
HostRegexp: testHostRegexp,
PathRegexp: testPathRegexpHTTP,
Protocol: "http",
Auth: tc.auth,
AuthWhitelistPaths: tc.authWhitelist,
AuthSkipInvoiceCreationPaths: tc.authSkipInvoiceCreationPaths,
}}

mockAuth := auth.NewMockAuthenticator()
Expand Down Expand Up @@ -261,8 +272,33 @@ func runHTTPTest(t *testing.T, tc *testCase, method string) {
require.EqualValues(t, len(bodyBytes), resp.ContentLength)
}

// Make sure that if we query a URL that is on the skip invoice
// creation list, we get a 401 if auth fails.
if len(tc.authSkipInvoiceCreationPaths) > 0 {
urlToSkip := fmt.Sprintf("http://%s/http/skip", testProxyAddr)
reqToSkip, err := http.NewRequest(method, urlToSkip, nil)
require.NoError(t, err)

if method == "POST" {
reqToSkip.Header.Add("Content-Type", "application/json")
reqToSkip.Body = io.NopCloser(strings.NewReader(`{}`))
}

respSkipped, err := client.Do(reqToSkip)
require.NoError(t, err)

require.Equal(t, http.StatusUnauthorized, respSkipped.StatusCode)
require.Equal(t, "401 Unauthorized", respSkipped.Status)

bodySkippedContent, err := io.ReadAll(respSkipped.Body)
require.NoError(t, err)
require.Equal(t, "unauthorized\n", string(bodySkippedContent))
require.EqualValues(t, len(bodySkippedContent), respSkipped.ContentLength)
_ = respSkipped.Body.Close()
}

// Make sure that if the Auth header is set, the client's request is
// proxied to the backend service.
// proxied to the backend service for a non-skipped, non-whitelisted path.
req, err = http.NewRequest(method, url, nil)
require.NoError(t, err)
req.Header.Add("Authorization", "foobar")
Expand Down Expand Up @@ -297,6 +333,12 @@ func TestProxyGRPC(t *testing.T) {
authWhitelist: []string{
"^/proxy_test\\.Greeter/SayHelloNoAuth.*$",
},
}, {
name: "gRPC no whitelist with skip for SayHello",
auth: "on",
authSkipInvoiceCreationPaths: []string{
`^/proxy_test[.]Greeter/SayHello.*$`,
},
}}

for _, tc := range testCases {
Expand Down Expand Up @@ -343,13 +385,14 @@ func runGRPCTest(t *testing.T, tc *testCase) {

// Create a list of services to proxy between.
services := []*proxy.Service{{
Address: testTargetServiceAddress,
HostRegexp: testHostRegexp,
PathRegexp: testPathRegexpGRPC,
Protocol: "https",
TLSCertPath: certFile,
Auth: tc.auth,
AuthWhitelistPaths: tc.authWhitelist,
Address: testTargetServiceAddress,
HostRegexp: testHostRegexp,
PathRegexp: testPathRegexpGRPC,
Protocol: "https",
TLSCertPath: certFile,
Auth: tc.auth,
AuthWhitelistPaths: tc.authWhitelist,
AuthSkipInvoiceCreationPaths: tc.authSkipInvoiceCreationPaths,
}}

// Create the proxy server and start serving on TLS.
Expand Down Expand Up @@ -393,17 +436,24 @@ func runGRPCTest(t *testing.T, tc *testCase) {
grpc.Trailer(&captureMetadata),
)
require.Error(t, err)
require.True(t, l402.IsPaymentRequired(err))

// We expect the WWW-Authenticate header field to be set to an L402
// auth response.
expectedHeaderContent, _ := mockAuth.FreshChallengeHeader("", 0)
capturedHeader := captureMetadata.Get("WWW-Authenticate")
require.Len(t, capturedHeader, 2)
require.Equal(
t, expectedHeaderContent.Values("WWW-Authenticate"),
capturedHeader,
)
if len(tc.authSkipInvoiceCreationPaths) > 0 {
statusErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, codes.Internal, statusErr.Code())
require.Equal(t, "unauthorized", statusErr.Message())
} else {
require.True(t, l402.IsPaymentRequired(err))

// We expect the WWW-Authenticate header field to be set to an L402
// auth response.
expectedHeaderContent, _ := mockAuth.FreshChallengeHeader("", 0)
capturedHeader := captureMetadata.Get("WWW-Authenticate")
require.Len(t, capturedHeader, 2)
require.Equal(
t, expectedHeaderContent.Values("WWW-Authenticate"),
capturedHeader,
)
}

// Make sure that if we query an URL that is on the whitelist, we don't
// get the 402 response.
Expand Down
52 changes: 49 additions & 3 deletions proxy/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ type Service struct {
// /package_name.ServiceName/MethodName
AuthWhitelistPaths []string `long:"authwhitelistpaths" description:"List of regular expressions for paths that don't require authentication'"`

// AuthSkipInvoiceCreationPaths is an optional list of regular
// expressions that are matched against the path of the URL of a
// request. If the request URL matches any of those regular
// expressions, the call will not try to create an invoice for the
// request, but still try to do the l402 authentication.
AuthSkipInvoiceCreationPaths []string `long:"authskipinvoicecreationpaths" description:"List of regular expressions for paths that will skip invoice creation'"`

// compiledHostRegexp is the compiled host regex.
compiledHostRegexp *regexp.Regexp

Expand All @@ -112,6 +119,10 @@ type Service struct {
// compiledAuthWhitelistPaths is the compiled auth whitelist paths.
compiledAuthWhitelistPaths []*regexp.Regexp

// compiledAuthSkipInvoiceCreationPaths is the compiled auth skip
// invoice creation paths.
compiledAuthSkipInvoiceCreationPaths []*regexp.Regexp

freebieDB freebie.DB
pricer pricer.Pricer
}
Expand Down Expand Up @@ -144,6 +155,20 @@ func (s *Service) AuthRequired(r *http.Request) auth.Level {
return s.Auth
}

// SkipInvoiceCreation determines if an invoice should be created for a
// given request.
func (s *Service) SkipInvoiceCreation(r *http.Request) bool {
for _, pathRegexp := range s.compiledAuthSkipInvoiceCreationPaths {
if pathRegexp.MatchString(r.URL.Path) {
log.Tracef("Req path [%s] matches skip entry "+
"[%s].", r.URL.Path, pathRegexp)
return true
}
}

return false
}

// prepareServices prepares the backend service configurations to be used by the
// proxy.
func prepareServices(services []*Service) error {
Expand Down Expand Up @@ -195,7 +220,7 @@ func prepareServices(services []*Service) error {
// Compile the host regex.
compiledHostRegexp, err := regexp.Compile(service.HostRegexp)
if err != nil {
return fmt.Errorf("error compiling host regex: %v", err)
return fmt.Errorf("error compiling host regex: %w", err)
}
service.compiledHostRegexp = compiledHostRegexp

Expand All @@ -206,7 +231,7 @@ func prepareServices(services []*Service) error {
)
if err != nil {
return fmt.Errorf("error compiling path "+
"regex: %v", err)
"regex: %w", err)
}
service.compiledPathRegexp = compiledPathRegexp
}
Expand All @@ -222,13 +247,34 @@ func prepareServices(services []*Service) error {
regExp, err := regexp.Compile(entry)
if err != nil {
return fmt.Errorf("error validating auth "+
"whitelist: %v", err)
"whitelist: %w", err)
}
service.compiledAuthWhitelistPaths = append(
service.compiledAuthWhitelistPaths, regExp,
)
}

service.compiledAuthSkipInvoiceCreationPaths = make(
[]*regexp.Regexp, 0, len(
service.AuthSkipInvoiceCreationPaths,
),
)

// Make sure all skip invoice creation regular expression
// entries actually compile so we run into an eventual panic
// during startup and not only when the request happens.
for _, entry := range service.AuthSkipInvoiceCreationPaths {
regExp, err := regexp.Compile(entry)
if err != nil {
return fmt.Errorf("error validating skip "+
"invoice creation whitelist: %w", err)
}
service.compiledAuthSkipInvoiceCreationPaths = append(
service.compiledAuthSkipInvoiceCreationPaths,
regExp,
)
}

// If dynamic prices are enabled then use the provided
// DynamicPrice options to initialise a gRPC backed
// pricer client.
Expand Down
11 changes: 11 additions & 0 deletions sample-conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ services:
# dynamicprice.enabled is set to true.
price: 0

# A list of regular expressions for path that are free of charge.
authwhitelistpaths:
- '^/freebieservice.*$'

# A list of regular expressions for path that will skip invoice creation,
# but still try to do the l402 authentication. This is useful for streaming
# services, as they are not supported to be the initial request to receive
# a L402.
authskipinvoicecreationpaths:
- '^/streamingservice.*$'

# Options to use for connection to the price serving gRPC server.
dynamicprice:
# Whether or not a gRPC server is available to query price data from. If
Expand Down
Loading