Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: add support for application state management #3360

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ type App struct {
tlsHandler *TLSHandler
// Mount fields
mountFields *mountFields
// state management
state *State
// Route stack divided by HTTP methods
stack [][]*Route
// Route stack divided by HTTP methods and route prefixes
Expand Down Expand Up @@ -515,6 +517,9 @@ func New(config ...Config) *App {
// Define mountFields
app.mountFields = newMountFields(app)

// Define state
app.state = newState()

// Override config if provided
if len(config) > 0 {
app.config = config[0]
Expand Down Expand Up @@ -952,6 +957,11 @@ func (app *App) Hooks() *Hooks {
return app.hooks
}

// State returns the state struct to store global data in order to share it between handlers.
func (app *App) State() *State {
return app.state
}

var ErrTestGotEmptyResponse = errors.New("test: got empty response")

// TestConfig is a struct holding Test settings
Expand Down
10 changes: 10 additions & 0 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1890,6 +1890,16 @@ func Test_Route_Naming_Issue_2671_2685(t *testing.T) {
require.Equal(t, "/simple-route", sRoute2.Path)
}

func Test_App_State(t *testing.T) {
t.Parallel()
app := New()

app.State().Set("key", "value")
str, ok := app.State().GetString("key")
require.True(t, ok)
require.Equal(t, "value", str)
}

// go test -v -run=^$ -bench=Benchmark_Communication_Flow -benchmem -count=4
func Benchmark_Communication_Flow(b *testing.B) {
app := New()
Expand Down
141 changes: 141 additions & 0 deletions state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package fiber

import (
"sync"
)

// State is a key-value store for Fiber's app in order to be used as a global storage for the app's dependencies.
// It's a thread-safe implementation of a map[string]any, using sync.Map.
type State struct {
dependencies sync.Map
}

// NewState creates a new instance of State.
func newState() *State {
return &State{
dependencies: sync.Map{},
}
}

// Set sets a key-value pair in the State.
func (s *State) Set(key string, value any) {
s.dependencies.Store(key, value)
}

// Get retrieves a value from the State.
func (s *State) Get(key string) (any, bool) {
return s.dependencies.Load(key)
}

// GetString retrieves a string value from the State.
func (s *State) GetString(key string) (string, bool) {
dep, ok := s.Get(key)
if ok {
depString, okCast := dep.(string)
return depString, okCast
}

return "", false
}

// GetInt retrieves an int value from the State.
func (s *State) GetInt(key string) (int, bool) {
dep, ok := s.Get(key)
if ok {
depInt, okCast := dep.(int)
return depInt, okCast
}

return 0, false
}

// GetBool retrieves a bool value from the State.
func (s *State) GetBool(key string) (value, ok bool) { //nolint:nonamedreturns // Better idea to use named returns here
dep, ok := s.Get(key)
if ok {
depBool, okCast := dep.(bool)
return depBool, okCast
}

return false, false
}

// GetFloat64 retrieves a float64 value from the State.
func (s *State) GetFloat64(key string) (float64, bool) {
dep, ok := s.Get(key)
if ok {
depFloat64, okCast := dep.(float64)
return depFloat64, okCast
}

return 0, false
}

// MustGet retrieves a value from the State and panics if the key is not found.
func (s *State) MustGet(key string) any {
if dep, ok := s.Get(key); ok {
return dep
}

panic("state: dependency not found!")
}

// MustGetString retrieves a string value from the State and panics if the key is not found.
func (s *State) Delete(key string) {
s.dependencies.Delete(key)
}

// Reset resets the State.
func (s *State) Clear() {
s.dependencies.Clear()
}

// Keys retrieves all the keys from the State.
func (s *State) Keys() []string {
keys := make([]string, 0)
Copy link
Member Author

@efectn efectn Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we can preallocate the keys slice using s.Len() to reduce the memory allocation; however, it will make the method a little bit slower.

s.dependencies.Range(func(key, _ any) bool {
keyStr, ok := key.(string)
if !ok {
return false
}

Check warning on line 100 in state.go

View check run for this annotation

Codecov / codecov/patch

state.go#L99-L100

Added lines #L99 - L100 were not covered by tests

keys = append(keys, keyStr)
return true
})

return keys
}

// Len retrieves the number of dependencies in the State.
func (s *State) Len() int {
length := 0
s.dependencies.Range(func(_, _ any) bool {
length++
return true
})

return length
}

// GetState retrieves a value from the State and casts it to the desired type.
func GetState[T any](s *State, key string) (T, bool) {
dep, ok := s.Get(key)

if ok {
depT, okCast := dep.(T)
return depT, okCast
}

var zeroVal T
return zeroVal, false
}

// MustGetState retrieves a value from the State and casts it to the desired type, panicking if the key is not found.
func MustGetState[T any](s *State, key string) T {
dep, ok := GetState[T](s, key)
if !ok {
panic("state: dependency not found!")
}

return dep
}
Loading
Loading