Skip to content

Commit db05736

Browse files
authored
Merge pull request #171 from kaleido-io/fix-customsubsystem-name
[ffapi] [config] application/x-www-form-urlencoded Params and Fixing Usability Bugs
2 parents 3e3a001 + 5e09223 commit db05736

File tree

8 files changed

+214
-38
lines changed

8 files changed

+214
-38
lines changed

pkg/config/config.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,19 @@ func RootConfigReset(setServiceDefaults ...func()) {
162162
i18n.SetLang(viper.GetString(string(Lang)))
163163
}
164164

165+
var envPrefix = "firefly"
166+
167+
func SetEnvPrefix(prefix string) {
168+
envPrefix = prefix
169+
}
170+
165171
// ReadConfig initializes the config
166172
func ReadConfig(cfgSuffix, cfgFile string) error {
167173
keysMutex.Lock() // must only call viper directly here (as we already hold the lock)
168174
defer keysMutex.Unlock()
169175

170176
// Set precedence order for reading config location
171-
viper.SetEnvPrefix("firefly")
177+
viper.SetEnvPrefix(envPrefix)
172178
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
173179
viper.AutomaticEnv()
174180
viper.SetConfigType("yaml")

pkg/config/config_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package config
1919
import (
2020
"context"
2121
"fmt"
22+
"github.com/stretchr/testify/require"
2223
"os"
2324
"path"
2425
"strings"
@@ -491,3 +492,24 @@ func TestConfigWatchE2E(t *testing.T) {
491492
}()
492493

493494
}
495+
496+
func TestSetEnvPrefix(t *testing.T) {
497+
tmpDir := t.TempDir()
498+
configPath := path.Join(tmpDir, "test.yaml")
499+
500+
// Create the file
501+
os.WriteFile(configPath, []byte(`{"conf": "one"}`), 0664)
502+
503+
os.Setenv("TEST_UT_CONF", "two")
504+
SetEnvPrefix("test")
505+
506+
RootConfigReset()
507+
cfg := RootSection("ut")
508+
cfg.AddKnownKey("conf")
509+
510+
err := ReadConfig("yaml", configPath)
511+
require.NoError(t, err)
512+
assert.Equal(t, "test", viper.GetViper().GetEnvPrefix())
513+
514+
assert.Equal(t, "two", cfg.GetString("conf"))
515+
}

pkg/ffapi/apiserver.go

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ const APIServerMetricsSubSystemName = "api_server_rest"
4646
type APIServer interface {
4747
Serve(ctx context.Context) error
4848
Started() <-chan struct{}
49-
MuxRouter(ctx context.Context) *mux.Router
49+
MuxRouter(ctx context.Context) (*mux.Router, error)
5050
APIPublicURL() string // valid to call after server is successfully started
5151
}
5252

@@ -67,6 +67,7 @@ type apiServer[T any] struct {
6767
livenessPath string
6868
monitoringPublicURL string
6969
mux *mux.Router
70+
oah *OpenAPIHandlerFactory
7071

7172
APIServerOptions[T]
7273
}
@@ -97,6 +98,14 @@ type APIServerRouteExt[T any] struct {
9798
// NewAPIServer makes a new server, with the specified configuration, and
9899
// the supplied wrapper function - which will inject
99100
func NewAPIServer[T any](ctx context.Context, options APIServerOptions[T]) APIServer {
101+
if options.APIConfig == nil {
102+
panic("APIConfig is required")
103+
}
104+
105+
if options.MonitoringConfig == nil {
106+
panic("MonitoringConfig is required")
107+
}
108+
100109
as := &apiServer[T]{
101110
defaultFilterLimit: options.APIConfig.GetUint64(ConfAPIDefaultFilterLimit),
102111
maxFilterLimit: options.APIConfig.GetUint64(ConfAPIMaxFilterLimit),
@@ -119,27 +128,34 @@ func NewAPIServer[T any](ctx context.Context, options APIServerOptions[T]) APISe
119128
as.FavIcon32 = ffLogo16
120129
}
121130

122-
metricsSubsystemName := APIServerMetricsSubSystemName
123-
if options.MetricsSubsystemName != "" {
124-
metricsSubsystemName = options.MetricsSubsystemName
125-
}
126-
127131
_ = as.MetricsRegistry.NewHTTPMetricsInstrumentationsForSubsystem(
128132
ctx,
129-
metricsSubsystemName,
133+
as.metricsSubsystemName(),
130134
true,
131135
prometheus.DefBuckets,
132136
map[string]string{},
133137
)
134138
return as
135139
}
136140

141+
func (as *apiServer[T]) metricsSubsystemName() string {
142+
metricsSubsystemName := APIServerMetricsSubSystemName
143+
if as.MetricsSubsystemName != "" {
144+
metricsSubsystemName = as.MetricsSubsystemName
145+
}
146+
return metricsSubsystemName
147+
}
148+
137149
// Can be called before Serve, but MUST use the background context if so
138-
func (as *apiServer[T]) MuxRouter(ctx context.Context) *mux.Router {
150+
func (as *apiServer[T]) MuxRouter(ctx context.Context) (*mux.Router, error) {
139151
if as.mux == nil {
140-
as.mux = as.createMuxRouter(ctx)
152+
var err error
153+
if as.mux, err = as.createMuxRouter(ctx); err != nil {
154+
return nil, err
155+
}
156+
141157
}
142-
return as.mux
158+
return as.mux, nil
143159
}
144160

145161
// Serve is the main entry point for the API Server
@@ -155,17 +171,26 @@ func (as *apiServer[T]) Serve(ctx context.Context) (err error) {
155171
httpErrChan := make(chan error)
156172
monitoringErrChan := make(chan error)
157173

158-
apiHTTPServer, err := httpserver.NewHTTPServer(ctx, "api", as.MuxRouter(ctx), httpErrChan, as.APIConfig, as.CORSConfig, &httpserver.ServerOptions{
174+
apiMux, err := as.MuxRouter(ctx)
175+
if err != nil {
176+
return err
177+
}
178+
apiHTTPServer, err := httpserver.NewHTTPServer(ctx, "api", apiMux, httpErrChan, as.APIConfig, as.CORSConfig, &httpserver.ServerOptions{
159179
MaximumRequestTimeout: as.requestMaxTimeout,
160180
})
161181
if err != nil {
162182
return err
163183
}
164184
as.apiPublicURL = buildPublicURL(as.APIConfig, apiHTTPServer.Addr())
185+
as.oah.StaticPublicURL = as.apiPublicURL
165186
go apiHTTPServer.ServeHTTP(ctx)
166187

167188
if as.monitoringEnabled {
168-
monitoringHTTPServer, err := httpserver.NewHTTPServer(ctx, "monitoring", as.createMonitoringMuxRouter(ctx), monitoringErrChan, as.MonitoringConfig, as.CORSConfig, &httpserver.ServerOptions{
189+
monitoringMux, err := as.createMonitoringMuxRouter(ctx)
190+
if err != nil {
191+
return err
192+
}
193+
monitoringHTTPServer, err := httpserver.NewHTTPServer(ctx, "monitoring", monitoringMux, monitoringErrChan, as.MonitoringConfig, as.CORSConfig, &httpserver.ServerOptions{
169194
MaximumRequestTimeout: as.requestMaxTimeout,
170195
})
171196
if err != nil {
@@ -249,12 +274,12 @@ func (as *apiServer[T]) handlerFactory() *HandlerFactory {
249274
}
250275
}
251276

252-
func (as *apiServer[T]) createMuxRouter(ctx context.Context) *mux.Router {
277+
func (as *apiServer[T]) createMuxRouter(ctx context.Context) (*mux.Router, error) {
253278
r := mux.NewRouter().UseEncodedPath()
254279
hf := as.handlerFactory()
255280

256281
if as.monitoringEnabled {
257-
h, _ := as.MetricsRegistry.GetHTTPMetricsInstrumentationsMiddlewareForSubsystem(ctx, APIServerMetricsSubSystemName)
282+
h, _ := as.MetricsRegistry.GetHTTPMetricsInstrumentationsMiddlewareForSubsystem(ctx, as.metricsSubsystemName())
258283
r.Use(h)
259284
}
260285

@@ -273,30 +298,33 @@ func (as *apiServer[T]) createMuxRouter(ctx context.Context) *mux.Router {
273298
}
274299
}
275300
if ce.JSONHandler != nil || ce.UploadHandler != nil || ce.StreamHandler != nil {
301+
if strings.HasPrefix(route.Path, "/") {
302+
return nil, fmt.Errorf("route path '%s' must not start with '/'", route.Path)
303+
}
276304
r.HandleFunc(fmt.Sprintf("/api/v1/%s", route.Path), as.routeHandler(hf, route)).
277305
Methods(route.Method)
278306
}
279307
}
280308

281-
oah := &OpenAPIHandlerFactory{
309+
as.oah = &OpenAPIHandlerFactory{
282310
BaseSwaggerGenOptions: SwaggerGenOptions{
283311
Title: as.Description,
284312
Version: "1.0",
285313
PanicOnMissingDescription: as.PanicOnMissingDescription,
286314
DefaultRequestTimeout: as.requestTimeout,
287315
SupportFieldRedaction: as.SupportFieldRedaction,
288316
},
289-
StaticPublicURL: as.apiPublicURL,
317+
StaticPublicURL: as.apiPublicURL, // this is most likely not yet set, we'll ensure its set later on
290318
}
291-
r.HandleFunc(`/api/swagger.yaml`, hf.APIWrapper(oah.OpenAPIHandler(`/api/v1`, OpenAPIFormatYAML, as.Routes)))
292-
r.HandleFunc(`/api/swagger.json`, hf.APIWrapper(oah.OpenAPIHandler(`/api/v1`, OpenAPIFormatJSON, as.Routes)))
293-
r.HandleFunc(`/api/openapi.yaml`, hf.APIWrapper(oah.OpenAPIHandler(`/api/v1`, OpenAPIFormatYAML, as.Routes)))
294-
r.HandleFunc(`/api/openapi.json`, hf.APIWrapper(oah.OpenAPIHandler(`/api/v1`, OpenAPIFormatJSON, as.Routes)))
295-
r.HandleFunc(`/api`, hf.APIWrapper(oah.SwaggerUIHandler(`/api/openapi.yaml`)))
319+
r.HandleFunc(`/api/swagger.yaml`, hf.APIWrapper(as.oah.OpenAPIHandler(`/api/v1`, OpenAPIFormatYAML, as.Routes)))
320+
r.HandleFunc(`/api/swagger.json`, hf.APIWrapper(as.oah.OpenAPIHandler(`/api/v1`, OpenAPIFormatJSON, as.Routes)))
321+
r.HandleFunc(`/api/openapi.yaml`, hf.APIWrapper(as.oah.OpenAPIHandler(`/api/v1`, OpenAPIFormatYAML, as.Routes)))
322+
r.HandleFunc(`/api/openapi.json`, hf.APIWrapper(as.oah.OpenAPIHandler(`/api/v1`, OpenAPIFormatJSON, as.Routes)))
323+
r.HandleFunc(`/api`, hf.APIWrapper(as.oah.SwaggerUIHandler(`/api/openapi.yaml`)))
296324
r.HandleFunc(`/favicon{any:.*}.png`, favIconsHandler(as.FavIcon16, as.FavIcon32))
297325

298326
r.NotFoundHandler = hf.APIWrapper(as.notFoundHandler)
299-
return r
327+
return r, nil
300328
}
301329

302330
func (as *apiServer[T]) notFoundHandler(res http.ResponseWriter, req *http.Request) (status int, err error) {
@@ -309,7 +337,7 @@ func (as *apiServer[T]) emptyJSONHandler(res http.ResponseWriter, _ *http.Reques
309337
return 200, nil
310338
}
311339

312-
func (as *apiServer[T]) createMonitoringMuxRouter(ctx context.Context) *mux.Router {
340+
func (as *apiServer[T]) createMonitoringMuxRouter(ctx context.Context) (*mux.Router, error) {
313341
r := mux.NewRouter().UseEncodedPath()
314342
hf := as.handlerFactory() // TODO separate factory for monitoring ??
315343

@@ -322,12 +350,12 @@ func (as *apiServer[T]) createMonitoringMuxRouter(ctx context.Context) *mux.Rout
322350

323351
for _, route := range as.MonitoringRoutes {
324352
path := route.Path
325-
if !strings.HasPrefix(route.Path, "/") {
326-
path = fmt.Sprintf("/%s", route.Path)
353+
if strings.HasPrefix(route.Path, "/") {
354+
return nil, fmt.Errorf("route path '%s' must not start with '/'", route.Path)
327355
}
328-
r.HandleFunc(path, as.routeHandler(hf, route)).Methods(route.Method)
356+
r.HandleFunc("/"+path, as.routeHandler(hf, route)).Methods(route.Method)
329357
}
330358

331359
r.NotFoundHandler = hf.APIWrapper(as.notFoundHandler)
332-
return r
360+
return r, nil
333361
}

pkg/ffapi/apiserver_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,11 @@ func newTestAPIServer(t *testing.T, start bool) (*utManager, *apiServer[*utManag
135135
// request and that's the "T" on the APIServer
136136
return um, um.mockEnrichErr
137137
},
138-
Description: "unit testing",
139-
APIConfig: apiConfig,
140-
MonitoringConfig: monitoringConfig,
141-
CORSConfig: corsConfig,
138+
Description: "unit testing",
139+
APIConfig: apiConfig,
140+
MonitoringConfig: monitoringConfig,
141+
CORSConfig: corsConfig,
142+
MetricsSubsystemName: "apiserver_ut",
142143
})
143144
done := make(chan struct{})
144145
if start {

pkg/ffapi/handler.go

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"bytes"
2121
"context"
2222
"encoding/json"
23+
"fmt"
2324
"io"
2425
"mime/multipart"
2526
"net/http"
@@ -71,7 +72,6 @@ type multipartState struct {
7172
}
7273

7374
func (hs *HandlerFactory) getFilePart(req *http.Request) (*multipartState, error) {
74-
7575
formParams := make(map[string]string)
7676
ctx := req.Context()
7777
l := log.L(ctx)
@@ -104,6 +104,25 @@ func (hs *HandlerFactory) getFilePart(req *http.Request) (*multipartState, error
104104
}
105105
}
106106

107+
func (hs *HandlerFactory) getFormParams(req *http.Request) (map[string]string, error) {
108+
if err := req.ParseForm(); err != nil {
109+
return nil, i18n.WrapError(req.Context(), err, i18n.MsgParseFormError)
110+
}
111+
form := make(map[string]string, len(req.Form))
112+
for k, v := range req.Form {
113+
if len(v) < 1 {
114+
continue
115+
}
116+
117+
if len(v) > 1 {
118+
return nil, i18n.WrapError(req.Context(), fmt.Errorf("multi-value form parameters for '%s' are not currently supported", k), i18n.MsgParseFormError)
119+
}
120+
121+
form[k] = v[0]
122+
}
123+
return form, nil
124+
}
125+
107126
func (hs *HandlerFactory) getParams(req *http.Request, route *Route) (queryParams, pathParams map[string]string, queryArrayParams map[string][]string) {
108127
queryParams = make(map[string]string)
109128
pathParams = make(map[string]string)
@@ -154,7 +173,7 @@ func (hs *HandlerFactory) RouteHandler(route *Route) http.HandlerFunc {
154173
if route.JSONInputValue != nil {
155174
jsonInput = route.JSONInputValue()
156175
}
157-
var queryParams, pathParams map[string]string
176+
var queryParams, pathParams, formParams map[string]string
158177
var queryArrayParams map[string][]string
159178
var multipart *multipartState
160179
contentType := req.Header.Get("Content-Type")
@@ -185,6 +204,11 @@ func (hs *HandlerFactory) RouteHandler(route *Route) http.HandlerFunc {
185204
d.UseNumber()
186205
err = d.Decode(&jsonInput)
187206
}
207+
case strings.HasPrefix(strings.ToLower(contentType), "application/x-www-form-urlencoded"):
208+
formParams, err = hs.getFormParams(req)
209+
if err != nil {
210+
return 400, err
211+
}
188212
case strings.HasPrefix(strings.ToLower(contentType), "text/plain"):
189213
default:
190214
return 415, i18n.NewError(req.Context(), i18n.MsgInvalidContentType, contentType)
@@ -226,6 +250,9 @@ func (hs *HandlerFactory) RouteHandler(route *Route) http.HandlerFunc {
226250
output, err = route.FormUploadHandler(r)
227251
case route.StreamHandler != nil:
228252
output, err = route.StreamHandler(r)
253+
case formParams != nil:
254+
r.FP = formParams
255+
fallthrough
229256
default:
230257
output, err = route.JSONHandler(r)
231258
}

pkg/ffapi/handler_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"mime/multipart"
2626
"net/http"
2727
"net/http/httptest"
28+
"net/url"
2829
"strings"
2930
"testing"
3031
"time"
@@ -650,3 +651,60 @@ func TestBasePathParameters(t *testing.T) {
650651
assert.NoError(t, err)
651652
assert.Equal(t, 201, res.StatusCode)
652653
}
654+
655+
func TestPOSTFormParams(t *testing.T) {
656+
s, _, done := newTestServer(t, []*Route{{
657+
Name: "testRoute",
658+
Path: "/test",
659+
Method: http.MethodPost,
660+
FormParams: []*FormParam{
661+
{Name: "foo"},
662+
},
663+
JSONInputValue: nil,
664+
JSONOutputValue: func() interface{} { return make(map[string]interface{}) },
665+
JSONOutputCodes: []int{201},
666+
JSONHandler: func(r *APIRequest) (output interface{}, err error) {
667+
assert.Equal(t, "baz", r.PP["param"])
668+
assert.Equal(t, "bar", r.FP["foo"])
669+
return map[string]interface{}{}, nil
670+
},
671+
}}, "/base-path/{param}", []*PathParam{
672+
{Name: "param"},
673+
})
674+
defer done()
675+
676+
val := url.Values{
677+
"foo": {"bar"},
678+
}
679+
res, err := http.PostForm(fmt.Sprintf("http://%s/base-path/baz/test", s.Addr()), val)
680+
assert.NoError(t, err)
681+
assert.Equal(t, 201, res.StatusCode)
682+
}
683+
684+
func TestPOSTFormParamsMultiValueUnsupported(t *testing.T) {
685+
s, _, done := newTestServer(t, []*Route{{
686+
Name: "testRoute",
687+
Path: "/test",
688+
Method: http.MethodPost,
689+
FormParams: []*FormParam{
690+
{Name: "foo"},
691+
},
692+
JSONInputValue: nil,
693+
JSONOutputValue: func() interface{} { return make(map[string]interface{}) },
694+
JSONOutputCodes: []int{201},
695+
JSONHandler: func(r *APIRequest) (output interface{}, err error) {
696+
t.Fail() // we shouldn't get here
697+
return map[string]interface{}{}, nil
698+
},
699+
}}, "/base-path/{param}", []*PathParam{
700+
{Name: "param"},
701+
})
702+
defer done()
703+
704+
val := url.Values{
705+
"foo": {"bar", "foo2"},
706+
}
707+
res, err := http.PostForm(fmt.Sprintf("http://%s/base-path/baz/test", s.Addr()), val)
708+
assert.NoError(t, err)
709+
assert.Equal(t, 400, res.StatusCode)
710+
}

0 commit comments

Comments
 (0)