diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index fc66ea2a..0d19fdf4 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -44,4 +44,6 @@ jobs: run: go build -v ./... - name: Test - run: go test -v -race ./... + run: | + go test -v -race ./... + go test -tags=test ./... diff --git a/client.go b/client.go index ced66dc9..f764b143 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,9 @@ import ( "context" "fmt" "log/slog" + "net/http" + "reflect" + "runtime" "strings" "sync/atomic" "time" @@ -34,6 +37,7 @@ type Client struct { defaultFlagHandler func(string) (Flag, error) client *resty.Client + httpClient *http.Client ctxLocalEval context.Context ctxAnalytics context.Context log *slog.Logger @@ -52,26 +56,64 @@ func GetEvaluationContextFromCtx(ctx context.Context) (ec EvaluationContext, ok return ec, ok } +func getOptionQualifiedName(opt Option) string { + return runtime.FuncForPC(reflect.ValueOf(opt).Pointer()).Name() +} + +func isClientOption(name string) bool { + return strings.Contains(name, OptionWithHTTPClient) || strings.Contains(name, OptionWithRestyClient) +} + // NewClient creates instance of Client with given configuration. func NewClient(apiKey string, options ...Option) *Client { c := &Client{ apiKey: apiKey, config: defaultConfig(), - client: resty.New(), + } + + customClientCount := 0 + for _, opt := range options { + name := getOptionQualifiedName(opt) + if isClientOption(name) { + customClientCount = customClientCount + 1 + if customClientCount > 1 { + panic("Only one client option can be provided") + } + opt(c) + } + } + + // If a resty custom client has been provided, client is already set - otherwise we use a custom http client or default to a resty + if c.client == nil { + if c.httpClient != nil { + c.client = resty.NewWithClient(c.httpClient) + c.config.userProvidedClient = true + } else { + c.client = resty.New() + } + } else { + c.config.userProvidedClient = true } c.client.SetHeaders(map[string]string{ "Accept": "application/json", EnvironmentKeyHeader: c.apiKey, }) - c.client.SetTimeout(c.config.timeout) + + if c.client.GetClient().Timeout == 0 { + c.client.SetTimeout(c.config.timeout) + } + c.log = createLogger() for _, opt := range options { - if opt != nil { - opt(c) + name := getOptionQualifiedName(opt) + if isClientOption(name) { + continue } + opt(c) } + c.client = c.client. SetLogger(newSlogToRestyAdapter(c.log)). OnBeforeRequest(newRestyLogRequestMiddleware(c.log)). diff --git a/client_http_test.go b/client_http_test.go new file mode 100644 index 00000000..b5592cd5 --- /dev/null +++ b/client_http_test.go @@ -0,0 +1,48 @@ +//go:build test + +package flagsmith + +import ( + "testing" + "time" + + "github.com/go-resty/resty/v2" + "github.com/stretchr/testify/assert" +) + +func (c *Client) ExposeRestyClient() *resty.Client { + return c.client +} + +func TestCustomClientRetriesAreSet(t *testing.T) { + retryCount := 5 + + customResty := resty.New(). + SetRetryCount(retryCount). + SetRetryWaitTime(10 * time.Millisecond) + + client := NewClient("env-key", WithRestyClient(customResty)) + + internal := client.ExposeRestyClient() + assert.Equal(t, retryCount, internal.RetryCount) + assert.Equal(t, 10*time.Millisecond, internal.RetryWaitTime) +} + +func TestCustomRestyClientTimeoutIsNotOverriddenWithDefaultTimeout(t *testing.T) { + customResty := resty.New().SetTimeout(13 * time.Millisecond) + + client := NewClient("env-key", WithRestyClient(customResty)) + + internal := client.ExposeRestyClient() + + assert.Equal(t, 13*time.Millisecond, internal.GetClient().Timeout) +} + +func TestCustomRestyClientHasDefaultTimeoutIfNotProvided(t *testing.T) { + customResty := resty.New() + + client := NewClient("env-key", WithRestyClient(customResty)) + + internal := client.ExposeRestyClient() + assert.Equal(t, 10*time.Second, internal.GetClient().Timeout) +} diff --git a/client_test.go b/client_test.go index 1a32d53d..bdee4669 100644 --- a/client_test.go +++ b/client_test.go @@ -15,6 +15,7 @@ import ( flagsmith "github.com/Flagsmith/flagsmith-go-client/v4" "github.com/Flagsmith/flagsmith-go-client/v4/fixtures" + "github.com/go-resty/resty/v2" "github.com/stretchr/testify/assert" ) @@ -1019,3 +1020,160 @@ type writerFunc func(p []byte) (n int, err error) func (f writerFunc) Write(p []byte) (n int, err error) { return f(p) } + +// Helper function to implement a header interceptor. +func roundTripperWithHeader(key, value string) http.RoundTripper { + return &injectHeaderTransport{key: key, value: value} +} + +type injectHeaderTransport struct { + key string + value string +} + +func (t *injectHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set(t.key, t.value) + return http.DefaultTransport.RoundTrip(req) +} + +func TestCustomHTTPClientIsUsed(t *testing.T) { + ctx := context.Background() + + hasCustomHeader := false + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + assert.Equal(t, "/api/v1/flags/", req.URL.Path) + assert.Equal(t, fixtures.EnvironmentAPIKey, req.Header.Get("x-Environment-Key")) + if req.Header.Get("X-Test-Client") == "http" { + hasCustomHeader = true + } + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusOK) + _, err := io.WriteString(rw, fixtures.FlagsJson) + assert.NoError(t, err) + })) + defer server.Close() + + customClient := &http.Client{ + Transport: roundTripperWithHeader("X-Test-Client", "http"), + } + + client := flagsmith.NewClient(fixtures.EnvironmentAPIKey, + flagsmith.WithHTTPClient(customClient), + flagsmith.WithBaseURL(server.URL+"/api/v1/")) + + flags, err := client.GetFlags(ctx, nil) + assert.Equal(t, 1, len(flags.AllFlags())) + assert.NoError(t, err) + assert.True(t, hasCustomHeader, "Expected http header") + flag, err := flags.GetFlag(fixtures.Feature1Name) + assert.NoError(t, err) + assert.Equal(t, fixtures.Feature1Value, flag.Value) +} + +func TestCustomRestyClientIsUsed(t *testing.T) { + ctx := context.Background() + + hasCustomHeader := false + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Header.Get("X-Custom-Test-Header") == "resty" { + hasCustomHeader = true + } + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusOK) + _, err := io.WriteString(rw, fixtures.FlagsJson) + assert.NoError(t, err) + })) + defer server.Close() + + restyClient := resty.New(). + SetHeader("X-Custom-Test-Header", "resty") + + client := flagsmith.NewClient(fixtures.EnvironmentAPIKey, + flagsmith.WithRestyClient(restyClient), + flagsmith.WithBaseURL(server.URL+"/api/v1/")) + + flags, err := client.GetFlags(ctx, nil) + assert.NoError(t, err) + assert.Equal(t, 1, len(flags.AllFlags())) + assert.True(t, hasCustomHeader, "Expected custom resty header") +} + +func TestRestyClientOverridesHTTPClientShouldPanic(t *testing.T) { + httpClient := &http.Client{ + Transport: roundTripperWithHeader("X-Test-Client", "http"), + } + + restyClient := resty.New(). + SetHeader("X-Test-Client", "resty") + + assert.Panics(t, func() { + _ = flagsmith.NewClient(fixtures.EnvironmentAPIKey, + flagsmith.WithHTTPClient(httpClient), + flagsmith.WithRestyClient(restyClient), + flagsmith.WithBaseURL("http://example.com/api/v1/")) + }, "Expected panic when both HTTP and Resty clients are provided") +} + +func TestDefaultRestyClientIsUsed(t *testing.T) { + ctx := context.Background() + + serverCalled := false + + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + serverCalled = true + + assert.Equal(t, "/api/v1/flags/", req.URL.Path) + assert.Equal(t, fixtures.EnvironmentAPIKey, req.Header.Get("x-Environment-Key")) + + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusOK) + _, err := io.WriteString(rw, fixtures.FlagsJson) + assert.NoError(t, err) + })) + defer server.Close() + + client := flagsmith.NewClient(fixtures.EnvironmentAPIKey, + flagsmith.WithBaseURL(server.URL+"/api/v1/")) + + flags, err := client.GetFlags(ctx, nil) + + assert.NoError(t, err) + assert.True(t, serverCalled, "Expected server to be") + assert.Equal(t, 1, len(flags.AllFlags())) +} + +func TestCustomClientOptionsShoudPanic(t *testing.T) { + restyClient := resty.New() + + testCases := []struct { + name string + option flagsmith.Option + }{ + { + name: "WithRequestTimeout", + option: flagsmith.WithRequestTimeout(5 * time.Second), + }, + { + name: "WithRetries", + option: flagsmith.WithRetries(3, time.Second), + }, + { + name: "WithCustomHeaders", + option: flagsmith.WithCustomHeaders(map[string]string{"X-Custom": "value"}), + }, + { + name: "WithProxy", + option: flagsmith.WithProxy("http://proxy.example.com"), + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + assert.Panics(t, func() { + _ = flagsmith.NewClient(fixtures.EnvironmentAPIKey, + flagsmith.WithRestyClient(restyClient), + test.option) + }, "Expected panic when using %s with custom resty client", test.name) + }) + } +} diff --git a/config.go b/config.go index 334b822b..859a520b 100644 --- a/config.go +++ b/config.go @@ -27,6 +27,7 @@ type config struct { realtimeBaseUrl string useRealtime bool polling bool + userProvidedClient bool } // defaultConfig returns default configuration. @@ -36,5 +37,6 @@ func defaultConfig() config { timeout: DefaultTimeout, envRefreshInterval: time.Second * 60, realtimeBaseUrl: DefaultRealtimeBaseUrl, + userProvidedClient: false, } } diff --git a/options.go b/options.go index 01dec3da..3b6436c5 100644 --- a/options.go +++ b/options.go @@ -2,10 +2,18 @@ package flagsmith import ( "context" + "net/http" "strings" "time" "log/slog" + + "github.com/go-resty/resty/v2" +) + +const ( + OptionWithHTTPClient = "WithHTTPClient" + OptionWithRestyClient = "WithRestyClient" ) type Option func(c *Client) @@ -27,6 +35,8 @@ var _ = []Option{ WithRealtimeBaseURL(""), WithLogger(nil), WithSlogLogger(nil), + WithRestyClient(nil), + WithHTTPClient(nil), } func WithBaseURL(url string) Option { @@ -55,6 +65,9 @@ func WithRemoteEvaluation() Option { func WithRequestTimeout(timeout time.Duration) Option { return func(c *Client) { + if c.config.userProvidedClient { + panic("options modifying the client can not be used with a custom client") + } c.client.SetTimeout(timeout) } } @@ -79,6 +92,9 @@ func WithAnalytics(ctx context.Context) Option { func WithRetries(count int, waitTime time.Duration) Option { return func(c *Client) { + if c.config.userProvidedClient { + panic("options modifying the client can not be used with a custom client") + } c.client.SetRetryCount(count) c.client.SetRetryWaitTime(waitTime) } @@ -86,6 +102,9 @@ func WithRetries(count int, waitTime time.Duration) Option { func WithCustomHeaders(headers map[string]string) Option { return func(c *Client) { + if c.config.userProvidedClient { + panic("options modifying the client can not be used with a custom client") + } c.client.SetHeaders(headers) } } @@ -114,6 +133,9 @@ func WithSlogLogger(logger *slog.Logger) Option { // The proxyURL argument is a string representing the URL of the proxy server to use, e.g. "http://proxy.example.com:8080". func WithProxy(proxyURL string) Option { return func(c *Client) { + if c.config.userProvidedClient { + panic("options modifying the client can not be used with a custom client") + } c.client.SetProxy(proxyURL) } } @@ -165,3 +187,19 @@ func WithPolling() Option { c.config.polling = true } } + +func WithHTTPClient(httpClient *http.Client) Option { + return func(c *Client) { + if httpClient != nil { + c.httpClient = httpClient + } + } +} + +func WithRestyClient(restyClient *resty.Client) Option { + return func(c *Client) { + if restyClient != nil { + c.client = restyClient + } + } +}