Skip to content

Commit 06f2d96

Browse files
committed
refactor: change FindProviders to use "limit int" instead of "stream bool"
1 parent 848098f commit 06f2d96

File tree

3 files changed

+49
-19
lines changed

3 files changed

+49
-19
lines changed

routing/http/client/client_test.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ import (
2727

2828
type mockContentRouter struct{ mock.Mock }
2929

30-
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error) {
31-
args := m.Called(ctx, key, stream)
30+
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.ProviderResponse], error) {
31+
args := m.Called(ctx, key, limit)
3232
return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1)
3333
}
3434
func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) {
@@ -302,8 +302,11 @@ func TestClient_FindProviders(t *testing.T) {
302302

303303
findProvsIter := iter.FromSlice(c.routerProvs)
304304

305-
router.On("FindProviders", mock.Anything, cid, c.expStreamingResponse).
306-
Return(findProvsIter, c.routerErr)
305+
if c.expStreamingResponse {
306+
router.On("FindProviders", mock.Anything, cid, 0).Return(findProvsIter, c.routerErr)
307+
} else {
308+
router.On("FindProviders", mock.Anything, cid, 20).Return(findProvsIter, c.routerErr)
309+
}
307310

308311
provsIter, err := client.FindProviders(ctx, cid)
309312

routing/http/server/server.go

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ const (
2828
mediaTypeJSON = "application/json"
2929
mediaTypeNDJSON = "application/x-ndjson"
3030
mediaTypeWildcard = "*/*"
31+
32+
DefaultRecordsLimit = 20
33+
DefaultStreamingRecordsLimit = 0
3134
)
3235

3336
var logger = logging.Logger("service/server/delegatedrouting")
@@ -41,9 +44,9 @@ type FindProvidersAsyncResponse struct {
4144
}
4245

4346
type ContentRouter interface {
44-
// FindProviders searches for peers who are able to provide a given key. Stream
45-
// indicates whether or not this request will be responded as a stream.
46-
FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error)
47+
// FindProviders searches for peers who are able to provide a given key. Limit
48+
// indicates the maximum amount of results to return. 0 means unbounded.
49+
FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.ProviderResponse], error)
4750
ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error)
4851
Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error)
4952
}
@@ -71,9 +74,27 @@ func WithStreamingResultsDisabled() Option {
7174
}
7275
}
7376

77+
// WithRecordsLimit sets a limit that will be passed to ContentRouter.FindProviders
78+
// for non-streaming requests (application/json). Default is DefaultRecordsLimit.
79+
func WithRecordsLimit(limit int) Option {
80+
return func(s *server) {
81+
s.recordsLimit = limit
82+
}
83+
}
84+
85+
// WithStreamingRecordsLimit sets a limit that will be passed to ContentRouter.FindProviders
86+
// for streaming requests (application/x-ndjson). Default is DefaultStreamingRecordsLimit.
87+
func WithStreamingRecordsLimit(limit int) Option {
88+
return func(s *server) {
89+
s.streamingRecordsLimit = limit
90+
}
91+
}
92+
7493
func Handler(svc ContentRouter, opts ...Option) http.Handler {
7594
server := &server{
76-
svc: svc,
95+
svc: svc,
96+
recordsLimit: DefaultRecordsLimit,
97+
streamingRecordsLimit: DefaultStreamingRecordsLimit,
7798
}
7899

79100
for _, opt := range opts {
@@ -88,8 +109,10 @@ func Handler(svc ContentRouter, opts ...Option) http.Handler {
88109
}
89110

90111
type server struct {
91-
svc ContentRouter
92-
disableNDJSON bool
112+
svc ContentRouter
113+
disableNDJSON bool
114+
recordsLimit int
115+
streamingRecordsLimit int
93116
}
94117

95118
func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) {
@@ -172,11 +195,11 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {
172195

173196
var supportsNDJSON bool
174197
var supportsJSON bool
175-
var streaming bool
198+
var recordsLimit int
176199
acceptHeaders := httpReq.Header.Values("Accept")
177200
if len(acceptHeaders) == 0 {
178201
handlerFunc = s.findProvidersJSON
179-
streaming = false
202+
recordsLimit = s.recordsLimit
180203
} else {
181204
for _, acceptHeader := range acceptHeaders {
182205
for _, accept := range strings.Split(acceptHeader, ",") {
@@ -197,17 +220,17 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {
197220

198221
if supportsNDJSON && !s.disableNDJSON {
199222
handlerFunc = s.findProvidersNDJSON
200-
streaming = true
223+
recordsLimit = s.streamingRecordsLimit
201224
} else if supportsJSON {
202225
handlerFunc = s.findProvidersJSON
203-
streaming = false
226+
recordsLimit = s.recordsLimit
204227
} else {
205228
writeErr(w, "FindProviders", http.StatusBadRequest, errors.New("no supported content types"))
206229
return
207230
}
208231
}
209232

210-
provIter, err := s.svc.FindProviders(httpReq.Context(), cid, streaming)
233+
provIter, err := s.svc.FindProviders(httpReq.Context(), cid, recordsLimit)
211234
if err != nil {
212235
writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err))
213236
return

routing/http/server/server_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func TestHeaders(t *testing.T) {
3333
cb, err := cid.Decode(c)
3434
require.NoError(t, err)
3535

36-
router.On("FindProviders", mock.Anything, cb, false).
36+
router.On("FindProviders", mock.Anything, cb, DefaultRecordsLimit).
3737
Return(results, nil)
3838

3939
resp, err := http.Get(serverAddr + ProvidePath + c)
@@ -85,7 +85,11 @@ func TestResponse(t *testing.T) {
8585
server := httptest.NewServer(Handler(router))
8686
t.Cleanup(server.Close)
8787
serverAddr := "http://" + server.Listener.Addr().String()
88-
router.On("FindProviders", mock.Anything, cid, expectedStream).Return(results, nil)
88+
limit := DefaultRecordsLimit
89+
if expectedStream {
90+
limit = DefaultStreamingRecordsLimit
91+
}
92+
router.On("FindProviders", mock.Anything, cid, limit).Return(results, nil)
8993
urlStr := serverAddr + ProvidePath + cidStr
9094

9195
req, err := http.NewRequest(http.MethodGet, urlStr, nil)
@@ -115,8 +119,8 @@ func TestResponse(t *testing.T) {
115119

116120
type mockContentRouter struct{ mock.Mock }
117121

118-
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error) {
119-
args := m.Called(ctx, key, stream)
122+
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.ProviderResponse], error) {
123+
args := m.Called(ctx, key, limit)
120124
return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1)
121125
}
122126
func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) {

0 commit comments

Comments
 (0)