Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions descope/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ var (
exchangeAccessKey: "auth/accesskey/exchange",
},
mgmt: mgmtEndpoints{
license: "mgmt/license",
tenantCreate: "mgmt/tenant/create",
tenantUpdate: "mgmt/tenant/update",
tenantDelete: "mgmt/tenant/delete",
Expand Down Expand Up @@ -346,6 +347,7 @@ type authEndpoints struct {
}

type mgmtEndpoints struct {
license string
tenantCreate string
tenantUpdate string
tenantDelete string
Expand Down Expand Up @@ -1482,6 +1484,10 @@ func (e *endpoints) ManagementDescoperSearch() string {
return path.Join(e.version, e.mgmt.descoperSearch)
}

func (e *endpoints) ManagementLicense() string {
return path.Join(e.version, e.mgmt.license)
}

type sdkInfo struct {
name string
version string
Expand Down Expand Up @@ -1526,6 +1532,7 @@ type Client struct {
externalRequestID func(context.Context) string
Conf ClientParams
sdkInfo *sdkInfo
licenseType string // License type from handshake (free/pro/enterprise)
}
type HTTPResponse struct {
Req *http.Request
Expand Down Expand Up @@ -1801,6 +1808,28 @@ func (c *Client) addDescopeHeaders(req *http.Request) {
req.Header.Set("x-descope-sdk-sha", c.sdkInfo.sha)
req.Header.Set("x-descope-sdk-uuid", instanceUUID)
req.Header.Set("x-descope-project-id", c.Conf.ProjectID)
if c.licenseType != "" {
req.Header.Set("x-descope-license", c.licenseType)
}
}

func (c *Client) FetchLicense(ctx context.Context) (string, error) {
var resp struct {
LicenseType string `json:"licenseType"`
}
opts := &HTTPRequest{ResBodyObj: &resp}
_, err := c.DoGetRequest(ctx, Routes.ManagementLicense(), opts, "")
if err != nil {
return "", err
}
if resp.LicenseType == "" {
return "", fmt.Errorf("empty license type returned from server")
}
return resp.LicenseType, nil
}

func (c *Client) SetLicenseType(licenseType string) {
c.licenseType = licenseType
}

func getSDKInfo() *sdkInfo {
Expand Down
251 changes: 251 additions & 0 deletions descope/api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,254 @@ func TestBaseURLForProjectID(t *testing.T) {
assert.EqualValues(t, useURL, baseURLForProjectID("Puse12aAc4T2V93bddihGEx2Ryhc8e5Zfoobar"))
assert.EqualValues(t, useURL, baseURLForProjectID("Puse12aAc4T2V93bddihGEx2Ryhc8e5Z"))
}

// License Header Tests

func TestFetchLicense_Success(t *testing.T) {
projectID := "test-project"
expectedLicenseType := "enterprise"

c := NewClient(ClientParams{
ProjectID: projectID,
ManagementKey: "test-key",
DefaultClient: mocks.NewTestClient(func(r *http.Request) (*http.Response, error) {
// Verify the request is to the correct endpoint
assert.Contains(t, r.URL.Path, "mgmt/license")
assert.EqualValues(t, http.MethodGet, r.Method)

// Return a successful response with license type
responseBody := fmt.Sprintf(`{"licenseType": "%s"}`, expectedLicenseType)
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(responseBody)),
}, nil
}),
})

licenseType, err := c.FetchLicense(context.Background())
require.NoError(t, err)
assert.EqualValues(t, expectedLicenseType, licenseType)
}

func TestFetchLicense_APIError(t *testing.T) {
projectID := "test-project"

c := NewClient(ClientParams{
ProjectID: projectID,
ManagementKey: "test-key",
DefaultClient: mocks.NewTestClient(func(_ *http.Request) (*http.Response, error) {
// Return an error response
return &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader(`{"errorCode": "E999999", "errorDescription": "Internal error"}`)),
}, nil
}),
})

licenseType, err := c.FetchLicense(context.Background())
require.Error(t, err)
assert.Empty(t, licenseType)
}

func TestFetchLicense_NetworkError(t *testing.T) {
projectID := "test-project"
expectedErr := fmt.Errorf("network error")

c := NewClient(ClientParams{
ProjectID: projectID,
ManagementKey: "test-key",
DefaultClient: mocks.NewTestClient(func(_ *http.Request) (*http.Response, error) {
return nil, expectedErr
}),
})

licenseType, err := c.FetchLicense(context.Background())
require.Error(t, err)
assert.Empty(t, licenseType)
assert.Contains(t, err.Error(), "network error")
}

func TestSetLicenseType(t *testing.T) {
projectID := "test-project"
c := NewClient(ClientParams{ProjectID: projectID})

// Initially, licenseType should be empty
assert.Empty(t, c.licenseType)

// Set license type
expectedLicenseType := "pro"
c.SetLicenseType(expectedLicenseType)
assert.EqualValues(t, expectedLicenseType, c.licenseType)

newLicenseType := "enterprise"
c.SetLicenseType(newLicenseType)
assert.EqualValues(t, newLicenseType, c.licenseType)

// Set to empty string
c.SetLicenseType("")
assert.Empty(t, c.licenseType)
}

func TestLicenseHeader_AddedWhenSet(t *testing.T) {
projectID := "test-project"
expectedLicenseType := "enterprise"
headerChecked := false

c := NewClient(ClientParams{
ProjectID: projectID,
DefaultClient: mocks.NewTestClient(func(r *http.Request) (*http.Response, error) {
// Verify the license header is present
actualLicense := r.Header.Get("x-descope-license")
assert.EqualValues(t, expectedLicenseType, actualLicense)
headerChecked = true
return &http.Response{StatusCode: http.StatusOK}, nil
}),
})

// Set the license type
c.SetLicenseType(expectedLicenseType)

// Make a request and verify the header is added
_, err := c.DoPostRequest(context.Background(), "test-path", nil, nil, "")
require.NoError(t, err)
assert.True(t, headerChecked, "License header check was not performed")
}

func TestLicenseHeader_NotAddedWhenEmpty(t *testing.T) {
projectID := "test-project"
headerChecked := false

c := NewClient(ClientParams{
ProjectID: projectID,
DefaultClient: mocks.NewTestClient(func(r *http.Request) (*http.Response, error) {
// Verify the license header is NOT present
actualLicense := r.Header.Get("x-descope-license")
assert.Empty(t, actualLicense, "License header should not be present when licenseType is empty")
headerChecked = true
return &http.Response{StatusCode: http.StatusOK}, nil
}),
})

// Do NOT set license type (should remain empty)
assert.Empty(t, c.licenseType)

// Make a request and verify the header is NOT added
_, err := c.DoPostRequest(context.Background(), "test-path", nil, nil, "")
require.NoError(t, err)
assert.True(t, headerChecked, "License header check was not performed")
}

func TestLicenseHeader_AddedToAllRequestTypes(t *testing.T) {
projectID := "test-project"
expectedLicenseType := "pro"

tests := []struct {
name string
requestFunc func(*Client) error
}{
{
name: "POST request",
requestFunc: func(c *Client) error {
_, err := c.DoPostRequest(context.Background(), "test-path", nil, nil, "")
return err
},
},
{
name: "GET request",
requestFunc: func(c *Client) error {
_, err := c.DoGetRequest(context.Background(), "test-path", nil, "")
return err
},
},
{
name: "PUT request",
requestFunc: func(c *Client) error {
_, err := c.DoPutRequest(context.Background(), "test-path", nil, nil, "")
return err
},
},
{
name: "DELETE request",
requestFunc: func(c *Client) error {
_, err := c.DoDeleteRequest(context.Background(), "test-path", nil, "")
return err
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
headerChecked := false

c := NewClient(ClientParams{
ProjectID: projectID,
DefaultClient: mocks.NewTestClient(func(r *http.Request) (*http.Response, error) {
// Verify the license header is present
actualLicense := r.Header.Get("x-descope-license")
assert.EqualValues(t, expectedLicenseType, actualLicense)
headerChecked = true
return &http.Response{StatusCode: http.StatusOK}, nil
}),
})

// Set the license type
c.SetLicenseType(expectedLicenseType)

// Execute the request
err := tt.requestFunc(c)
require.NoError(t, err)
assert.True(t, headerChecked, "License header check was not performed")
})
}
}

func TestLicenseHeader_UpdatedDynamically(t *testing.T) {
projectID := "test-project"
requestCount := 0

c := NewClient(ClientParams{
ProjectID: projectID,
DefaultClient: mocks.NewTestClient(func(r *http.Request) (*http.Response, error) {
requestCount++
actualLicense := r.Header.Get("x-descope-license")

switch requestCount {
case 1:
// First request: no license header
assert.Empty(t, actualLicense)
case 2:
// Second request: "free" license
assert.EqualValues(t, "free", actualLicense)
case 3:
// Third request: "enterprise" license
assert.EqualValues(t, "enterprise", actualLicense)
case 4:
// Fourth request: no license header again
assert.Empty(t, actualLicense)
}

return &http.Response{StatusCode: http.StatusOK}, nil
}),
})

// Request 1: No license set
_, err := c.DoPostRequest(context.Background(), "test-path", nil, nil, "")
require.NoError(t, err)

// Request 2: Set to "free"
c.SetLicenseType("free")
_, err = c.DoPostRequest(context.Background(), "test-path", nil, nil, "")
require.NoError(t, err)

// Request 3: Update to "enterprise"
c.SetLicenseType("enterprise")
_, err = c.DoPostRequest(context.Background(), "test-path", nil, nil, "")
require.NoError(t, err)

// Request 4: Clear license
c.SetLicenseType("")
_, err = c.DoPostRequest(context.Background(), "test-path", nil, nil, "")
require.NoError(t, err)

assert.EqualValues(t, 4, requestCount, "Expected 4 requests to be made")
}
14 changes: 14 additions & 0 deletions descope/client/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package client

import (
"context"
"time"

"github.com/descope/go-sdk/descope"
"github.com/descope/go-sdk/descope/api"
"github.com/descope/go-sdk/descope/internal/auth"
Expand Down Expand Up @@ -86,6 +89,17 @@ func NewWithConfig(config *Config) (*DescopeClient, error) {
CertificateVerify: config.CertificateVerify,
RequestTimeout: config.RequestTimeout,
})

if config.ManagementKey != "" {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if licenseType, err := mgmtClient.FetchLicense(ctx); err != nil {
logger.LogInfo("License handshake failed, continuing without header: %v", err)
} else {
mgmtClient.SetLicenseType(licenseType)
}
}

managementService := mgmt.NewManagement(mgmt.ManagementParams{ProjectID: config.ProjectID, FGACacheURL: config.FGACacheURL}, provider, mgmtClient)

return &DescopeClient{Auth: authService, Management: managementService, config: config}, nil
Expand Down