Skip to content

Commit e83ecf8

Browse files
committed
add route customizer to gmux adapter
1 parent eb497ee commit e83ecf8

3 files changed

Lines changed: 97 additions & 3 deletions

File tree

adapters/humamux/humagmux_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,82 @@ package humamux
22

33
import (
44
"context"
5+
"fmt"
56
"net/http"
67
"net/http/httptest"
78
"strings"
89
"testing"
910
"time"
1011

1112
"github.com/danielgtaylor/huma/v2"
13+
"github.com/danielgtaylor/huma/v2/humatest"
1214
"github.com/gorilla/mux"
15+
"github.com/stretchr/testify/assert"
1316
)
1417

1518
var lastModified = time.Now()
1619

20+
type TestInput struct {
21+
Group string `path:"group"`
22+
Verbose bool `query:"verbose"`
23+
Auth string `header:"Authorization"`
24+
TestHeader string `header:"TestHeader"`
25+
Body struct {
26+
Name string `json:"name"`
27+
Email string `json:"email"`
28+
}
29+
}
30+
31+
// Test outputs (headers, body).
32+
type TestOutput struct {
33+
MyHeader string `header:"MyHeader"`
34+
TestHeader string `header:"TestHeader"`
35+
Body struct {
36+
Message string `json:"message"`
37+
}
38+
}
39+
40+
func testHandler(ctx context.Context, input *TestInput) (*TestOutput, error) {
41+
resp := &TestOutput{}
42+
resp.MyHeader = "my-value"
43+
resp.TestHeader = input.TestHeader
44+
resp.Body.Message = fmt.Sprintf("Hello, %s <%s>! (%s, %v, %s)", input.Body.Name, input.Body.Email, input.Group, input.Verbose, input.Auth)
45+
return resp, nil
46+
}
47+
48+
func TestCustomMiddleware(t *testing.T) {
49+
mw1 := func(next http.Handler) http.Handler {
50+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
51+
r.Header.Set("TestHeader", "test-value")
52+
next.ServeHTTP(w, r)
53+
})
54+
}
55+
56+
r := mux.NewRouter()
57+
api := New(r, huma.DefaultConfig("Test", "1.0.0"),
58+
WithRouteCustomizer(func(op *huma.Operation, r *mux.Route) {
59+
r.Handler(mw1(r.GetHandler()))
60+
}))
61+
62+
huma.Register(api, huma.Operation{
63+
OperationID: "test",
64+
Method: http.MethodGet,
65+
Path: "/{group}",
66+
}, testHandler)
67+
68+
testAPI := humatest.Wrap(t, api)
69+
resp := testAPI.Do(http.MethodGet, "/foo",
70+
"Host: localhost",
71+
"Authorization: Bearer abc123",
72+
strings.NewReader(`{"name": "Daniel", "email": "daniel@example.com"}`),
73+
)
74+
75+
assert.Equal(t, http.StatusOK, resp.Code)
76+
assert.Equal(t, "my-value", resp.Header().Get("MyHeader"))
77+
assert.Equal(t, "test-value", resp.Header().Get("TestHeader"))
78+
79+
}
80+
1781
func BenchmarkHumaGorillaMux(b *testing.B) {
1882
type GreetingInput struct {
1983
ID string `path:"id"`

adapters/humamux/humamux.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,23 +132,27 @@ func (c *gmuxContext) BodyWriter() io.Writer {
132132
}
133133

134134
type gMux struct {
135+
options
135136
router *mux.Router
136137
}
137138

138139
func (a *gMux) Handle(op *huma.Operation, handler func(huma.Context)) {
139-
a.router.
140+
route := a.router.
140141
NewRoute().
141142
Path(op.Path).
142143
Methods(op.Method).
143144
HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
144145
handler(&gmuxContext{op: op, r: r, w: w})
145146
})
147+
if a.routeCustomizer != nil {
148+
a.routeCustomizer(op, route)
149+
}
146150
}
147151

148152
func (a *gMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
149153
a.router.ServeHTTP(w, r)
150154
}
151155

152-
func New(r *mux.Router, config huma.Config) huma.API {
153-
return huma.NewAPI(config, &gMux{router: r})
156+
func New(r *mux.Router, config huma.Config, options ...Option) huma.API {
157+
return huma.NewAPI(config, &gMux{router: r, options: parseOptions(options)})
154158
}

adapters/humamux/options.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package humamux
2+
3+
import (
4+
"github.com/danielgtaylor/huma/v2"
5+
"github.com/gorilla/mux"
6+
)
7+
8+
type Option func(*options)
9+
10+
func WithRouteCustomizer(f func(op *huma.Operation, r *mux.Route)) Option {
11+
return func(o *options) {
12+
o.routeCustomizer = f
13+
}
14+
}
15+
16+
func parseOptions(optionList []Option) options {
17+
var optns options
18+
for _, opt := range optionList {
19+
opt(&optns)
20+
}
21+
return optns
22+
}
23+
24+
type options struct {
25+
routeCustomizer func(op *huma.Operation, r *mux.Route)
26+
}

0 commit comments

Comments
 (0)