Skip to content

Commit 00efe71

Browse files
HTTP Fetcher: Disentangle checksum.sri from Digest
When we implemented HTTP fetching, there was no digest_function field in the requests, so we used checksum.sri to choose one. The API has now changed to explicitly include a digest_function, so using the checksum.sri is actually incorrect behaviour -- if a client sets digest_function to SHA256, and specifies a SHA512 checksum, we should validate using the SHA512 checksum, but return a SHA256 digest. This change in API may need a wider rework of how checksum validation works. We should probably add a decorator fetcher that does this, pulling the object from CAS based on the inner fetcher's response, and validating that. This would be more correct, but at the expense of additional network I/O.
1 parent f83741a commit 00efe71

File tree

1 file changed

+69
-63
lines changed

1 file changed

+69
-63
lines changed

pkg/fetch/http_fetcher.go

Lines changed: 69 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package fetch
22

33
import (
4-
"bytes"
54
"context"
65
"encoding/base64"
76
"encoding/hex"
@@ -52,36 +51,38 @@ func NewHTTPFetcher(httpClient *http.Client,
5251
}
5352

5453
func (hf *httpFetcher) FetchBlob(ctx context.Context, req *remoteasset.FetchBlobRequest) (*remoteasset.FetchBlobResponse, error) {
55-
var err error
56-
instanceName, err := bb_digest.NewInstanceName(req.InstanceName)
54+
digestFunction, err := getDigestFunction(req.DigestFunction, req.InstanceName)
5755
if err != nil {
58-
return nil, util.StatusWrapf(err, "Invalid instance name %#v", req.InstanceName)
56+
return nil, err
5957
}
6058

6159
// TODO: Address the following fields
6260
// timeout := ptypes.Duration(req.timeout)
6361
// oldestContentAccepted := ptypes.Timestamp(req.oldestContentAccepted)
64-
expectedDigest, digestFunctionEnum, err := getChecksumSri(req.Qualifiers)
62+
expectedDigest, checksumFunction, err := getChecksumSri(req.Qualifiers)
6563
if err != nil {
6664
return nil, err
6765
}
68-
if digestFunctionEnum == remoteexecution.DigestFunction_UNKNOWN {
69-
// Default to SHA256 if no digest is provided.
70-
digestFunctionEnum = remoteexecution.DigestFunction_SHA256
71-
}
7266

7367
auth, err := getAuthHeaders(req.Uris, req.Qualifiers)
7468
if err != nil {
7569
return nil, err
7670
}
7771

7872
for _, uri := range req.Uris {
79-
buffer, digest := hf.downloadBlob(ctx, uri, instanceName, expectedDigest, digestFunctionEnum, auth)
73+
buffer, digest := hf.downloadBlob(ctx, uri, digestFunction, auth)
8074
if _, err = buffer.GetSizeBytes(); err != nil {
8175
log.Printf("Error downloading blob with URI %s: %v", uri, err)
8276
continue
8377
}
8478

79+
// Check the checksum.sri qualifier, if there's an expected Digest
80+
if expectedDigest != "" {
81+
if ok, err := validateChecksumSri(buffer, checksumFunction, expectedDigest); !ok {
82+
return nil, err
83+
}
84+
}
85+
8586
if err = hf.contentAddressableStorage.Put(ctx, digest, buffer); err != nil {
8687
log.Printf("Error downloading blob with URI %s: %v", uri, err)
8788
return nil, util.StatusWrapWithCode(err, codes.Internal, "Failed to place blob into CAS")
@@ -111,16 +112,41 @@ func (hf *httpFetcher) CheckQualifiers(qualifiers qualifier.Set) qualifier.Set {
111112
return qualifier.Difference(qualifiers, toRemove)
112113
}
113114

114-
func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceName bb_digest.InstanceName, expectedDigest string, digestFunctionEnum remoteexecution.DigestFunction_Value, auth *AuthHeaders) (buffer.Buffer, bb_digest.Digest) {
115+
// validateChecksumSri ensures that the checksum of the passed response matches the expected value.
116+
func validateChecksumSri(buf buffer.Buffer, checksumFunction bb_digest.Function, expectedDigest string) (bool, error) {
117+
sizeBytes, err := buf.GetSizeBytes()
118+
if err != nil {
119+
return false, err
120+
}
121+
checksumGenerator := checksumFunction.NewGenerator(sizeBytes)
122+
written, err := io.Copy(checksumGenerator, buf.ToReader())
123+
if err != nil {
124+
return false, err
125+
}
126+
if written != sizeBytes {
127+
return false, status.Errorf(codes.Internal, "Failed to hash entire buffer")
128+
}
129+
130+
checksum := checksumGenerator.Sum().GetProto().GetHash()
131+
if checksum != expectedDigest {
132+
return false, status.Errorf(codes.Internal, "Fetched content did not match checksum.sri qualifier: Expected %s, Got %s", expectedDigest, checksum)
133+
}
134+
135+
return true, nil
136+
}
137+
138+
// downloadBlob performs the actual blob download, yielding a buffer of the content and its Digest
139+
func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, digestFunction bb_digest.Function, auth *AuthHeaders) (buffer.Buffer, bb_digest.Digest) {
140+
// Generate the HTTP Request
115141
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
116142
if err != nil {
117143
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Failed to create HTTP request")), bb_digest.BadDigest
118144
}
119-
120145
if auth != nil {
121146
auth.ApplyHeaders(uri, req)
122147
}
123148

149+
// Perform the request, check for status
124150
resp, err := hf.httpClient.Do(req)
125151
if err != nil {
126152
log.Printf("Error downloading blob with URI %s: %v", uri, err)
@@ -131,57 +157,24 @@ func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceNam
131157
return buffer.NewBufferFromError(status.Errorf(codes.Internal, "HTTP request failed with status %#v", resp.Status)), bb_digest.BadDigest
132158
}
133159

134-
digestFunction, err := instanceName.GetDigestFunction(digestFunctionEnum, len(expectedDigest))
160+
// Compute the Digest
161+
bodyBytes, err := io.ReadAll(resp.Body)
135162
if err != nil {
136-
return buffer.NewBufferFromError(util.StatusWrapfWithCode(err, codes.Internal, "Failed to get digest function for instance: %v", instanceName)), bb_digest.BadDigest
163+
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Failed to read response body")), bb_digest.BadDigest
137164
}
138-
139-
// Work out the digest of the downloaded data
140-
//
141-
// If the HTTP response includes the content length (indicated by the value
142-
// of the field being >= 0) and the client has provided an expected hash of
143-
// the content, we can avoid holding the contents of the entire file in
144-
// memory at one time by creating a new buffer from the response body
145-
// directly
146-
//
147-
// If either one (or both) of these things is not available, we will need to
148-
// read the enitre response body into a byte slice in order to be able to
149-
// determine the digest
150-
length := resp.ContentLength
151-
body := resp.Body
152-
if length < 0 || expectedDigest == "" {
153-
bodyBytes, err := io.ReadAll(resp.Body)
154-
if err != nil {
155-
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Failed to read response body")), bb_digest.BadDigest
156-
}
157-
err = resp.Body.Close()
158-
if err != nil {
159-
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Failed to close response body")), bb_digest.BadDigest
160-
}
161-
length = int64(len(bodyBytes))
162-
163-
// If we don't know what the hash should be we will need to work out the
164-
// actual hash of the content
165-
if expectedDigest == "" {
166-
hasher := digestFunction.NewGenerator(length)
167-
hasher.Write(bodyBytes)
168-
digest := hasher.Sum()
169-
expectedDigest = digest.GetHashString()
170-
}
171-
172-
body = io.NopCloser(bytes.NewBuffer(bodyBytes))
173-
}
174-
digest, err := digestFunction.NewDigest(expectedDigest, length)
165+
err = resp.Body.Close()
175166
if err != nil {
176-
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Digest Creation failed")), bb_digest.BadDigest
167+
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Failed to close response body")), bb_digest.BadDigest
177168
}
169+
hasher := digestFunction.NewGenerator(resp.ContentLength)
170+
hasher.Write(bodyBytes)
171+
digest := hasher.Sum()
178172

179-
// An error will be generated down the line if the data does not match the
180-
// digest
181-
return buffer.NewCASBufferFromReader(digest, body, buffer.UserProvided), digest
173+
return buffer.NewCASBufferFromByteSlice(digest, bodyBytes, buffer.UserProvided), digest
182174
}
183175

184-
func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, remoteexecution.DigestFunction_Value, error) {
176+
// getChecksumSri parses the checksum.sri qualifier into an expected digest and a digest function to use
177+
func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, bb_digest.Function, error) {
185178
hashTypes := map[string]remoteexecution.DigestFunction_Value{
186179
"sha256": remoteexecution.DigestFunction_SHA256,
187180
"sha1": remoteexecution.DigestFunction_SHA1,
@@ -195,27 +188,40 @@ func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, remoteexecutio
195188
for _, qualifier := range qualifiers {
196189
if qualifier.Name == "checksum.sri" {
197190
if digestFunctionEnum != remoteexecution.DigestFunction_UNKNOWN {
198-
return "", remoteexecution.DigestFunction_UNKNOWN, status.Errorf(codes.InvalidArgument, "Multiple checksum.sri provided")
191+
return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Multiple checksum.sri provided")
199192
}
200193
parts := strings.SplitN(qualifier.Value, "-", 2)
201194
if len(parts) != 2 {
202-
return "", remoteexecution.DigestFunction_UNKNOWN, status.Errorf(codes.InvalidArgument, "Bad checksum.sri hash expression: %s", qualifier.Value)
195+
return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Bad checksum.sri hash expression: %s", qualifier.Value)
203196
}
204197
hashName := parts[0]
205198
b64hash := parts[1]
206-
var ok bool
207-
digestFunctionEnum, ok = hashTypes[hashName]
199+
200+
digestFunctionEnum, ok := hashTypes[hashName]
208201
if !ok {
209-
return "", remoteexecution.DigestFunction_UNKNOWN, status.Errorf(codes.InvalidArgument, "Unsupported checksum algorithm %s", hashName)
202+
return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Unsupported checksum algorithm %s", hashName)
210203
}
204+
205+
// Convert expected digest to hex
211206
decoded, err := base64.StdEncoding.DecodeString(b64hash)
212207
if err != nil {
213-
return "", remoteexecution.DigestFunction_UNKNOWN, status.Errorf(codes.InvalidArgument, "Failed to decode checksum as base64 encoded %s sum: %s", hashName, err.Error())
208+
return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Failed to decode checksum as base64 encoded %s sum: %s", hashName, err.Error())
214209
}
215210
expectedDigest = hex.EncodeToString(decoded)
211+
212+
// Convert to a proper digest function.
213+
// Note: The Instance name doesn't matter here, this function is used only
214+
// to give us a convenient API when actually checking the checksum.
215+
instance := bb_digest.MustNewInstanceName("")
216+
checksumFunction, err := instance.GetDigestFunction(digestFunctionEnum, len(expectedDigest))
217+
if err != nil {
218+
return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Failed to get checksum function for checksum.sri: %s", err.Error())
219+
}
220+
return expectedDigest, checksumFunction, nil
216221
}
217222
}
218-
return expectedDigest, digestFunctionEnum, nil
223+
224+
return "", bb_digest.Function{}, nil
219225
}
220226

221227
func getAuthHeaders(uris []string, qualifiers []*remoteasset.Qualifier) (*AuthHeaders, error) {

0 commit comments

Comments
 (0)