Skip to content

Commit 8ee657b

Browse files
authored
Merge pull request #6 from nodece/main
fix: remove leaderMiddleware
2 parents ed4bdac + 6365fbb commit 8ee657b

File tree

2 files changed

+51
-76
lines changed

2 files changed

+51
-76
lines changed

http/service.go

Lines changed: 40 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@ import (
1212
"strings"
1313
"time"
1414

15+
"golang.org/x/net/http2"
16+
17+
"github.com/hashicorp/raft"
18+
1519
"github.com/hashicorp/go-multierror"
1620

21+
"github.com/go-chi/chi"
1722
jsoniter "github.com/json-iterator/go"
1823
"github.com/pkg/errors"
19-
"golang.org/x/net/http2"
20-
21-
"github.com/go-chi/chi"
2224

2325
"github.com/casbin/hraft-dispatcher/command"
2426

@@ -93,12 +95,12 @@ func NewService(ln net.Listener, tlsConfig *tls.Config, store Store) (*Service,
9395
}
9496

9597
r := chi.NewRouter()
96-
r.With(s.leaderMiddleware).Route("/policies", func(r chi.Router) {
98+
r.Route("/policies", func(r chi.Router) {
9799
r.Put("/add", s.handleAddPolicy)
98100
r.Put("/update", s.handleUpdatePolicy)
99101
r.Put("/remove", s.handleRemovePolicy)
100102
})
101-
r.With(s.leaderMiddleware).Route("/nodes", func(r chi.Router) {
103+
r.Route("/nodes", func(r chi.Router) {
102104
r.Put("/join", s.handleJoinNode)
103105
r.Put("/remove", s.handleRemoveNode)
104106
})
@@ -125,38 +127,17 @@ func NewService(ln net.Listener, tlsConfig *tls.Config, store Store) (*Service,
125127
return s, nil
126128
}
127129

128-
// leaderMiddleware checks whether the current node is the leader.
129-
// If this current node is not a leader, the request is forwarded to the leader node.
130-
func (s *Service) leaderMiddleware(next http.Handler) http.Handler {
131-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
132-
isLeader, leaderAddr := s.store.Leader()
133-
if !isLeader {
134-
if len(leaderAddr) == 0 {
135-
s.logger.Error("failed to get the leader address")
136-
w.WriteHeader(http.StatusServiceUnavailable)
137-
return
138-
}
139-
redirectURL := s.getRedirectURL(r, leaderAddr)
140-
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
141-
return
142-
}
143-
next.ServeHTTP(w, r)
144-
})
145-
}
146-
147130
// Start starts this service.
148131
// It always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed.
149132
func (s *Service) Start() error {
150133
go func() {
151134
var err error
152135
if s.tlsConfig == nil {
153136
s.logger.Info(fmt.Sprintf("listening and serving HTTP on %s", s.ln.Addr()))
154-
err = s.srv.Serve(s.ln)
155137
} else {
156138
s.logger.Info(fmt.Sprintf("listening and serving HTTPS on %s", s.ln.Addr()))
157-
err = s.srv.Serve(s.ln)
158-
err = s.srv.ServeTLS(s.ln, "", "")
159139
}
140+
err = s.srv.Serve(s.ln)
160141
if err != nil && err != http.ErrServerClosed {
161142
s.logger.Error("unable to serve http", zap.Error(err))
162143
}
@@ -208,9 +189,30 @@ func (s *Service) GetScheme() string {
208189
return scheme
209190
}
210191

211-
// handleNodes handles request of nodes.
212-
func (s *Service) handleNodes(w http.ResponseWriter, r *http.Request) {
213-
w.WriteHeader(http.StatusServiceUnavailable)
192+
// handleStoreResponse checks the error returned by store.
193+
// If the error is nil, the server returns http.StatusOK.
194+
// If the error is raft.ErrNotLeader, the server forward the request to the leader node,
195+
// otherwise the server returns http.StatusServiceUnavailable.
196+
func (s *Service) handleStoreResponse(err error, w http.ResponseWriter, r *http.Request) {
197+
if err == nil {
198+
w.WriteHeader(http.StatusOK)
199+
return
200+
}
201+
if err == raft.ErrNotLeader {
202+
isLeader, leaderAddr := s.store.Leader()
203+
if !isLeader {
204+
if len(leaderAddr) == 0 {
205+
s.logger.Error("failed to get the leader address")
206+
w.WriteHeader(http.StatusServiceUnavailable)
207+
return
208+
}
209+
redirectURL := s.getRedirectURL(r, leaderAddr)
210+
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
211+
return
212+
}
213+
} else {
214+
http.Error(w, err.Error(), http.StatusServiceUnavailable)
215+
}
214216
}
215217

216218
// handleAddPolicy handles the request to add a set of rules.
@@ -227,10 +229,7 @@ func (s *Service) handleAddPolicy(w http.ResponseWriter, r *http.Request) {
227229
return
228230
}
229231
err = s.store.AddPolicies(&cmd)
230-
if err != nil {
231-
http.Error(w, err.Error(), http.StatusServiceUnavailable)
232-
return
233-
}
232+
s.handleStoreResponse(err, w, r)
234233
}
235234

236235
// handleRemovePolicy handles the request to remove a set of rules.
@@ -239,10 +238,7 @@ func (s *Service) handleRemovePolicy(w http.ResponseWriter, r *http.Request) {
239238
switch removeType {
240239
case "all":
241240
err := s.store.ClearPolicy()
242-
if err != nil {
243-
http.Error(w, err.Error(), http.StatusBadRequest)
244-
return
245-
}
241+
s.handleStoreResponse(err, w, r)
246242
case "filtered":
247243
data, err := ioutil.ReadAll(r.Body)
248244
if err != nil {
@@ -256,10 +252,7 @@ func (s *Service) handleRemovePolicy(w http.ResponseWriter, r *http.Request) {
256252
return
257253
}
258254
err = s.store.RemoveFilteredPolicy(&cmd)
259-
if err != nil {
260-
http.Error(w, err.Error(), http.StatusBadRequest)
261-
return
262-
}
255+
s.handleStoreResponse(err, w, r)
263256
case "":
264257
data, err := ioutil.ReadAll(r.Body)
265258
if err != nil {
@@ -273,10 +266,7 @@ func (s *Service) handleRemovePolicy(w http.ResponseWriter, r *http.Request) {
273266
return
274267
}
275268
err = s.store.RemovePolicies(&cmd)
276-
if err != nil {
277-
http.Error(w, err.Error(), http.StatusServiceUnavailable)
278-
return
279-
}
269+
s.handleStoreResponse(err, w, r)
280270
default:
281271
w.WriteHeader(http.StatusBadRequest)
282272
}
@@ -299,10 +289,7 @@ func (s *Service) handleUpdatePolicy(w http.ResponseWriter, r *http.Request) {
299289
return
300290
}
301291
err = s.store.UpdatePolicies(&cmd)
302-
if err != nil {
303-
http.Error(w, err.Error(), http.StatusServiceUnavailable)
304-
return
305-
}
292+
s.handleStoreResponse(err, w, r)
306293
case "":
307294
data, err := ioutil.ReadAll(r.Body)
308295
if err != nil {
@@ -316,10 +303,7 @@ func (s *Service) handleUpdatePolicy(w http.ResponseWriter, r *http.Request) {
316303
return
317304
}
318305
err = s.store.UpdatePolicy(&cmd)
319-
if err != nil {
320-
http.Error(w, err.Error(), http.StatusServiceUnavailable)
321-
return
322-
}
306+
s.handleStoreResponse(err, w, r)
323307
default:
324308
w.WriteHeader(http.StatusBadRequest)
325309
}
@@ -338,10 +322,7 @@ func (s *Service) handleJoinNode(w http.ResponseWriter, r *http.Request) {
338322
return
339323
}
340324
err = s.store.JoinNode(cmd.Id, cmd.Address)
341-
if err != nil {
342-
http.Error(w, err.Error(), http.StatusServiceUnavailable)
343-
return
344-
}
325+
s.handleStoreResponse(err, w, r)
345326
}
346327

347328
func (s *Service) handleRemoveNode(w http.ResponseWriter, r *http.Request) {
@@ -357,10 +338,7 @@ func (s *Service) handleRemoveNode(w http.ResponseWriter, r *http.Request) {
357338
return
358339
}
359340
err = s.store.RemoveNode(cmd.Id)
360-
if err != nil {
361-
http.Error(w, err.Error(), http.StatusServiceUnavailable)
362-
return
363-
}
341+
s.handleStoreResponse(err, w, r)
364342
}
365343

366344
func (s *Service) Addr() string {

http/service_test.go

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ import (
1313
"testing"
1414
"time"
1515

16+
"github.com/hashicorp/raft"
17+
"github.com/pkg/errors"
18+
1619
"github.com/casbin/hraft-dispatcher/command"
1720
"github.com/casbin/hraft-dispatcher/http/mocks"
1821
"github.com/golang/mock/gomock"
@@ -48,7 +51,7 @@ func TestRedirect(t *testing.T) {
4851
assert.Equal(t, expectedURL, actualURL)
4952
}
5053

51-
func TestLeader(t *testing.T) {
54+
func TestNotLeaderError(t *testing.T) {
5255
ctl := gomock.NewController(t)
5356
defer ctl.Finish()
5457

@@ -58,17 +61,18 @@ func TestLeader(t *testing.T) {
5861
s, err := NewService(ln, nil, store)
5962
assert.NoError(t, err)
6063

61-
store.EXPECT().Leader().Return(true, "127.0.0.1:6790")
62-
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
63-
r := s.leaderMiddleware(nextHandler)
6464
w := httptest.NewRecorder()
65-
r.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPut, "https://testing", nil))
65+
s.handleStoreResponse(nil, w, httptest.NewRequest(http.MethodPut, "https://testing", nil))
6666
assert.Equal(t, w.Code, http.StatusOK)
6767

68+
w = httptest.NewRecorder()
69+
s.handleStoreResponse(errors.New("test error"), w, httptest.NewRequest(http.MethodPut, "https://testing", nil))
70+
assert.Equal(t, w.Code, http.StatusServiceUnavailable)
71+
6872
store.EXPECT().Leader().Return(false, "127.0.0.1:6790")
6973
w = httptest.NewRecorder()
70-
r.ServeHTTP(w, httptest.NewRequest(http.MethodPut, "https://testing", nil))
71-
assert.Equal(t, w.Header().Get("Location"), "https://127.0.0.1:6790")
74+
s.handleStoreResponse(raft.ErrNotLeader, w, httptest.NewRequest(http.MethodPut, "https://testing/add", nil))
75+
assert.Equal(t, w.Header().Get("Location"), "https://127.0.0.1:6790/add")
7276
assert.Equal(t, w.Code, http.StatusTemporaryRedirect)
7377
}
7478

@@ -97,7 +101,6 @@ func TestAddPolicy(t *testing.T) {
97101
PType: "p",
98102
Rules: []*command.StringArray{{Items: []string{"role:admin", "/", "*"}}},
99103
}
100-
store.EXPECT().Leader().Return(true, s.Addr())
101104
store.EXPECT().AddPolicies(addPolicyRequest).Return(nil)
102105

103106
b, err := jsoniter.Marshal(addPolicyRequest)
@@ -135,7 +138,6 @@ func TestRemovePolicy(t *testing.T) {
135138
PType: "p",
136139
Rules: []*command.StringArray{{Items: []string{"role:admin", "/", "*"}}},
137140
}
138-
store.EXPECT().Leader().Return(true, s.Addr())
139141
store.EXPECT().RemovePolicies(removePolicyRequest).Return(nil)
140142

141143
b, err := jsoniter.Marshal(removePolicyRequest)
@@ -174,7 +176,6 @@ func TestRemoveFilteredPolicy(t *testing.T) {
174176
FieldIndex: 0,
175177
FieldValues: []string{"role:admin"},
176178
}
177-
store.EXPECT().Leader().Return(true, s.Addr())
178179
store.EXPECT().RemoveFilteredPolicy(removeFilteredPolicyRequest).Return(nil)
179180

180181
b, err := jsoniter.Marshal(removeFilteredPolicyRequest)
@@ -213,7 +214,6 @@ func TestUpdatePolicy(t *testing.T) {
213214
OldRule: []string{"role:admin", "/", "*"},
214215
NewRule: []string{"role:admin", "/admin", "*"},
215216
}
216-
store.EXPECT().Leader().Return(true, s.Addr())
217217
store.EXPECT().UpdatePolicy(updatePolicyRequest).Return(nil)
218218

219219
b, err := jsoniter.Marshal(updatePolicyRequest)
@@ -246,7 +246,6 @@ func TestClearPolicy(t *testing.T) {
246246
assert.NoError(t, err)
247247
defer s.Stop(context.Background())
248248

249-
store.EXPECT().Leader().Return(true, s.Addr())
250249
store.EXPECT().ClearPolicy().Return(nil)
251250

252251
r, err := http.NewRequest(http.MethodPut, fmt.Sprintf("https://%s/policies/remove?type=all", s.Addr()), nil)
@@ -281,7 +280,6 @@ func TestJoinNode(t *testing.T) {
281280
Id: "test-main",
282281
Address: "10.0.7.10",
283282
}
284-
store.EXPECT().Leader().Return(true, s.Addr())
285283
store.EXPECT().JoinNode(addNodeRequest.Id, addNodeRequest.Address).Return(nil)
286284

287285
b, err := jsoniter.Marshal(addNodeRequest)
@@ -342,7 +340,6 @@ func TestRemoveNode(t *testing.T) {
342340
removeNodeRequest := &command.RemoveNodeRequest{
343341
Id: "test-main",
344342
}
345-
store.EXPECT().Leader().Return(true, s.Addr())
346343
store.EXPECT().RemoveNode(removeNodeRequest.Id).Return(nil)
347344

348345
b, err := jsoniter.Marshal(removeNodeRequest)

0 commit comments

Comments
 (0)