Skip to content

Commit 770b75c

Browse files
committed
refactor: use announcement response schema
1 parent 813b59f commit 770b75c

File tree

9 files changed

+98
-47
lines changed

9 files changed

+98
-47
lines changed

routing/http/client/client.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ func (c *Client) FindProviders(ctx context.Context, key cid.Cid) (providers iter
240240
// Provide publishes [types.AnnouncementRecord]s based on the given [types.AnnouncementRequests].
241241
// This records will be signed by your provided. Therefore, the [Client] must have been configured
242242
// with [WithProviderInfo].
243-
func (c *Client) Provide(ctx context.Context, announcements ...types.AnnouncementRequest) (iter.ResultIter[*types.AnnouncementRecord], error) {
243+
func (c *Client) Provide(ctx context.Context, announcements ...types.AnnouncementRequest) (iter.ResultIter[*types.AnnouncementResponseRecord], error) {
244244
if err := c.canProvide(); err != nil {
245245
return nil, err
246246
}
@@ -291,7 +291,7 @@ func (c *Client) Provide(ctx context.Context, announcements ...types.Announcemen
291291

292292
// ProvideRecords publishes the given [types.AnnouncementRecord]. An error will
293293
// be returned if the records aren't signed or valid.
294-
func (c *Client) ProvideRecords(ctx context.Context, records ...*types.AnnouncementRecord) (iter.ResultIter[*types.AnnouncementRecord], error) {
294+
func (c *Client) ProvideRecords(ctx context.Context, records ...*types.AnnouncementRecord) (iter.ResultIter[*types.AnnouncementResponseRecord], error) {
295295
providerRecords := make([]types.Record, len(records))
296296
for i, record := range records {
297297
if err := record.Verify(); err != nil {
@@ -307,7 +307,7 @@ func (c *Client) ProvideRecords(ctx context.Context, records ...*types.Announcem
307307
return c.provide(ctx, url, req)
308308
}
309309

310-
func (c *Client) provide(ctx context.Context, url string, req interface{}) (iter.ResultIter[*types.AnnouncementRecord], error) {
310+
func (c *Client) provide(ctx context.Context, url string, req interface{}) (iter.ResultIter[*types.AnnouncementResponseRecord], error) {
311311
b, err := drjson.MarshalJSONBytes(req)
312312
if err != nil {
313313
return nil, err
@@ -342,19 +342,19 @@ func (c *Client) provide(ctx context.Context, url string, req interface{}) (iter
342342
}
343343
}()
344344

345-
var it iter.ResultIter[*types.AnnouncementRecord]
345+
var it iter.ResultIter[*types.AnnouncementResponseRecord]
346346
switch mediaType {
347347
case mediaTypeJSON:
348348
parsedResp := &jsontypes.AnnouncePeersResponse{}
349349
err = json.NewDecoder(resp.Body).Decode(parsedResp)
350350
if err != nil {
351351
return nil, err
352352
}
353-
var sliceIt iter.Iter[*types.AnnouncementRecord] = iter.FromSlice(parsedResp.ProvideResults)
353+
var sliceIt iter.Iter[*types.AnnouncementResponseRecord] = iter.FromSlice(parsedResp.ProvideResults)
354354
it = iter.ToResultIter(sliceIt)
355355
case mediaTypeNDJSON:
356356
skipBodyClose = true
357-
it = ndjson.NewAnnouncementRecordsIter(resp.Body)
357+
it = ndjson.NewAnnouncementResponseRecordsIter(resp.Body)
358358
default:
359359
logger.Errorw("unknown media type", "MediaType", mediaType, "ContentType", respContentType)
360360
return nil, errors.New("unknown content type")
@@ -452,7 +452,7 @@ func (c *Client) FindPeers(ctx context.Context, pid peer.ID) (peers iter.ResultI
452452

453453
// ProvidePeer publishes an [types.AnnouncementRecord] with the provider
454454
// information from your peer, configured with [WithProviderInfo].
455-
func (c *Client) ProvidePeer(ctx context.Context, ttl time.Duration, metadata []byte) (iter.ResultIter[*types.AnnouncementRecord], error) {
455+
func (c *Client) ProvidePeer(ctx context.Context, ttl time.Duration, metadata []byte) (iter.ResultIter[*types.AnnouncementResponseRecord], error) {
456456
if err := c.canProvide(); err != nil {
457457
return nil, err
458458
}
@@ -495,7 +495,7 @@ func (c *Client) ProvidePeer(ctx context.Context, ttl time.Duration, metadata []
495495

496496
// ProvidePeerRecords publishes the given [types.AnnouncementRecord]. An error will
497497
// be returned if the records aren't signed or valid.
498-
func (c *Client) ProvidePeerRecords(ctx context.Context, records ...*types.AnnouncementRecord) (iter.ResultIter[*types.AnnouncementRecord], error) {
498+
func (c *Client) ProvidePeerRecords(ctx context.Context, records ...*types.AnnouncementRecord) (iter.ResultIter[*types.AnnouncementResponseRecord], error) {
499499
providerRecords := make([]types.Record, len(records))
500500
for i, record := range records {
501501
if err := record.Verify(); err != nil {

routing/http/client/client_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ func TestClient_Provide(t *testing.T) {
485485
require.Empty(t, results[0].Error)
486486
}
487487

488-
assert.Equal(t, c.expAdvisoryTTL, results[0].Payload.TTL)
488+
assert.Equal(t, c.expAdvisoryTTL, results[0].TTL)
489489
})
490490
}
491491
}
@@ -754,7 +754,7 @@ func TestClient_ProvidePeer(t *testing.T) {
754754
require.Empty(t, results[0].Error)
755755
}
756756

757-
assert.Equal(t, c.expAdvisoryTTL, results[0].Payload.TTL)
757+
assert.Equal(t, c.expAdvisoryTTL, results[0].TTL)
758758
})
759759
}
760760
}

routing/http/contentrouter/contentrouter.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ const ttl = 24 * time.Hour
2525

2626
type Client interface {
2727
FindProviders(ctx context.Context, key cid.Cid) (iter.ResultIter[types.Record], error)
28-
Provide(ctx context.Context, announcements ...types.AnnouncementRequest) (iter.ResultIter[*types.AnnouncementRecord], error)
28+
Provide(ctx context.Context, announcements ...types.AnnouncementRequest) (iter.ResultIter[*types.AnnouncementResponseRecord], error)
2929
FindPeers(ctx context.Context, pid peer.ID) (peers iter.ResultIter[*types.PeerRecord], err error)
3030
GetIPNS(ctx context.Context, name ipns.Name) (*ipns.Record, error)
3131
PutIPNS(ctx context.Context, name ipns.Name, record *ipns.Record) error

routing/http/contentrouter/contentrouter_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ func (m *mockClient) FindProviders(ctx context.Context, key cid.Cid) (iter.Resul
2727
return args.Get(0).(iter.ResultIter[types.Record]), args.Error(1)
2828
}
2929

30-
func (m *mockClient) Provide(ctx context.Context, announcements ...types.AnnouncementRequest) (iter.ResultIter[*types.AnnouncementRecord], error) {
30+
func (m *mockClient) Provide(ctx context.Context, announcements ...types.AnnouncementRequest) (iter.ResultIter[*types.AnnouncementResponseRecord], error) {
3131
args := m.Called(ctx, announcements)
32-
return args.Get(0).(iter.ResultIter[*types.AnnouncementRecord]), args.Error(1)
32+
return args.Get(0).(iter.ResultIter[*types.AnnouncementResponseRecord]), args.Error(1)
3333
}
3434

3535
func (m *mockClient) FindPeers(ctx context.Context, pid peer.ID) (iter.ResultIter[*types.PeerRecord], error) {
@@ -76,10 +76,10 @@ func TestProvide(t *testing.T) {
7676
crc := NewContentRoutingClient(client)
7777

7878
if !c.expNotProvided {
79-
res := []*types.AnnouncementRecord{
80-
{Payload: types.AnnouncementPayload{TTL: time.Minute}},
79+
res := []*types.AnnouncementResponseRecord{
80+
{TTL: time.Minute},
8181
}
82-
client.On("Provide", ctx, []types.AnnouncementRequest{{CID: key, TTL: ttl}}).Return(iter.ToResultIter[*types.AnnouncementRecord](iter.FromSlice(res)), nil)
82+
client.On("Provide", ctx, []types.AnnouncementRequest{{CID: key, TTL: ttl}}).Return(iter.ToResultIter[*types.AnnouncementResponseRecord](iter.FromSlice(res)), nil)
8383
}
8484

8585
err := crc.Provide(ctx, key, c.announce)
@@ -101,10 +101,10 @@ func TestProvideMany(t *testing.T) {
101101
ctx := context.Background()
102102
client := &mockClient{}
103103
crc := NewContentRoutingClient(client)
104-
res := []*types.AnnouncementRecord{
105-
{Payload: types.AnnouncementPayload{TTL: time.Minute}},
104+
res := []*types.AnnouncementResponseRecord{
105+
{TTL: time.Minute},
106106
}
107-
client.On("Provide", ctx, makeBatchAnnouncements(cids, ttl)).Return(iter.ToResultIter[*types.AnnouncementRecord](iter.FromSlice(res)), nil)
107+
client.On("Provide", ctx, makeBatchAnnouncements(cids, ttl)).Return(iter.ToResultIter[*types.AnnouncementResponseRecord](iter.FromSlice(res)), nil)
108108
err := crc.ProvideMany(ctx, mhs)
109109
require.NoError(t, err)
110110
}

routing/http/server/server.go

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,9 @@ func (s *server) providePeers(w http.ResponseWriter, r *http.Request) {
281281
return
282282
}
283283

284-
responseIter := iter.Map[types.Record, *types.AnnouncementRecord](iter.FromSlice(req.Peers), func(t types.Record) *types.AnnouncementRecord {
285-
resRecord := &types.AnnouncementRecord{
286-
Schema: types.SchemaAnnouncement,
284+
responseIter := iter.Map[types.Record, *types.AnnouncementResponseRecord](iter.FromSlice(req.Peers), func(t types.Record) *types.AnnouncementResponseRecord {
285+
resRecord := &types.AnnouncementResponseRecord{
286+
Schema: types.SchemaAnnouncementResponse,
287287
}
288288

289289
reqRecord, err := s.provideCheckAnnouncement("Provide", t)
@@ -298,8 +298,7 @@ func (s *server) providePeers(w http.ResponseWriter, r *http.Request) {
298298
return resRecord
299299
}
300300

301-
resRecord.Payload.TTL = ttl
302-
resRecord.Payload.ID = reqRecord.Payload.ID
301+
resRecord.TTL = ttl
303302
return resRecord
304303
})
305304

@@ -310,10 +309,10 @@ func (s *server) providePeers(w http.ResponseWriter, r *http.Request) {
310309
}
311310

312311
if mediaType == mediaTypeNDJSON {
313-
writeResultsIterNDJSON[*types.AnnouncementRecord](w, iter.ToResultIter[*types.AnnouncementRecord](responseIter))
312+
writeResultsIterNDJSON[*types.AnnouncementResponseRecord](w, iter.ToResultIter[*types.AnnouncementResponseRecord](responseIter))
314313
} else {
315314
writeJSONResult(w, "ProvidePeers", jsontypes.AnnouncePeersResponse{
316-
ProvideResults: iter.ReadAll[*types.AnnouncementRecord](responseIter),
315+
ProvideResults: iter.ReadAll[*types.AnnouncementResponseRecord](responseIter),
317316
})
318317
}
319318
}
@@ -327,9 +326,9 @@ func (s *server) provide(w http.ResponseWriter, r *http.Request) {
327326
return
328327
}
329328

330-
responseIter := iter.Map[types.Record, *types.AnnouncementRecord](iter.FromSlice(req.Providers), func(t types.Record) *types.AnnouncementRecord {
331-
resRecord := &types.AnnouncementRecord{
332-
Schema: types.SchemaAnnouncement,
329+
responseIter := iter.Map[types.Record, *types.AnnouncementResponseRecord](iter.FromSlice(req.Providers), func(t types.Record) *types.AnnouncementResponseRecord {
330+
resRecord := &types.AnnouncementResponseRecord{
331+
Schema: types.SchemaAnnouncementResponse,
333332
}
334333

335334
reqRecord, err := s.provideCheckAnnouncement("Provide", t)
@@ -344,9 +343,7 @@ func (s *server) provide(w http.ResponseWriter, r *http.Request) {
344343
return resRecord
345344
}
346345

347-
resRecord.Payload.TTL = ttl
348-
resRecord.Payload.CID = reqRecord.Payload.CID
349-
resRecord.Payload.ID = reqRecord.Payload.ID
346+
resRecord.TTL = ttl
350347
return resRecord
351348
})
352349

@@ -357,10 +354,10 @@ func (s *server) provide(w http.ResponseWriter, r *http.Request) {
357354
}
358355

359356
if mediaType == mediaTypeNDJSON {
360-
writeResultsIterNDJSON[*types.AnnouncementRecord](w, iter.ToResultIter[*types.AnnouncementRecord](responseIter))
357+
writeResultsIterNDJSON[*types.AnnouncementResponseRecord](w, iter.ToResultIter[*types.AnnouncementResponseRecord](responseIter))
361358
} else {
362359
writeJSONResult(w, "Provide", jsontypes.AnnounceProvidersResponse{
363-
ProvideResults: iter.ReadAll[*types.AnnouncementRecord](responseIter),
360+
ProvideResults: iter.ReadAll[*types.AnnouncementResponseRecord](responseIter),
364361
})
365362
}
366363
}

routing/http/server/server_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,11 @@ func TestProviders(t *testing.T) {
206206
}
207207

208208
t.Run("POST /routing/v1/providers (JSON Response)", func(t *testing.T) {
209-
runPutTest(t, mediaTypeJSON, `{"ProvideResults":[{"Schema":"announcement","Payload":{"CID":"`+cid1Str+`","ID":"`+pid1Str+`","TTL":3600000}},{"Schema":"announcement","Payload":{"CID":"`+cid1Str+`","ID":"`+pid2Str+`","TTL":60000}}]}`)
209+
runPutTest(t, mediaTypeJSON, `{"ProvideResults":[{"Schema":"announcement-response","TTL":3600000},{"Schema":"announcement-response","TTL":60000}]}`)
210210
})
211211

212212
t.Run("POST /routing/v1/providers (NDJSON Response)", func(t *testing.T) {
213-
runPutTest(t, mediaTypeNDJSON, `{"Schema":"announcement","Payload":{"CID":"`+cid1Str+`","ID":"`+pid1Str+`","TTL":3600000}}`+"\n"+`{"Schema":"announcement","Payload":{"CID":"`+cid1Str+`","ID":"`+pid2Str+`","TTL":60000}}`+"\n")
213+
runPutTest(t, mediaTypeNDJSON, `{"Schema":"announcement-response","TTL":3600000}`+"\n"+`{"Schema":"announcement-response","TTL":60000}`+"\n")
214214
})
215215
}
216216

@@ -376,11 +376,11 @@ func TestPeers(t *testing.T) {
376376
}
377377

378378
t.Run("POST /routing/v1/peers (JSON Response)", func(t *testing.T) {
379-
runPutTest(t, mediaTypeJSON, `{"ProvideResults":[{"Schema":"announcement","Payload":{"ID":"`+pid1.String()+`","TTL":3600000}},{"Schema":"announcement","Payload":{"ID":"`+pid2.String()+`","TTL":60000}}]}`)
379+
runPutTest(t, mediaTypeJSON, `{"ProvideResults":[{"Schema":"announcement-response","TTL":3600000},{"Schema":"announcement-response","TTL":60000}]}`)
380380
})
381381

382382
t.Run("POST /routing/v1/peers (NDJSON Response)", func(t *testing.T) {
383-
runPutTest(t, mediaTypeNDJSON, `{"Schema":"announcement","Payload":{"ID":"`+pid1.String()+`","TTL":3600000}}`+"\n"+`{"Schema":"announcement","Payload":{"ID":"`+pid2.String()+`","TTL":60000}}`+"\n")
383+
runPutTest(t, mediaTypeNDJSON, `{"Schema":"announcement-response","TTL":3600000}`+"\n"+`{"Schema":"announcement-response","TTL":60000}`+"\n")
384384
})
385385
}
386386

routing/http/types/json/responses.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ func (r *RecordsArray) UnmarshalJSON(b []byte) error {
4848
return err
4949
}
5050
*r = append(*r, &prov)
51+
case types.SchemaAnnouncementResponse:
52+
var prov types.AnnouncementResponseRecord
53+
err := json.Unmarshal(provBytes, &prov)
54+
if err != nil {
55+
return err
56+
}
57+
*r = append(*r, &prov)
5158
default:
5259
*r = append(*r, &readProv)
5360
}
@@ -58,7 +65,7 @@ func (r *RecordsArray) UnmarshalJSON(b []byte) error {
5865

5966
// AnnounceProvidersResponse is the result of a POST Providers request.
6067
type AnnounceProvidersResponse struct {
61-
ProvideResults []*types.AnnouncementRecord
68+
ProvideResults []*types.AnnouncementResponseRecord
6269
}
6370

6471
// AnnouncePeersResponse is the result of a POST Peers request.

routing/http/types/ndjson/records.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ func NewRecordsIter(r io.Reader) iter.Iter[iter.Result[types.Record]] {
4444
return iter.Map[iter.Result[types.UnknownRecord]](jsonIter, mapFn)
4545
}
4646

47-
// NewAnnouncementRecordsIter returns an iterator that reads [types.AnnouncementRecord]
48-
// from the given [io.Reader]. Records with a different schema are ignored. To read all
49-
// records, use [NewRecordsIter] instead.
50-
func NewAnnouncementRecordsIter(r io.Reader) iter.Iter[iter.Result[*types.AnnouncementRecord]] {
51-
return newFilteredRecords[*types.AnnouncementRecord](r, types.SchemaPeer)
47+
// NewAnnouncementResponseRecordsIter returns an iterator that reads
48+
// [types.AnnouncementResponseRecord] from the given [io.Reader]. Records with
49+
// a different schema are ignored. To read all records, use [NewRecordsIter] instead.
50+
func NewAnnouncementResponseRecordsIter(r io.Reader) iter.Iter[iter.Result[*types.AnnouncementResponseRecord]] {
51+
return newFilteredRecords[*types.AnnouncementResponseRecord](r, types.SchemaPeer)
5252
}
5353

5454
// NewPeerRecordsIter returns an iterator that reads [types.PeerRecord] from the given

routing/http/types/record_announcement.go

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ import (
1919
"github.com/multiformats/go-multibase"
2020
)
2121

22-
const SchemaAnnouncement = "announcement"
23-
const announcementSignaturePrefix = "routing-record:"
22+
const (
23+
SchemaAnnouncement = "announcement"
24+
SchemaAnnouncementResponse = "announcement-response"
25+
26+
announcementSignaturePrefix = "routing-record:"
27+
)
2428

2529
var _ Record = &AnnouncementRecord{}
2630

@@ -195,7 +199,6 @@ func (ap *AnnouncementPayload) UnmarshalJSON(b []byte) error {
195199
// AnnouncementRecord is a [Record] of [SchemaAnnouncement].
196200
type AnnouncementRecord struct {
197201
Schema string
198-
Error string `json:",omitempty"`
199202
Payload AnnouncementPayload
200203
Signature string `json:",omitempty"`
201204
}
@@ -327,3 +330,47 @@ func makeIPLDMap(mp map[string]ipld.Node) (datamodel.Node, error) {
327330

328331
return nd.Build(), nil
329332
}
333+
334+
var _ Record = &AnnouncementResponseRecord{}
335+
336+
// AnnouncementRecord is a [Record] of [SchemaAnnouncementResponse].
337+
type AnnouncementResponseRecord struct {
338+
Schema string
339+
Error string
340+
TTL time.Duration
341+
}
342+
343+
func (ar *AnnouncementResponseRecord) GetSchema() string {
344+
return ar.Schema
345+
}
346+
347+
func (ar AnnouncementResponseRecord) MarshalJSON() ([]byte, error) {
348+
v := struct {
349+
Schema string
350+
Error string `json:",omitempty"`
351+
TTL int64 `json:",omitempty"`
352+
}{
353+
Schema: ar.Schema,
354+
Error: ar.Error,
355+
TTL: ar.TTL.Milliseconds(),
356+
}
357+
358+
return drjson.MarshalJSONBytes(v)
359+
}
360+
361+
func (ar *AnnouncementResponseRecord) UnmarshalJSON(b []byte) error {
362+
v := struct {
363+
Schema string
364+
Error string
365+
TTL int64
366+
}{}
367+
err := json.Unmarshal(b, &v)
368+
if err != nil {
369+
return err
370+
}
371+
372+
ar.Schema = v.Schema
373+
ar.Error = v.Error
374+
ar.TTL = time.Duration(v.TTL) * time.Millisecond
375+
return nil
376+
}

0 commit comments

Comments
 (0)