Skip to content

Commit

Permalink
chore: Add tests for flags handlers (#8)
Browse files Browse the repository at this point in the history
* Add test for GetAllFeatureFlags

* add test for createFlags

* Bump go

* Add test for flags handlers
  • Loading branch information
thomaspoignant authored Oct 25, 2024
1 parent f38e84e commit 07a189c
Show file tree
Hide file tree
Showing 12 changed files with 1,076 additions and 41 deletions.
1 change: 1 addition & 0 deletions dao/err/postgres_error_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package daoErr
import (
"database/sql"
"errors"

"github.com/lib/pq"
)

Expand Down
1 change: 1 addition & 0 deletions dao/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dao

import (
"context"

daoErr "github.com/go-feature-flag/app-api/dao/err"
"github.com/go-feature-flag/app-api/model"
)
Expand Down
43 changes: 43 additions & 0 deletions dao/inmemory_impl_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dao
import (
"context"
"fmt"

daoErr "github.com/go-feature-flag/app-api/dao/err"
"github.com/go-feature-flag/app-api/model"
_ "github.com/lib/pq" // we import the driver used by sqlx
Expand All @@ -22,11 +23,23 @@ type InMemoryMockDao struct {

// GetFlags return all the flags
func (m *InMemoryMockDao) GetFlags(ctx context.Context) ([]model.FeatureFlag, daoErr.DaoError) {
if ctx.Value("error") != nil {
if err, ok := ctx.Value("error").(daoErr.DaoErrorCode); ok {
return nil, daoErr.NewDaoError(err, fmt.Errorf("error on get flags"))
}
return nil, daoErr.NewDaoError(daoErr.UnknownError, fmt.Errorf("error on get flags"))
}
return m.flags, nil
}

// GetFlagByID return a flag by its ID
func (m *InMemoryMockDao) GetFlagByID(ctx context.Context, id string) (model.FeatureFlag, daoErr.DaoError) {
if ctx.Value("error") != nil {
if err, ok := ctx.Value("error").(daoErr.DaoErrorCode); ok {
return model.FeatureFlag{}, daoErr.NewDaoError(err, fmt.Errorf("error on get flag by id"))
}
return model.FeatureFlag{}, daoErr.NewDaoError(daoErr.UnknownError, fmt.Errorf("error on get flag by id"))
}
for _, flag := range m.flags {
if flag.ID == id {
return flag, nil
Expand All @@ -37,6 +50,12 @@ func (m *InMemoryMockDao) GetFlagByID(ctx context.Context, id string) (model.Fea

// GetFlagByName return a flag by its name
func (m *InMemoryMockDao) GetFlagByName(ctx context.Context, name string) (model.FeatureFlag, daoErr.DaoError) {
if ctx.Value("error") != nil {
if err, ok := ctx.Value("error").(daoErr.DaoErrorCode); ok {
return model.FeatureFlag{}, daoErr.NewDaoError(err, fmt.Errorf("error on get flag by name"))
}
return model.FeatureFlag{}, daoErr.NewDaoError(daoErr.UnknownError, fmt.Errorf("error on get flag by name"))
}
for _, flag := range m.flags {
if flag.Name == name {
return flag, nil
Expand All @@ -47,11 +66,24 @@ func (m *InMemoryMockDao) GetFlagByName(ctx context.Context, name string) (model

// CreateFlag create a new flag, return the id of the flag
func (m *InMemoryMockDao) CreateFlag(ctx context.Context, flag model.FeatureFlag) (string, daoErr.DaoError) {
if ctx.Value("error_create") != nil {
if err, ok := ctx.Value("error_create").(daoErr.DaoErrorCode); ok {
return "", daoErr.NewDaoError(err, fmt.Errorf("error creating flag"))
}
return "", daoErr.NewDaoError(daoErr.UnknownError, fmt.Errorf("error creating flag"))
}

m.flags = append(m.flags, flag)
return flag.ID, nil
}

func (m *InMemoryMockDao) UpdateFlag(ctx context.Context, flag model.FeatureFlag) daoErr.DaoError {
if ctx.Value("error_update") != nil {
if err, ok := ctx.Value("error_update").(daoErr.DaoErrorCode); ok {
return daoErr.NewDaoError(err, fmt.Errorf("error on update flags"))
}
return daoErr.NewDaoError(daoErr.UnknownError, fmt.Errorf("error on update flags"))
}
for index, f := range m.flags {
if f.ID == flag.ID {
m.flags[index] = flag
Expand All @@ -62,6 +94,13 @@ func (m *InMemoryMockDao) UpdateFlag(ctx context.Context, flag model.FeatureFlag
}

func (m *InMemoryMockDao) DeleteFlagByID(ctx context.Context, id string) daoErr.DaoError {
if ctx.Value("error_delete") != nil {
if err, ok := ctx.Value("error_delete").(daoErr.DaoErrorCode); ok {
return daoErr.NewDaoError(err, fmt.Errorf("error on get flags"))
}
return daoErr.NewDaoError(daoErr.UnknownError, fmt.Errorf("error on get flags"))
}

newInmemoryFlagList := []model.FeatureFlag{}
for _, f := range m.flags {
if f.ID != id {
Expand All @@ -82,3 +121,7 @@ func (m *InMemoryMockDao) Ping() daoErr.DaoError {
func (m *InMemoryMockDao) OnPingReturnError(v bool) {
m.errorOnPing = v
}

func (m *InMemoryMockDao) SetFlags(flags []model.FeatureFlag) {
m.flags = flags
}
1 change: 1 addition & 0 deletions dao/postgres_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"

"github.com/go-feature-flag/app-api/dao/dbmodel"
daoErr "github.com/go-feature-flag/app-api/dao/err"
"github.com/go-feature-flag/app-api/model"
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/go-feature-flag/app-api

go 1.22.7
go 1.23.2

require (
github.com/google/uuid v1.6.0
Expand Down
86 changes: 50 additions & 36 deletions handler/flags.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,37 @@
package handler

import (
"database/sql"
"errors"
"fmt"
daoErr "github.com/go-feature-flag/app-api/dao/err"
"github.com/labstack/echo/v4"
"net/http"
"time"

"github.com/go-feature-flag/app-api/dao"
daoErr "github.com/go-feature-flag/app-api/dao/err"
"github.com/go-feature-flag/app-api/model"
"github.com/go-feature-flag/app-api/util"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/labstack/echo/v4"
)

type FlagAPIHandlerOptions struct {
Clock util.Clock
}

type FlagAPIHandler struct {
dao dao.Flags
dao dao.Flags
options *FlagAPIHandlerOptions
}

// NewFlagAPIHandler creates a new instance of the FlagAPIHandler handler
// It is a controller class to handle the feature flag configuration logic
func NewFlagAPIHandler(dao dao.Flags) FlagAPIHandler {
return FlagAPIHandler{dao: dao}
func NewFlagAPIHandler(dao dao.Flags, options *FlagAPIHandlerOptions) FlagAPIHandler {
if options == nil {
options = &FlagAPIHandlerOptions{}
}
if options.Clock == nil {
options.Clock = util.DefaultClock{}
}
return FlagAPIHandler{dao: dao, options: options}
}

// GetAllFeatureFlags is returning the list of all the flags
Expand Down Expand Up @@ -87,8 +96,8 @@ func (f FlagAPIHandler) CreateNewFlag(c echo.Context) error {
if flag.ID == "" {
flag.ID = uuid.NewString()
}
flag.CreatedDate = time.Now()
flag.LastUpdatedDate = time.Now()
flag.CreatedDate = f.options.Clock.Now()
flag.LastUpdatedDate = f.options.Clock.Now()
// TODO: remove this line and extract the information from the token
flag.LastModifiedBy = "toto"

Expand All @@ -105,6 +114,9 @@ func (f FlagAPIHandler) CreateNewFlag(c echo.Context) error {

id, err := f.dao.CreateFlag(c.Request().Context(), flag)
if err != nil {
if err.Code() == daoErr.ConversionError {
return c.JSON(model.NewHTTPError(http.StatusBadRequest, err))
}
return c.JSON(model.NewHTTPError(http.StatusInternalServerError, err))
}
flag.ID = id
Expand All @@ -123,8 +135,17 @@ func validateFlag(flag model.FeatureFlag) (int, error) {
return status, err
}

if flag.VariationType == "" {
switch flag.VariationType {
case model.FlagTypeBoolean,
model.FlagTypeDouble,
model.FlagTypeInteger,
model.FlagTypeString,
model.FlagTypeJSON:
break
case "":
return http.StatusBadRequest, errors.New("flag type is required")
default:
return http.StatusBadRequest, fmt.Errorf("flag type %s not supported", flag.VariationType)
}

for _, rule := range flag.GetRules() {
Expand All @@ -137,10 +158,15 @@ func validateFlag(flag model.FeatureFlag) (int, error) {
}

func validateRule(rule *model.Rule, isDefault bool) (int, error) {
if rule == nil ||
(rule.ProgressiveRollout == nil &&
rule.Percentages == nil &&
(rule.VariationResult == nil || *rule.VariationResult == "")) {
if rule == nil || *rule == (model.Rule{}) {
if isDefault {
return http.StatusBadRequest, errors.New("flag default rule is required")
}
return http.StatusBadRequest, errors.New("targeting rule is nil")
}
if rule.ProgressiveRollout == nil &&
rule.Percentages == nil &&
(rule.VariationResult == nil || *rule.VariationResult == "") {
err := fmt.Errorf("invalid rule %s", rule.Name)
if isDefault {
err = errors.New("flag default rule is invalid")
Expand All @@ -150,7 +176,7 @@ func validateRule(rule *model.Rule, isDefault bool) (int, error) {

if !isDefault {
if rule.Query == "" {
return http.StatusBadRequest, errors.New("rule query is required")
return http.StatusBadRequest, errors.New("query is required for targeting rules")
}
}
return http.StatusOK, nil
Expand All @@ -168,8 +194,7 @@ func validateRule(rule *model.Rule, isDefault bool) (int, error) {
// @Failure 500 {object} model.HTTPError "Internal server error"
// @Router /v1/flags/{id} [put]
func (f FlagAPIHandler) UpdateFlagByID(c echo.Context) error {
// check if the flag exists
_, err := f.dao.GetFlagByID(c.Request().Context(), c.Param("id"))
retrievedFlag, err := f.dao.GetFlagByID(c.Request().Context(), c.Param("id"))
if err != nil {
return f.handleDaoError(c, err)
}
Expand All @@ -187,11 +212,12 @@ func (f FlagAPIHandler) UpdateFlagByID(c echo.Context) error {
if flag.ID == "" {
flag.ID = c.Param("id")
}
flag.LastUpdatedDate = time.Now()
flag.LastUpdatedDate = f.options.Clock.Now()
flag.CreatedDate = retrievedFlag.CreatedDate

err = f.dao.UpdateFlag(c.Request().Context(), flag)
if err != nil {
return c.JSON(model.NewHTTPError(http.StatusInternalServerError, err))
return f.handleDaoError(c, err)
}
return c.JSON(http.StatusOK, flag)
}
Expand All @@ -210,19 +236,7 @@ func (f FlagAPIHandler) DeleteFlagByID(c echo.Context) error {
idParam := c.Param("id")
err := f.dao.DeleteFlagByID(c.Request().Context(), idParam)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return c.JSON(model.NewHTTPError(http.StatusNotFound, fmt.Errorf("flag with id %s not found", idParam)))
}
var pgErr *pq.Error
if errors.As(err, &pgErr) {
switch pgErr.Code {
case "22P02":
return c.JSON(model.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid UUID format")))
default:
return c.JSON(model.NewHTTPError(http.StatusInternalServerError, err))
}
}
return c.JSON(model.NewHTTPError(http.StatusInternalServerError, err))
return f.handleDaoError(c, err)
}
return c.JSON(http.StatusNoContent, nil)
}
Expand Down Expand Up @@ -251,10 +265,10 @@ func (f FlagAPIHandler) UpdateFeatureFlagStatus(c echo.Context) error {
}

flag.Disable = &statusUpdate.Disable
flag.LastUpdatedDate = time.Now()
flag.LastUpdatedDate = f.options.Clock.Now()
err = f.dao.UpdateFlag(c.Request().Context(), flag)
if err != nil {
return c.JSON(model.NewHTTPError(http.StatusInternalServerError, err))
return f.handleDaoError(c, err)
}
return c.JSON(http.StatusOK, flag)
}
Expand All @@ -263,7 +277,7 @@ func (f FlagAPIHandler) UpdateFeatureFlagStatus(c echo.Context) error {
func (f FlagAPIHandler) handleDaoError(c echo.Context, err daoErr.DaoError) error {
switch err.Code() {
case daoErr.NotFound:
return c.JSON(model.NewHTTPError(http.StatusNotFound, fmt.Errorf("flag with id %s not found", c.Param("id"))))
return c.JSON(model.NewHTTPError(http.StatusNotFound, fmt.Errorf("flag not found")))
case daoErr.InvalidUUID:
return c.JSON(model.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid UUID format")))
default:
Expand Down
Loading

0 comments on commit 07a189c

Please sign in to comment.