11package fetch
22
33import (
4- "bytes"
54 "context"
65 "encoding/base64"
76 "encoding/hex"
@@ -52,36 +51,38 @@ func NewHTTPFetcher(httpClient *http.Client,
5251}
5352
5453func (hf * httpFetcher ) FetchBlob (ctx context.Context , req * remoteasset.FetchBlobRequest ) (* remoteasset.FetchBlobResponse , error ) {
55- var err error
56- instanceName , err := bb_digest .NewInstanceName (req .InstanceName )
54+ digestFunction , err := getDigestFunction (req .DigestFunction , req .InstanceName )
5755 if err != nil {
58- return nil , util . StatusWrapf ( err , "Invalid instance name %#v" , req . InstanceName )
56+ return nil , err
5957 }
6058
6159 // TODO: Address the following fields
6260 // timeout := ptypes.Duration(req.timeout)
6361 // oldestContentAccepted := ptypes.Timestamp(req.oldestContentAccepted)
64- expectedDigest , digestFunctionEnum , err := getChecksumSri (req .Qualifiers )
62+ expectedDigest , checksumFunction , err := getChecksumSri (req .Qualifiers )
6563 if err != nil {
6664 return nil , err
6765 }
68- if digestFunctionEnum == remoteexecution .DigestFunction_UNKNOWN {
69- // Default to SHA256 if no digest is provided.
70- digestFunctionEnum = remoteexecution .DigestFunction_SHA256
71- }
7266
7367 auth , err := getAuthHeaders (req .Uris , req .Qualifiers )
7468 if err != nil {
7569 return nil , err
7670 }
7771
7872 for _ , uri := range req .Uris {
79- buffer , digest := hf .downloadBlob (ctx , uri , instanceName , expectedDigest , digestFunctionEnum , auth )
73+ buffer , digest := hf .downloadBlob (ctx , uri , digestFunction , auth )
8074 if _ , err = buffer .GetSizeBytes (); err != nil {
8175 log .Printf ("Error downloading blob with URI %s: %v" , uri , err )
8276 continue
8377 }
8478
79+ // Check the checksum.sri qualifier, if there's an expected Digest
80+ if expectedDigest != "" {
81+ if ok , err := validateChecksumSri (buffer , checksumFunction , expectedDigest ); ! ok {
82+ return nil , err
83+ }
84+ }
85+
8586 if err = hf .contentAddressableStorage .Put (ctx , digest , buffer ); err != nil {
8687 log .Printf ("Error downloading blob with URI %s: %v" , uri , err )
8788 return nil , util .StatusWrapWithCode (err , codes .Internal , "Failed to place blob into CAS" )
@@ -111,16 +112,41 @@ func (hf *httpFetcher) CheckQualifiers(qualifiers qualifier.Set) qualifier.Set {
111112 return qualifier .Difference (qualifiers , toRemove )
112113}
113114
114- func (hf * httpFetcher ) downloadBlob (ctx context.Context , uri string , instanceName bb_digest.InstanceName , expectedDigest string , digestFunctionEnum remoteexecution.DigestFunction_Value , auth * AuthHeaders ) (buffer.Buffer , bb_digest.Digest ) {
115+ // validateChecksumSri ensures that the checksum of the passed response matches the expected value.
116+ func validateChecksumSri (buf buffer.Buffer , checksumFunction bb_digest.Function , expectedDigest string ) (bool , error ) {
117+ sizeBytes , err := buf .GetSizeBytes ()
118+ if err != nil {
119+ return false , err
120+ }
121+ checksumGenerator := checksumFunction .NewGenerator (sizeBytes )
122+ written , err := io .Copy (checksumGenerator , buf .ToReader ())
123+ if err != nil {
124+ return false , err
125+ }
126+ if written != sizeBytes {
127+ return false , status .Errorf (codes .Internal , "Failed to hash entire buffer" )
128+ }
129+
130+ checksum := checksumGenerator .Sum ().GetProto ().GetHash ()
131+ if checksum != expectedDigest {
132+ return false , status .Errorf (codes .Internal , "Fetched content did not match checksum.sri qualifier: Expected %s, Got %s" , expectedDigest , checksum )
133+ }
134+
135+ return true , nil
136+ }
137+
138+ // downloadBlob performs the actual blob download, yielding a buffer of the content and its Digest
139+ func (hf * httpFetcher ) downloadBlob (ctx context.Context , uri string , digestFunction bb_digest.Function , auth * AuthHeaders ) (buffer.Buffer , bb_digest.Digest ) {
140+ // Generate the HTTP Request
115141 req , err := http .NewRequestWithContext (ctx , http .MethodGet , uri , nil )
116142 if err != nil {
117143 return buffer .NewBufferFromError (util .StatusWrapWithCode (err , codes .Internal , "Failed to create HTTP request" )), bb_digest .BadDigest
118144 }
119-
120145 if auth != nil {
121146 auth .ApplyHeaders (uri , req )
122147 }
123148
149+ // Perform the request, check for status
124150 resp , err := hf .httpClient .Do (req )
125151 if err != nil {
126152 log .Printf ("Error downloading blob with URI %s: %v" , uri , err )
@@ -131,57 +157,24 @@ func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceNam
131157 return buffer .NewBufferFromError (status .Errorf (codes .Internal , "HTTP request failed with status %#v" , resp .Status )), bb_digest .BadDigest
132158 }
133159
134- digestFunction , err := instanceName .GetDigestFunction (digestFunctionEnum , len (expectedDigest ))
160+ // Compute the Digest
161+ bodyBytes , err := io .ReadAll (resp .Body )
135162 if err != nil {
136- return buffer .NewBufferFromError (util .StatusWrapfWithCode (err , codes .Internal , "Failed to get digest function for instance: %v" , instanceName )), bb_digest .BadDigest
163+ return buffer .NewBufferFromError (util .StatusWrapWithCode (err , codes .Internal , "Failed to read response body" )), bb_digest .BadDigest
137164 }
138-
139- // Work out the digest of the downloaded data
140- //
141- // If the HTTP response includes the content length (indicated by the value
142- // of the field being >= 0) and the client has provided an expected hash of
143- // the content, we can avoid holding the contents of the entire file in
144- // memory at one time by creating a new buffer from the response body
145- // directly
146- //
147- // If either one (or both) of these things is not available, we will need to
148- // read the enitre response body into a byte slice in order to be able to
149- // determine the digest
150- length := resp .ContentLength
151- body := resp .Body
152- if length < 0 || expectedDigest == "" {
153- bodyBytes , err := io .ReadAll (resp .Body )
154- if err != nil {
155- return buffer .NewBufferFromError (util .StatusWrapWithCode (err , codes .Internal , "Failed to read response body" )), bb_digest .BadDigest
156- }
157- err = resp .Body .Close ()
158- if err != nil {
159- return buffer .NewBufferFromError (util .StatusWrapWithCode (err , codes .Internal , "Failed to close response body" )), bb_digest .BadDigest
160- }
161- length = int64 (len (bodyBytes ))
162-
163- // If we don't know what the hash should be we will need to work out the
164- // actual hash of the content
165- if expectedDigest == "" {
166- hasher := digestFunction .NewGenerator (length )
167- hasher .Write (bodyBytes )
168- digest := hasher .Sum ()
169- expectedDigest = digest .GetHashString ()
170- }
171-
172- body = io .NopCloser (bytes .NewBuffer (bodyBytes ))
173- }
174- digest , err := digestFunction .NewDigest (expectedDigest , length )
165+ err = resp .Body .Close ()
175166 if err != nil {
176- return buffer .NewBufferFromError (util .StatusWrapWithCode (err , codes .Internal , "Digest Creation failed " )), bb_digest .BadDigest
167+ return buffer .NewBufferFromError (util .StatusWrapWithCode (err , codes .Internal , "Failed to close response body " )), bb_digest .BadDigest
177168 }
169+ hasher := digestFunction .NewGenerator (resp .ContentLength )
170+ hasher .Write (bodyBytes )
171+ digest := hasher .Sum ()
178172
179- // An error will be generated down the line if the data does not match the
180- // digest
181- return buffer .NewCASBufferFromReader (digest , body , buffer .UserProvided ), digest
173+ return buffer .NewCASBufferFromByteSlice (digest , bodyBytes , buffer .UserProvided ), digest
182174}
183175
184- func getChecksumSri (qualifiers []* remoteasset.Qualifier ) (string , remoteexecution.DigestFunction_Value , error ) {
176+ // getChecksumSri parses the checksum.sri qualifier into an expected digest and a digest function to use
177+ func getChecksumSri (qualifiers []* remoteasset.Qualifier ) (string , bb_digest.Function , error ) {
185178 hashTypes := map [string ]remoteexecution.DigestFunction_Value {
186179 "sha256" : remoteexecution .DigestFunction_SHA256 ,
187180 "sha1" : remoteexecution .DigestFunction_SHA1 ,
@@ -195,27 +188,40 @@ func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, remoteexecutio
195188 for _ , qualifier := range qualifiers {
196189 if qualifier .Name == "checksum.sri" {
197190 if digestFunctionEnum != remoteexecution .DigestFunction_UNKNOWN {
198- return "" , remoteexecution . DigestFunction_UNKNOWN , status .Errorf (codes .InvalidArgument , "Multiple checksum.sri provided" )
191+ return "" , bb_digest. Function {} , status .Errorf (codes .InvalidArgument , "Multiple checksum.sri provided" )
199192 }
200193 parts := strings .SplitN (qualifier .Value , "-" , 2 )
201194 if len (parts ) != 2 {
202- return "" , remoteexecution . DigestFunction_UNKNOWN , status .Errorf (codes .InvalidArgument , "Bad checksum.sri hash expression: %s" , qualifier .Value )
195+ return "" , bb_digest. Function {} , status .Errorf (codes .InvalidArgument , "Bad checksum.sri hash expression: %s" , qualifier .Value )
203196 }
204197 hashName := parts [0 ]
205198 b64hash := parts [1 ]
206- var ok bool
207- digestFunctionEnum , ok = hashTypes [hashName ]
199+
200+ digestFunctionEnum , ok : = hashTypes [hashName ]
208201 if ! ok {
209- return "" , remoteexecution . DigestFunction_UNKNOWN , status .Errorf (codes .InvalidArgument , "Unsupported checksum algorithm %s" , hashName )
202+ return "" , bb_digest. Function {} , status .Errorf (codes .InvalidArgument , "Unsupported checksum algorithm %s" , hashName )
210203 }
204+
205+ // Convert expected digest to hex
211206 decoded , err := base64 .StdEncoding .DecodeString (b64hash )
212207 if err != nil {
213- return "" , remoteexecution . DigestFunction_UNKNOWN , status .Errorf (codes .InvalidArgument , "Failed to decode checksum as base64 encoded %s sum: %s" , hashName , err .Error ())
208+ return "" , bb_digest. Function {} , status .Errorf (codes .InvalidArgument , "Failed to decode checksum as base64 encoded %s sum: %s" , hashName , err .Error ())
214209 }
215210 expectedDigest = hex .EncodeToString (decoded )
211+
212+ // Convert to a proper digest function.
213+ // Note: The Instance name doesn't matter here, this function is used only
214+ // to give us a convenient API when actually checking the checksum.
215+ instance := bb_digest .MustNewInstanceName ("" )
216+ checksumFunction , err := instance .GetDigestFunction (digestFunctionEnum , len (expectedDigest ))
217+ if err != nil {
218+ return "" , bb_digest.Function {}, status .Errorf (codes .InvalidArgument , "Failed to get checksum function for checksum.sri: %s" , err .Error ())
219+ }
220+ return expectedDigest , checksumFunction , nil
216221 }
217222 }
218- return expectedDigest , digestFunctionEnum , nil
223+
224+ return "" , bb_digest.Function {}, nil
219225}
220226
221227func getAuthHeaders (uris []string , qualifiers []* remoteasset.Qualifier ) (* AuthHeaders , error ) {
0 commit comments