diff --git a/client.go b/client.go index 7046417d..2d3bb49f 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "net/url" "strings" "time" @@ -77,6 +78,17 @@ func NewClient(apiKey string, options ...Option) (*Client, error) { } } + c.log.Debug("initialising Flagsmith client", + slog.String("api_url", c.baseURL), + slog.Bool("local_eval", c.localEvaluation), + slog.Duration("environment_refresh_interval", c.envRefreshInterval), + slog.Bool("realtime", c.useRealtime), + slog.String("realtime_url", c.realtimeBaseUrl), + slog.Bool("offline", c.state.IsOffline()), + slog.Duration("timeout", c.timeout), + slog.String("proxy", c.proxy), + ) + if c.state.IsOffline() { return c, nil } @@ -99,15 +111,32 @@ func NewClient(apiKey string, options ...Option) (*Client, error) { if c.localEvaluation { if !strings.HasPrefix(apiKey, "ser.") { - return nil, errors.New("using local flag evaluation requires a server-side SDK key; got " + apiKey) + return nil, errors.New("local flag evaluation requires a server-side SDK key; got " + apiKey) + } + if c.envRefreshInterval == 0 && !c.useRealtime { + return nil, errors.New("local flag evaluation requires a non-zero refresh interval or enabling real-time updates") } + + // Fail fast if we can't fetch the initial environment within the timeout + ctxWithTimeout, cancel := context.WithTimeout(c.ctxLocalEval, c.timeout) + defer cancel() + c.log.Debug("fetching initial environment") + env, err := c.updateAndReturnEnvironment(ctxWithTimeout) + if err != nil { + return nil, fmt.Errorf("failed to fetch initial environment: %w", err) + } + if c.useRealtime { - go c.startRealtimeUpdates(c.ctxLocalEval) - } else { + streamPath, err := url.JoinPath(c.realtimeBaseUrl, "sse/environments", env.APIKey, "stream") + if err != nil { + return nil, fmt.Errorf("failed to build stream URL: %w", err) + } + go c.startRealtimeUpdates(c.ctxLocalEval, streamPath) + } + if c.envRefreshInterval > 0 { go c.pollEnvironment(c.ctxLocalEval) } } - return c, nil } @@ -158,6 +187,11 @@ func (c *Client) GetFlags(ctx context.Context, ec EvaluationContext) (f Flags, e // UpdateEnvironment fetches the current environment state from the Flagsmith API. It is called periodically when using // [WithLocalEvaluation], or when [WithRealtime] is enabled and an update event was received. func (c *Client) UpdateEnvironment(ctx context.Context) error { + _, err := c.updateAndReturnEnvironment(ctx) + return err +} + +func (c *Client) updateAndReturnEnvironment(ctx context.Context) (*environments.EnvironmentModel, error) { var env environments.EnvironmentModel resp, err := c.client. NewRequest(). @@ -167,16 +201,16 @@ func (c *Client) UpdateEnvironment(ctx context.Context) error { Get(c.baseURL + "environment-document/") if err != nil { - return c.handleError(&APIError{Err: err}) + return nil, c.handleError(&APIError{Err: err}) } if resp.IsError() { e := &APIError{response: resp.RawResponse} - return c.handleError(e) + return nil, c.handleError(e) } c.state.SetEnvironment(&env) c.log.Info("environment updated", "environment", env.APIKey) - return nil + return &env, nil } // GetIdentitySegments returns the segments that this evaluation context is a part of. It requires a local environment @@ -309,21 +343,20 @@ func (c *Client) getIdentityFlagsFromEnvironment(identifier string, traits map[s } func (c *Client) pollEnvironment(ctx context.Context) { - update := func() { - ctx, cancel := context.WithTimeout(ctx, c.envRefreshInterval) - defer cancel() - err := c.UpdateEnvironment(ctx) - if err != nil { - c.log.Error("pollEnvironment failed", "error", err) - } - } - update() ticker := time.NewTicker(c.envRefreshInterval) for { select { case <-ticker.C: + env, ok := c.state.GetEnvironment() + if ok && time.Since(env.UpdatedAt) < c.envRefreshInterval { + c.log.Debug("environment is already fresh, skipping poll") + continue + } c.log.Debug("polling environment") - update() + err := c.UpdateEnvironment(ctx) + if err != nil { + c.log.Error("pollEnvironment failed", "error", err) + } case <-ctx.Done(): return } diff --git a/client_test.go b/client_test.go index db353779..e8502b5b 100644 --- a/client_test.go +++ b/client_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "sync" + "sync/atomic" "testing" "time" @@ -100,15 +101,10 @@ func TestClientUpdatesEnvironmentOnStartForLocalEvaluation(t *testing.T) { func TestClientUpdatesEnvironmentOnEachRefresh(t *testing.T) { // Given ctx := context.Background() - actualEnvironmentRefreshCounter := struct { - mu sync.Mutex - count int - }{} + var actualEnvironmentRefreshCounter atomic.Uint64 expectedEnvironmentRefreshCount := 3 server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - actualEnvironmentRefreshCounter.mu.Lock() - actualEnvironmentRefreshCounter.count++ - actualEnvironmentRefreshCounter.mu.Unlock() + actualEnvironmentRefreshCounter.Add(1) assert.Equal(t, req.URL.Path, "/api/v1/environment-document/") assert.Equal(t, fixtures.EnvironmentAPIKey, req.Header.Get("X-Environment-Key")) @@ -133,8 +129,7 @@ func TestClientUpdatesEnvironmentOnEachRefresh(t *testing.T) { // one when the client starts and 2 // for each time the refresh interval expires - actualEnvironmentRefreshCounter.mu.Lock() - assert.Equal(t, expectedEnvironmentRefreshCount, actualEnvironmentRefreshCounter.count) + assert.Equal(t, expectedEnvironmentRefreshCount, int(actualEnvironmentRefreshCounter.Load())) } func TestGetFlags(t *testing.T) { diff --git a/realtime.go b/realtime.go index e6c5d95a..a9082e07 100644 --- a/realtime.go +++ b/realtime.go @@ -7,33 +7,21 @@ import ( "errors" "log/slog" "net/http" - "net/url" "strings" "time" ) -func (c *Client) startRealtimeUpdates(ctx context.Context) { - err := c.UpdateEnvironment(ctx) - if err != nil { - panic("Failed to fetch the environment while configuring real-time updates") - } - +func (c *Client) startRealtimeUpdates(ctx context.Context, stream string) { env, _ := c.state.GetEnvironment() envUpdatedAt := env.UpdatedAt log := c.log.With("environment", env.APIKey, "current_updated_at", &envUpdatedAt) - streamPath, err := url.JoinPath(c.realtimeBaseUrl, "sse/environments", env.APIKey, "stream") - if err != nil { - log.Error("failed to build stream URL", "error", err) - panic(err) - } - for { select { case <-ctx.Done(): return default: - resp, err := http.Get(streamPath) + resp, err := http.Get(stream) if err != nil { log.Error("failed to connect to realtime service", "error", err) continue