diff --git a/backoff.go b/backoff.go new file mode 100644 index 00000000..c3308c89 --- /dev/null +++ b/backoff.go @@ -0,0 +1,50 @@ +package flagsmith + +import ( + "context" + "time" +) + +const ( + initialBackoff = 200 * time.Millisecond + maxBackoff = 30 * time.Second +) + +// backoff handles exponential backoff with jitter. +type backoff struct { + current time.Duration +} + +// newBackoff creates a new backoff instance. +func newBackoff() *backoff { + return &backoff{ + current: initialBackoff, + } +} + +// next returns the next backoff duration and updates the current backoff. +func (b *backoff) next() time.Duration { + // Add jitter between 0-1s + backoff := b.current + time.Duration(time.Now().UnixNano()%1e9) + + // Double the backoff time, but cap it + if b.current < maxBackoff { + b.current *= 2 + } + + return backoff +} + +// reset resets the backoff to initial value. +func (b *backoff) reset() { + b.current = initialBackoff +} + +// wait waits for the current backoff time, or until ctx is done. +func (b *backoff) wait(ctx context.Context) { + select { + case <-ctx.Done(): + return + case <-time.After(b.next()): + } +} diff --git a/backoff_test.go b/backoff_test.go new file mode 100644 index 00000000..1149825a --- /dev/null +++ b/backoff_test.go @@ -0,0 +1,31 @@ +package flagsmith + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBackoff(t *testing.T) { + // Given + b := newBackoff() + + // When + first := b.next() + second := b.next() + third := b.next() + + // Then + assert.LessOrEqual(t, third, maxBackoff, "Backoff should not exceed max") + + // Backoff increases across attempts + assert.Greater(t, second, first, "Second backoff should be greater than the first") + assert.Greater(t, third, second, "Third backoff should be greater than the second") +} + +func TestBackoffReset(t *testing.T) { + b := newBackoff() + assert.Greater(t, b.next(), initialBackoff) + b.reset() + assert.Equal(t, initialBackoff, b.current, "Reset should return to initial backoff") +} diff --git a/client.go b/client.go index eccc2367..ced66dc9 100644 --- a/client.go +++ b/client.go @@ -30,6 +30,7 @@ type Client struct { identitiesWithOverrides atomic.Value analyticsProcessor *AnalyticsProcessor + realtime *realtime defaultFlagHandler func(string) (Flag, error) client *resty.Client @@ -76,7 +77,7 @@ func NewClient(apiKey string, options ...Option) *Client { OnBeforeRequest(newRestyLogRequestMiddleware(c.log)). OnAfterResponse(newRestyLogResponseMiddleware(c.log)) - c.log.Debug("initialising Flagsmith client", + c.log.Info("initialising Flagsmith client", "base_url", c.config.baseURL, "local_evaluation", c.config.localEvaluation, "offline", c.config.offlineMode, @@ -104,10 +105,13 @@ func NewClient(apiKey string, options ...Option) *Client { if !strings.HasPrefix(apiKey, "ser.") { panic("In order to use local evaluation, please generate a server key in the environment settings page.") } + if c.config.polling || !c.config.useRealtime { + // Poll indefinitely + go c.pollEnvironment(c.ctxLocalEval, true) + } if c.config.useRealtime { - go c.startRealtimeUpdates(c.ctxLocalEval) - } else { - go c.pollEnvironment(c.ctxLocalEval) + // Poll until we get the environment once + go c.pollThenStartRealtime(c.ctxLocalEval) } } // Initialise analytics processor @@ -336,26 +340,76 @@ func (c *Client) getEnvironmentFlagsFromEnvironment() (Flags, error) { ), nil } -func (c *Client) pollEnvironment(ctx context.Context) { +func (c *Client) pollEnvironment(ctx context.Context, pollForever bool) { + log := c.log.With(slog.String("worker", "poll")) update := func() { - ctx, cancel := context.WithTimeout(ctx, c.config.envRefreshInterval) + log.Debug("polling environment") + ctx, cancel := context.WithTimeout(ctx, c.config.timeout) defer cancel() err := c.UpdateEnvironment(ctx) if err != nil { - c.log.Error("failed to update environment", "error", err) + log.Error("failed to update environment", "error", err) } } update() ticker := time.NewTicker(c.config.envRefreshInterval) + defer func() { + ticker.Stop() + log.Info("polling stopped") + }() for { select { case <-ticker.C: + if !pollForever { + // Check if environment was successfully fetched + if _, ok := c.environment.Load().(*environments.EnvironmentModel); ok { + if !pollForever { + c.log.Debug("environment initialised") + return + } + } + } update() case <-ctx.Done(): return } } } + +func (c *Client) pollThenStartRealtime(ctx context.Context) { + b := newBackoff() + update := func() { + c.log.Debug("polling environment") + ctx, cancel := context.WithTimeout(ctx, c.config.envRefreshInterval) + defer cancel() + err := c.UpdateEnvironment(ctx) + if err != nil { + c.log.Error("failed to update environment", "error", err) + b.wait(ctx) + } + } + update() + defer func() { + c.log.Info("initial polling stopped") + }() + for { + select { + case <-ctx.Done(): + return + default: + // If environment was fetched, start realtime and finish + if env, ok := c.environment.Load().(*environments.EnvironmentModel); ok { + streamURL := c.config.realtimeBaseUrl + "sse/environments/" + env.APIKey + "/stream" + c.log.Debug("environment initialised, starting realtime updates") + c.realtime = newRealtime(c, ctx, streamURL, env.UpdatedAt) + go c.realtime.start() + return + } + update() + } + } +} + func (c *Client) UpdateEnvironment(ctx context.Context) error { var env environments.EnvironmentModel resp, err := c.client.NewRequest(). @@ -380,6 +434,11 @@ func (c *Client) UpdateEnvironment(ctx context.Context) error { } return f } + isNew := false + previousEnv := c.environment.Load() + if previousEnv == nil || env.UpdatedAt.After(previousEnv.(*environments.EnvironmentModel).UpdatedAt) { + isNew = true + } c.environment.Store(&env) identitiesWithOverrides := make(map[string]identities.IdentityModel) for _, id := range env.IdentityOverrides { @@ -387,7 +446,10 @@ func (c *Client) UpdateEnvironment(ctx context.Context) error { } c.identitiesWithOverrides.Store(identitiesWithOverrides) - c.log.Info("environment updated", "environment", env.APIKey) + if isNew { + c.log.Info("environment updated", "environment", env.APIKey, "updated_at", env.UpdatedAt) + } + return nil } diff --git a/client_test.go b/client_test.go index 0508593e..1a32d53d 100644 --- a/client_test.go +++ b/client_test.go @@ -977,3 +977,45 @@ func TestWithSlogLogger(t *testing.T) { t.Log(logStr) assert.Contains(t, logStr, "initialising Flagsmith client") } + +func TestWithPollingWorksWithRealtime(t *testing.T) { + ctx := context.Background() + server := httptest.NewServer(http.HandlerFunc(fixtures.EnvironmentDocumentHandler)) + defer server.Close() + + // guard against data race from goroutines logging at the same time + var logOutput strings.Builder + var logMu sync.Mutex + slogLogger := slog.New(slog.NewTextHandler(writerFunc(func(p []byte) (n int, err error) { + logMu.Lock() + defer logMu.Unlock() + return logOutput.Write(p) + }), &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + + // Given + _ = flagsmith.NewClient(fixtures.EnvironmentAPIKey, + flagsmith.WithSlogLogger(slogLogger), + flagsmith.WithLocalEvaluation(ctx), + flagsmith.WithRealtime(), + flagsmith.WithPolling(), + flagsmith.WithBaseURL(server.URL+"/api/v1/")) + + // When + time.Sleep(500 * time.Millisecond) + + // Then + logMu.Lock() + logStr := logOutput.String() + logMu.Unlock() + assert.Contains(t, logStr, "worker=poll") + assert.Contains(t, logStr, "worker=realtime") +} + +// writerFunc implements io.Writer. +type writerFunc func(p []byte) (n int, err error) + +func (f writerFunc) Write(p []byte) (n int, err error) { + return f(p) +} diff --git a/config.go b/config.go index ff0472ae..334b822b 100644 --- a/config.go +++ b/config.go @@ -26,6 +26,7 @@ type config struct { offlineMode bool realtimeBaseUrl string useRealtime bool + polling bool } // defaultConfig returns default configuration. diff --git a/options.go b/options.go index 256231c6..01dec3da 100644 --- a/options.go +++ b/options.go @@ -22,6 +22,7 @@ var _ = []Option{ WithCustomHeaders(nil), WithDefaultHandler(nil), WithProxy(""), + WithPolling(), WithRealtime(), WithRealtimeBaseURL(""), WithLogger(nil), @@ -157,3 +158,10 @@ func WithRealtimeBaseURL(url string) Option { c.config.realtimeBaseUrl = url } } + +// WithPolling makes it so that the client will poll for updates even when WithRealtime is used. +func WithPolling() Option { + return func(c *Client) { + c.config.polling = true + } +} diff --git a/realtime.go b/realtime.go index 06e0f73e..8bc74559 100644 --- a/realtime.go +++ b/realtime.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "log/slog" "net/http" "strings" @@ -13,60 +14,105 @@ import ( "github.com/Flagsmith/flagsmith-go-client/v4/flagengine/environments" ) -func (c *Client) startRealtimeUpdates(ctx context.Context) { - err := c.UpdateEnvironment(ctx) - if err != nil { - panic("Failed to fetch the environment while configuring real-time updates") +// realtime handles the SSE connection and reconnection logic. +type realtime struct { + client *Client + ctx context.Context + log *slog.Logger + streamURL string + envUpdatedAt time.Time + backoff *backoff +} + +// newRealtime creates a new realtime instance. +func newRealtime(client *Client, ctx context.Context, streamURL string, envUpdatedAt time.Time) *realtime { + return &realtime{ + client: client, + ctx: ctx, + log: client.log.With( + slog.String("worker", "realtime"), + slog.String("stream", streamURL), + ), + streamURL: streamURL, + envUpdatedAt: envUpdatedAt, + backoff: newBackoff(), } - env, _ := c.environment.Load().(*environments.EnvironmentModel) - stream_url := c.config.realtimeBaseUrl + "sse/environments/" + env.APIKey + "/stream" - envUpdatedAt := env.UpdatedAt - log := c.log.With( - slog.String("worker", "realtime"), - slog.String("stream", stream_url), - ) +} + +// start begins the realtime connection process. +func (r *realtime) start() { + r.log.Debug("connecting to realtime") defer func() { - log.Info("realtime stopped") + r.log.Info("stopped") }() for { select { - case <-ctx.Done(): + case <-r.ctx.Done(): return default: - resp, err := http.Get(stream_url) - if err != nil { - log.Error("failed to connect to realtime stream", "error", err) - continue + if err := r.connect(); err != nil { + r.log.Error("failed to connect", "error", err) + r.backoff.wait(r.ctx) } - defer resp.Body.Close() - log.Info("connected") - - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "data: ") { - parsedTime, err := parseUpdatedAtFromSSE(line) - if err != nil { - log.Error("failed to parse event message", "error", err, "message", line) - continue - } - if parsedTime.After(envUpdatedAt) { - err = c.UpdateEnvironment(ctx) - if err != nil { - log.Error("failed to update environment after receiving event", "error", err) - continue - } - env, _ := c.environment.Load().(*environments.EnvironmentModel) - envUpdatedAt = env.UpdatedAt - } + } + } +} + +// connect establishes and maintains the SSE connection. +func (r *realtime) connect() error { + resp, err := http.Get(r.streamURL) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("error response connecting to stream: %d", resp.StatusCode) + } + + r.log.Info("connected") + r.backoff.reset() + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + select { + case <-r.ctx.Done(): + return r.ctx.Err() + default: + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + if err := r.handleEvent(line); err != nil { + r.log.Error("failed to handle event", "error", err, "message", line) } } - if err := scanner.Err(); err != nil { - log.Error("error reading from realtime stream", "error", err) - } } } + + if err := scanner.Err(); err != nil { + r.log.Error("failed to read from stream", "error", err) + return err + } + + return nil +} + +// handleEvent processes a single SSE event. +func (r *realtime) handleEvent(line string) error { + parsedTime, err := parseUpdatedAtFromSSE(line) + if err != nil { + return err + } + + if parsedTime.After(r.envUpdatedAt) { + if err := r.client.UpdateEnvironment(r.ctx); err != nil { + return err + } + if env, ok := r.client.environment.Load().(*environments.EnvironmentModel); ok { + r.envUpdatedAt = env.UpdatedAt + } + } + return nil } + func parseUpdatedAtFromSSE(line string) (time.Time, error) { var eventData struct { UpdatedAt float64 `json:"updated_at"`