Skip to content

Commit d4b75cf

Browse files
committed
feat: provide maximum count to routing.FindProviders
1 parent c793b6d commit d4b75cf

File tree

3 files changed

+37
-11
lines changed

3 files changed

+37
-11
lines changed

routing/http/client/client_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ import (
2727

2828
type mockContentRouter struct{ mock.Mock }
2929

30-
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) {
31-
args := m.Called(ctx, key)
30+
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, count int) (iter.ResultIter[types.ProviderResponse], error) {
31+
args := m.Called(ctx, key, count)
3232
return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1)
3333
}
3434
func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) {
@@ -302,7 +302,7 @@ func TestClient_FindProviders(t *testing.T) {
302302

303303
findProvsIter := iter.FromSlice(c.routerProvs)
304304

305-
router.On("FindProviders", mock.Anything, cid).
305+
router.On("FindProviders", mock.Anything, cid, 20).
306306
Return(findProvsIter, c.routerErr)
307307

308308
provsIter, err := client.FindProviders(ctx, cid)

routing/http/server/server.go

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ type FindProvidersAsyncResponse struct {
4141
}
4242

4343
type ContentRouter interface {
44-
FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error)
44+
// FindProviders searches for peers who are able to provide a given key. Count
45+
// indicates the maximum amount of providers we are looking for. If count is 0,
46+
// the implementer can return an unbounded number of results.
47+
FindProviders(ctx context.Context, key cid.Cid, count int) (iter.ResultIter[types.ProviderResponse], error)
4548
ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error)
4649
Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error)
4750
}
@@ -69,9 +72,27 @@ func WithStreamingResultsDisabled() Option {
6972
}
7073
}
7174

75+
// WithRecordsCount changes the amount of records asked for non-streaming requests.
76+
// Default is 20.
77+
func WithRecordsCount(count int) Option {
78+
return func(s *server) {
79+
s.recordsCount = count
80+
}
81+
}
82+
83+
// WithStreamingRecordsCount changes the amount of records asked for streaming requests.
84+
// Default is 0 (unbounded).
85+
func WithStreamingRecordsCount(count int) Option {
86+
return func(s *server) {
87+
s.streamingRecordsCount = count
88+
}
89+
}
90+
7291
func Handler(svc ContentRouter, opts ...Option) http.Handler {
7392
server := &server{
74-
svc: svc,
93+
svc: svc,
94+
recordsCount: 20,
95+
streamingRecordsCount: 0,
7596
}
7697

7798
for _, opt := range opts {
@@ -86,8 +107,10 @@ func Handler(svc ContentRouter, opts ...Option) http.Handler {
86107
}
87108

88109
type server struct {
89-
svc ContentRouter
90-
disableNDJSON bool
110+
svc ContentRouter
111+
disableNDJSON bool
112+
recordsCount int
113+
streamingRecordsCount int
91114
}
92115

93116
func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) {
@@ -170,6 +193,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {
170193

171194
var supportsNDJSON bool
172195
var supportsJSON bool
196+
var count int
173197
acceptHeaders := httpReq.Header.Values("Accept")
174198
if len(acceptHeaders) == 0 {
175199
handlerFunc = s.findProvidersJSON
@@ -185,8 +209,10 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {
185209
switch mediaType {
186210
case mediaTypeJSON, mediaTypeWildcard:
187211
supportsJSON = true
212+
count = s.recordsCount
188213
case mediaTypeNDJSON:
189214
supportsNDJSON = true
215+
count = s.streamingRecordsCount
190216
}
191217
}
192218
}
@@ -201,7 +227,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {
201227
}
202228
}
203229

204-
provIter, err := s.svc.FindProviders(httpReq.Context(), cid)
230+
provIter, err := s.svc.FindProviders(httpReq.Context(), cid, count)
205231
if err != nil {
206232
writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err))
207233
return

routing/http/server/server_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func TestHeaders(t *testing.T) {
3333
cb, err := cid.Decode(c)
3434
require.NoError(t, err)
3535

36-
router.On("FindProviders", mock.Anything, cb).
36+
router.On("FindProviders", mock.Anything, cb, 0).
3737
Return(results, nil)
3838

3939
resp, err := http.Get(serverAddr + ProvidePath + c)
@@ -118,8 +118,8 @@ func TestResponse(t *testing.T) {
118118

119119
type mockContentRouter struct{ mock.Mock }
120120

121-
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) {
122-
args := m.Called(ctx, key)
121+
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, count int) (iter.ResultIter[types.ProviderResponse], error) {
122+
args := m.Called(ctx, key, count)
123123
return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1)
124124
}
125125
func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) {

0 commit comments

Comments
 (0)