Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions backoff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package flagsmith

import "time"

const (
initialBackoff = 200 * time.Millisecond
maxBackoff = 30 * time.Second
)

// backoff handles exponential backoff with jitter
type backoff struct {
current time.Duration
}

// newBackoff creates a new backoff instance
func newBackoff() *backoff {
return &backoff{
current: initialBackoff,
}
}

// next returns the next backoff duration and updates the current backoff
func (b *backoff) next() time.Duration {
// Add jitter between 0-1s
backoff := b.current + time.Duration(time.Now().UnixNano()%1e9)

// Double the backoff time, but cap it
if b.current < maxBackoff {
b.current *= 2
}

return backoff
}

// reset resets the backoff to initial value
func (b *backoff) reset() {
b.current = initialBackoff
}
31 changes: 31 additions & 0 deletions backoff_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package flagsmith

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestBackoff(t *testing.T) {
// Given
b := newBackoff()

// When
first := b.next()
second := b.next()
third := b.next()

// Then
assert.LessOrEqual(t, third, maxBackoff, "Backoff should not exceed max")

// Backoff increases across attempts
assert.Greater(t, second, first, "Second backoff should be greater than the first")
assert.Greater(t, third, second, "Third backoff should be greater than the second")
}

func TestBackoffReset(t *testing.T) {
b := newBackoff()
assert.Greater(t, b.next(), initialBackoff)
b.reset()
assert.Equal(t, initialBackoff, b.current, "Reset should return to initial backoff")
}
49 changes: 42 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"

Expand All @@ -30,6 +31,7 @@ type Client struct {
identitiesWithOverrides atomic.Value

analyticsProcessor *AnalyticsProcessor
realtime *realtime
defaultFlagHandler func(string) (Flag, error)

client *resty.Client
Expand All @@ -38,6 +40,8 @@ type Client struct {
log *slog.Logger
offlineHandler OfflineHandler
errorHandler func(handler *FlagsmithAPIError)

once sync.Once
}

// Returns context with provided EvaluationContext instance set.
Expand Down Expand Up @@ -71,9 +75,12 @@ func NewClient(apiKey string, options ...Option) *Client {
opt(c)
}
}
c.client.SetLogger(newSlogToRestyAdapter(c.log))
c.client = c.client.
SetLogger(newSlogToRestyAdapter(c.log)).
OnBeforeRequest(newRestyLogRequestMiddleware(c.log)).
OnAfterResponse(newRestyLogResponseMiddleware(c.log))

