Skip to content

Commit d96e912

Browse files
committed
feat: indicate if response will be streamable on routing.FindProviders
1 parent d37890a commit d96e912

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

routing/http/client/client_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ import (
2727

2828
type mockContentRouter struct{ mock.Mock }
2929

30-
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) {
31-
args := m.Called(ctx, key)
30+
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error) {
31+
args := m.Called(ctx, key, stream)
3232
return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1)
3333
}
3434
func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) {
@@ -302,7 +302,7 @@ func TestClient_FindProviders(t *testing.T) {
302302

303303
findProvsIter := iter.FromSlice(c.routerProvs)
304304

305-
router.On("FindProviders", mock.Anything, cid).
305+
router.On("FindProviders", mock.Anything, cid, c.expStreamingResponse).
306306
Return(findProvsIter, c.routerErr)
307307

308308
provsIter, err := client.FindProviders(ctx, cid)

routing/http/server/server.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ type FindProvidersAsyncResponse struct {
4141
}
4242

4343
type ContentRouter interface {
44-
FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error)
44+
// FindProviders searches for peers who are able to provide a given key. Stream
45+
// indicates whether or not this request will be responded as a stream.
46+
FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error)
4547
ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error)
4648
Provide(ctx context.Context, req *WriteProvideRequest) (types.ProviderResponse, error)
4749
}
@@ -170,9 +172,11 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {
170172

171173
var supportsNDJSON bool
172174
var supportsJSON bool
175+
var streaming bool
173176
acceptHeaders := httpReq.Header.Values("Accept")
174177
if len(acceptHeaders) == 0 {
175178
handlerFunc = s.findProvidersJSON
179+
streaming = false
176180
} else {
177181
for _, acceptHeader := range acceptHeaders {
178182
for _, accept := range strings.Split(acceptHeader, ",") {
@@ -193,15 +197,17 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {
193197

194198
if supportsNDJSON && !s.disableNDJSON {
195199
handlerFunc = s.findProvidersNDJSON
200+
streaming = true
196201
} else if supportsJSON {
197202
handlerFunc = s.findProvidersJSON
203+
streaming = false
198204
} else {
199205
writeErr(w, "FindProviders", http.StatusBadRequest, errors.New("no supported content types"))
200206
return
201207
}
202208
}
203209

204-
provIter, err := s.svc.FindProviders(httpReq.Context(), cid)
210+
provIter, err := s.svc.FindProviders(httpReq.Context(), cid, streaming)
205211
if err != nil {
206212
writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err))
207213
return

routing/http/server/server_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func TestHeaders(t *testing.T) {
3333
cb, err := cid.Decode(c)
3434
require.NoError(t, err)
3535

36-
router.On("FindProviders", mock.Anything, cb).
36+
router.On("FindProviders", mock.Anything, cb, false).
3737
Return(results, nil)
3838

3939
resp, err := http.Get(serverAddr + ProvidePath + c)
@@ -60,7 +60,7 @@ func TestResponse(t *testing.T) {
6060
cid, err := cid.Decode(cidStr)
6161
require.NoError(t, err)
6262

63-
runTest := func(t *testing.T, contentType string, expected string) {
63+
runTest := func(t *testing.T, contentType string, expectedStream bool, expectedBody string) {
6464
t.Parallel()
6565

6666
results := iter.FromSlice([]iter.Result[types.ProviderResponse]{
@@ -76,7 +76,7 @@ func TestResponse(t *testing.T) {
7676
server := httptest.NewServer(Handler(router))
7777
t.Cleanup(server.Close)
7878
serverAddr := "http://" + server.Listener.Addr().String()
79-
router.On("FindProviders", mock.Anything, cid).Return(results, nil)
79+
router.On("FindProviders", mock.Anything, cid, expectedStream).Return(results, nil)
8080
urlStr := serverAddr + ProvidePath + cidStr
8181

8282
req, err := http.NewRequest(http.MethodGet, urlStr, nil)
@@ -92,22 +92,22 @@ func TestResponse(t *testing.T) {
9292
body, err := io.ReadAll(resp.Body)
9393
require.NoError(t, err)
9494

95-
require.Equal(t, string(body), expected)
95+
require.Equal(t, string(body), expectedBody)
9696
}
9797

9898
t.Run("JSON Response", func(t *testing.T) {
99-
runTest(t, mediaTypeJSON, `{"Providers":[{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]}]}`)
99+
runTest(t, mediaTypeJSON, false, `{"Providers":[{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]}]}`)
100100
})
101101

102102
t.Run("NDJSON Response", func(t *testing.T) {
103-
runTest(t, mediaTypeNDJSON, `{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]}`+"\n")
103+
runTest(t, mediaTypeNDJSON, true, `{"Protocol":"transport-bitswap","Schema":"bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Addrs":[]}`+"\n")
104104
})
105105
}
106106

107107
type mockContentRouter struct{ mock.Mock }
108108

109-
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.ProviderResponse], error) {
110-
args := m.Called(ctx, key)
109+
func (m *mockContentRouter) FindProviders(ctx context.Context, key cid.Cid, stream bool) (iter.ResultIter[types.ProviderResponse], error) {
110+
args := m.Called(ctx, key, stream)
111111
return args.Get(0).(iter.ResultIter[types.ProviderResponse]), args.Error(1)
112112
}
113113
func (m *mockContentRouter) ProvideBitswap(ctx context.Context, req *BitswapWriteProvideRequest) (time.Duration, error) {

0 commit comments

Comments
 (0)