diff --git a/client/retrieval/connection.go b/client/retrieval/connection.go index d0dee24..9b8dea9 100644 --- a/client/retrieval/connection.go +++ b/client/retrieval/connection.go @@ -102,6 +102,32 @@ func (c *Connection) Hasher() hash.Hash { return c.hasher() } +// ExecutionOption is an option configuring a retrieval execution. +type ExecutionOption func(cfg *execConfig) + +type execConfig struct { + allowPublicRetrieval bool +} + +// WithPublicRetrieval configures the client to allow retrievals from public +// buckets. +// +// This means that responses that do not include an X-Agent-Message header will +// be treated as valid rather than errors. It is up to the caller to inspect the +// response data to determine if it is acceptable. +// +// When this option is set and the response does not contain an X-Agent-Message +// header, the [client.ExecutionResponse] returned by the call to [Execute] will +// be nil. +// +// Note: this does not prevent the client from sending authorized requests, it +// only affects how responses are interpreted. +func WithPublicRetrieval() ExecutionOption { + return func(cfg *execConfig) { + cfg.allowPublicRetrieval = true + } +} + // Execute performs a UCAN invocation using the headercar transport, // implementing a "probe and retry" pattern to handle HTTP header size // limitations when the invocation is too large to fit. @@ -128,7 +154,12 @@ func (c *Connection) Hasher() hash.Hash { // // Returns the execution response, the final HTTP response, and any error // encountered. -func Execute(ctx context.Context, inv invocation.Invocation, conn client.Connection) (client.ExecutionResponse, transport.HTTPResponse, error) { +func Execute(ctx context.Context, inv invocation.Invocation, conn client.Connection, options ...ExecutionOption) (client.ExecutionResponse, transport.HTTPResponse, error) { + cfg := execConfig{} + for _, o := range options { + o(&cfg) + } + input, err := message.Build([]invocation.Invocation{inv}, nil) if err != nil { return nil, nil, fmt.Errorf("building message: %w", err) @@ -170,6 +201,9 @@ func Execute(ctx context.Context, inv invocation.Invocation, conn client.Connect output, err := conn.Codec().Decode(response) if err != nil { + if cfg.allowPublicRetrieval && errors.Is(err, hcmsg.ErrMissingHeader) { + return nil, response, nil + } return nil, nil, fmt.Errorf("decoding message: %w", err) } diff --git a/client/retrieval/connection_test.go b/client/retrieval/connection_test.go index e4284b0..a5c955f 100644 --- a/client/retrieval/connection_test.go +++ b/client/retrieval/connection_test.go @@ -8,6 +8,8 @@ import ( "net/http" "net/http/httptest" "net/url" + "strconv" + "strings" "testing" prime "github.com/ipld/go-ipld-prime" @@ -306,3 +308,56 @@ func TestExecute(t *testing.T) { }) } } + +func TestExecuteWithPublicRetrieval(t *testing.T) { + data := helpers.RandomBytes(512) + digest, err := multihash.Sum(data, multihash.SHA2_256, -1) + require.NoError(t, err) + + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rangeStr := r.Header.Get("Range") + byteRange := strings.Split(strings.TrimPrefix(rangeStr, "bytes="), "-") + start, err := strconv.Atoi(byteRange[0]) + require.NoError(t, err) + end, err := strconv.Atoi(byteRange[1]) + require.NoError(t, err) + w.Write(data[start : end+1]) + })) + defer httpServer.Close() + + // specify the byte range we want to receive (inclusive) + contentRange := []int{100, 200} + + url, err := url.Parse(httpServer.URL) + require.NoError(t, err) + + headers := http.Header{} + headers.Set("Range", fmt.Sprintf("bytes=%d-%d", contentRange[0], contentRange[1])) + + conn, err := NewConnection( + fixtures.Service, + url.JoinPath("blob", "z"+digest.B58String()), + WithHeaders(headers), + ) + require.NoError(t, err) + + dlg := mkDelegationChain(t, fixtures.Service, fixtures.Alice, serve.Can(), 1) + inv, err := serve.Invoke( + fixtures.Alice, + fixtures.Service, + fixtures.Service.DID().String(), + serveCaveats{Digest: digest, Range: contentRange}, + delegation.WithProof(delegation.FromDelegation(dlg)), + ) + require.NoError(t, err) + + // allow public retrieval + xRes, hRes, err := Execute(t.Context(), inv, conn, WithPublicRetrieval()) + require.NoError(t, err) + require.Nil(t, xRes) + require.NotNil(t, hRes) + + body, err := io.ReadAll(hRes.Body()) + require.NoError(t, err) + require.Equal(t, data[contentRange[0]:contentRange[1]+1], body) +} diff --git a/transport/headercar/message/header.go b/transport/headercar/message/header.go index ea872ae..a3fbaf5 100644 --- a/transport/headercar/message/header.go +++ b/transport/headercar/message/header.go @@ -21,7 +21,10 @@ const ( MaxHeaderSizeBytes = 4 * 1024 ) -var ErrHeaderTooLarge = errors.New("maximum agent message header size exceeded") +var ( + ErrHeaderTooLarge = errors.New("maximum agent message header size exceeded") + ErrMissingHeader = fmt.Errorf("missing %s header", HeaderName) +) type encodeConfig struct { maxSize int diff --git a/transport/headercar/response/response.go b/transport/headercar/response/response.go index 53811a2..2c240b0 100644 --- a/transport/headercar/response/response.go +++ b/transport/headercar/response/response.go @@ -23,7 +23,7 @@ func Encode(msg message.AgentMessage) (transport.HTTPResponse, error) { func Decode(response transport.HTTPResponse) (message.AgentMessage, error) { msgHdr := response.Headers().Get(hcmsg.HeaderName) if msgHdr == "" { - return nil, fmt.Errorf("missing %s header in response", hcmsg.HeaderName) + return nil, hcmsg.ErrMissingHeader } return hcmsg.DecodeHeader(msgHdr) }