c.log.Debug("initialising Flagsmith client",
c.log.Info("initialising Flagsmith client",
"base_url", c.config.baseURL,
"local_evaluation", c.config.localEvaluation,
"offline", c.config.offlineMode,
Expand Down Expand Up @@ -101,11 +108,15 @@ func NewClient(apiKey string, options ...Option) *Client {
if !strings.HasPrefix(apiKey, "ser.") {
panic("In order to use local evaluation, please generate a server key in the environment settings page.")
}
if c.config.polling || !c.config.useRealtime {
// Poll indefinitely
go c.pollEnvironment(c.ctxLocalEval, true)
}
if c.config.useRealtime {
go c.startRealtimeUpdates(c.ctxLocalEval)
} else {
go c.pollEnvironment(c.ctxLocalEval)
// Poll until we get the environment once
go c.pollEnvironment(c.ctxLocalEval, false)
}

}
// Initialise analytics processor
if c.config.enableAnalytics {
Expand Down Expand Up @@ -333,26 +344,42 @@ func (c *Client) getEnvironmentFlagsFromEnvironment() (Flags, error) {
), nil
}

func (c *Client) pollEnvironment(ctx context.Context) {
func (c *Client) pollEnvironment(ctx context.Context, pollForever bool) {
log := c.log.With(slog.String("worker", "poll"))
update := func() {
log.Debug("polling environment")
ctx, cancel := context.WithTimeout(ctx, c.config.envRefreshInterval)
defer cancel()
err := c.UpdateEnvironment(ctx)
if err != nil {
c.log.Error("failed to update environment", "error", err)
log.Error("failed to update environment", "error", err)
}
}
update()
ticker := time.NewTicker(c.config.envRefreshInterval)
defer func() {
ticker.Stop()
log.Info("polling stopped")
}()
for {
select {
case <-ticker.C:
if !pollForever {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be moved to real-time module/package; pollEnvironment should not be responsible for this

// Check if environment was successfully fetched
if _, ok := c.environment.Load().(*environments.EnvironmentModel); ok {
if !pollForever {
c.log.Debug("environment initialised")
return
}
}
}
update()
case <-ctx.Done():
return
}
}
}

func (c *Client) UpdateEnvironment(ctx context.Context) error {
var env environments.EnvironmentModel
resp, err := c.client.NewRequest().
Expand Down Expand Up @@ -385,6 +412,14 @@ func (c *Client) UpdateEnvironment(ctx context.Context) error {
c.identitiesWithOverrides.Store(identitiesWithOverrides)

c.log.Info("environment updated", "environment", env.APIKey)
c.once.Do(func() {
if c.config.useRealtime && c.realtime == nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be part of this method; UpdateEnvironment doesn't need to know anything about real-time

streamURL := c.config.realtimeBaseUrl + "sse/environments/" + env.APIKey + "/stream"
c.realtime = newRealtime(c, c.ctxLocalEval, streamURL, env.UpdatedAt)
c.log.Debug("environment initialised, starting realtime updates")
go c.realtime.start()
}
})
return nil
}

Expand Down
42 changes: 42 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -977,3 +977,45 @@ func TestWithSlogLogger(t *testing.T) {
t.Log(logStr)
assert.Contains(t, logStr, "initialising Flagsmith client")
}

func TestWithPollingWorksWithRealtime(t *testing.T) {
ctx := context.Background()
server := httptest.NewServer(http.HandlerFunc(fixtures.EnvironmentDocumentHandler))
defer server.Close()

// guard against data race from goroutines logging at the same time
var logOutput strings.Builder
var logMu sync.Mutex
slogLogger := slog.New(slog.NewTextHandler(writerFunc(func(p []byte) (n int, err error) {
logMu.Lock()
defer logMu.Unlock()
return logOutput.Write(p)
}), &slog.HandlerOptions{
Level: slog.LevelDebug,
}))

// Given
_ = flagsmith.NewClient(fixtures.EnvironmentAPIKey,
flagsmith.WithSlogLogger(slogLogger),
flagsmith.WithLocalEvaluation(ctx),
flagsmith.WithRealtime(),
flagsmith.WithPolling(),
flagsmith.WithBaseURL(server.URL+"/api/v1/"))

// When
time.Sleep(500 * time.Millisecond)

// Then
logMu.Lock()
logStr := logOutput.String()
logMu.Unlock()
assert.Contains(t, logStr, "worker=poll")
assert.Contains(t, logStr, "worker=realtime")
}

// writerFunc implements io.Writer
type writerFunc func(p []byte) (n int, err error)

func (f writerFunc) Write(p []byte) (n int, err error) {
return f(p)
}
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type config struct {
offlineMode bool
realtimeBaseUrl string
useRealtime bool
polling bool
}

// defaultConfig returns default configuration.
Expand Down
55 changes: 55 additions & 0 deletions logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import (
"log/slog"
"os"
"strings"
"time"

"github.com/go-resty/resty/v2"
)

// Logger is the interface used for logging by flagsmith client. This interface defines the methods
Expand Down Expand Up @@ -118,3 +121,55 @@ func createLogger() *slog.Logger {
Level: slog.LevelDebug,
}))
}

const (
contextLoggerKey contextKey = contextKey("logger")
contextStartTimeKey contextKey = contextKey("startTime")
)

// restySlogLogger implements a [resty.Logger] using a [slog.Logger].
type restySlogLogger struct {
logger *slog.Logger
}

func newRestyLogRequestMiddleware(logger *slog.Logger) resty.RequestMiddleware {
return func(c *resty.Client, req *resty.Request) error {
// Create a child logger with request metadata
reqLogger := logger.WithGroup("http").With(
"method", req.Method,
"url", req.URL,
)
reqLogger.Debug("request")

// Store the logger in this request's context, and use it in the response
req.SetContext(context.WithValue(req.Context(), contextLoggerKey, reqLogger))

// Time the current request
req.SetContext(context.WithValue(req.Context(), contextStartTimeKey, time.Now()))

return nil
}
}

func newRestyLogResponseMiddleware(logger *slog.Logger) resty.ResponseMiddleware {
return func(client *resty.Client, resp *resty.Response) error {
// Retrieve the logger and start time from context
reqLogger, _ := resp.Request.Context().Value(contextLoggerKey).(*slog.Logger)
startTime, _ := resp.Request.Context().Value(contextStartTimeKey).(time.Time)

if reqLogger == nil {
reqLogger = logger
}
reqLogger = reqLogger.With(
slog.Int("status", resp.StatusCode()),
slog.Duration("duration", time.Since(startTime)),
slog.Int64("content_length", resp.Size()),
)
if resp.IsError() {
reqLogger.Error("error response")
} else {
reqLogger.Debug("response")
}
return nil
}
}
8 changes: 8 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ var _ = []Option{
WithCustomHeaders(nil),
WithDefaultHandler(nil),
WithProxy(""),
WithPolling(),
WithRealtime(),
WithRealtimeBaseURL(""),
WithLogger(nil),
Expand Down Expand Up @@ -157,3 +158,10 @@ func WithRealtimeBaseURL(url string) Option {
c.config.realtimeBaseUrl = url
}
}

// WithPolling makes it so that the client will poll for updates even when WithRealtime is used.
func WithPolling() Option {
return func(c *Client) {
c.config.polling = true
}
}
Loading