Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 18 additions & 28 deletions gofakes3.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gofakes3

import (
"bytes"
"encoding/base64"
"encoding/hex"
"fmt"
Expand Down Expand Up @@ -54,7 +53,7 @@ func New(backend Backend, options ...Option) *GoFakeS3 {
timeSkew: DefaultSkewLimit,
metadataSizeLimit: DefaultMetadataSizeLimit,
integrityCheck: true,
uploader: newUploader(),
uploader: newUploader(newMemoryTempBlobFactory()),
requestID: 0,
}

Expand Down Expand Up @@ -940,37 +939,28 @@ func (g *GoFakeS3) putMultipartUploadPart(bucket, object string, uploadID Upload
rdr = r.Body
}

var expectedMD5Base64 string
if g.integrityCheck {
md5Base64 := r.Header.Get("Content-MD5")
if _, ok := r.Header[textproto.CanonicalMIMEHeaderKey("Content-MD5")]; ok && md5Base64 == "" {
expectedMD5Base64 = r.Header.Get("Content-MD5")
if _, ok := r.Header[textproto.CanonicalMIMEHeaderKey("Content-MD5")]; ok && expectedMD5Base64 == "" {
return ErrInvalidDigest // Satisfies s3tests
}

if md5Base64 != "" {
var err error
rdr, err = newHashingReader(rdr, md5Base64)
if err != nil {
return err
}
}
}

body, err := ReadAll(rdr, size)
if err != nil {
return err
}
{
rdr, err := newHashingReader(rdr, expectedMD5Base64)
if err != nil {
return err
}

if int64(len(body)) != size {
return ErrIncompleteBody
}
etag, err := upload.AddPart(r.Context(), int(partNumber), g.timeSource.Now(), rdr, size)
if err != nil {
return err
}

etag, err := upload.AddPart(int(partNumber), g.timeSource.Now(), body)
if err != nil {
return err
w.Header().Add("ETag", etag)
return nil
}

w.Header().Add("ETag", etag)
return nil
}

func (g *GoFakeS3) abortMultipartUpload(bucket, object string, uploadID UploadID, w http.ResponseWriter, r *http.Request) error {
Expand All @@ -995,12 +985,12 @@ func (g *GoFakeS3) completeMultipartUpload(bucket, object string, uploadID Uploa
return err
}

fileBody, etag, err := upload.Reassemble(&in)
fileBody, size, err := upload.Reassemble(r.Context(), &in)
if err != nil {
return err
}

result, err := g.storage.PutObject(r.Context(), bucket, object, upload.Meta, bytes.NewReader(fileBody), int64(len(fileBody)))
result, err := g.storage.PutObject(r.Context(), bucket, object, upload.Meta, fileBody, size)
if err != nil {
return err
}
Expand All @@ -1009,7 +999,7 @@ func (g *GoFakeS3) completeMultipartUpload(bucket, object string, uploadID Uploa
}

return g.xmlEncoder(w).Encode(&CompleteMultipartUploadResult{
ETag: etag,
ETag: fmt.Sprintf("%x", fileBody.Sum(nil)),
Bucket: bucket,
Key: object,
})
Expand Down
26 changes: 19 additions & 7 deletions hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
// If the expected hash is not empty, once the underlying reader returns EOF,
// the hash is checked.
type hashingReader struct {
inner io.Reader
expected []byte
hash hash.Hash
sum []byte
inner io.Reader
expectedStr string

Check failure on line 19 in hash.go

View workflow job for this annotation

GitHub Actions / lint

field `expectedStr` is unused (unused)
expected []byte
hash hash.Hash
sum []byte
}

func newHashingReader(inner io.Reader, expectedMD5Base64 string) (*hashingReader, error) {
func newHashingReader(inner io.Reader, optExpectedMD5Base64 string) (*hashingReader, error) {
var md5Bytes []byte
var err error

if expectedMD5Base64 != "" {
md5Bytes, err = base64.StdEncoding.DecodeString(expectedMD5Base64)
if optExpectedMD5Base64 != "" {
md5Bytes, err = base64.StdEncoding.DecodeString(optExpectedMD5Base64)
if err != nil {
return nil, ErrInvalidDigest
}
Expand All @@ -42,6 +43,17 @@
}, nil
}

func (h *hashingReader) GetExpectedMD5() []byte {
if h.expected == nil {
return nil
}

ret := make([]byte, len(h.expected))
copy(ret, h.expected)

return ret
}

// Sum returns the hash of the data read from the inner reader so far.
// If into is passed, it may be used if the hash needs to be computed.
func (h *hashingReader) Sum(into []byte) []byte {
Expand Down
6 changes: 6 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ func WithV4Auth(authPair map[string]string) Option {
return func(g *GoFakeS3) { g.v4AuthPair = authPair }
}

// WithTempBlobFactory allors you to not use the standard in memory backend
// for multipart upload
func WithTempBlobFactory(factory TempBlobFactory) Option {
return func(g *GoFakeS3) { g.uploader = newUploader(factory) }
}

// WithTimeSource allows you to substitute the behaviour of time.Now() and
// time.Since() within GoFakeS3. This can be used to trigger time skew errors,
// or to ensure the output of the commands is deterministic.
Expand Down
51 changes: 51 additions & 0 deletions temp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package gofakes3

import (
"bytes"
"context"
"io"
)

type TempBlobFactory interface {
New(bucket, object string, partNumber int, size int64, expectedMD5 []byte) (TempBlob, error)
}

type TempBlob interface {
Reader(context.Context) io.ReadCloser
Writer(context.Context) io.WriteCloser
Cleanup(context.Context)
}

func newMemoryTempBlobFactory() TempBlobFactory {
return &memoryTempBlobFactory{}
}

type memoryTempBlobFactory struct{}

// New implements TempBlobFactory.
func (m *memoryTempBlobFactory) New(bucket, object string, partNumber int, size int64, epectedMD5 []byte) (TempBlob, error) {
return &memoryTempBlob{
buf: bytes.NewBuffer(make([]byte, 0, size)),
}, nil
}

type memoryTempBlob struct {
buf *bytes.Buffer
}

// Cleanup implements TempBlob.
func (m *memoryTempBlob) Cleanup(context.Context) {}

// Reader implements TempBlob.
func (m *memoryTempBlob) Reader(context.Context) io.ReadCloser {
return io.NopCloser(bytes.NewReader(m.buf.Bytes()))
}

// Writer implements TempBlob.
func (m *memoryTempBlob) Writer(context.Context) io.WriteCloser { return &nopWriteCloser{m.buf} }

type nopWriteCloser struct {
io.Writer
}

func (wc *nopWriteCloser) Close() error { return nil }
90 changes: 58 additions & 32 deletions uploader.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package gofakes3

import (
"crypto/md5"
"context"
"encoding/hex"
"fmt"
"io"
"math/big"
"net/url"
"strings"
Expand Down Expand Up @@ -154,12 +155,15 @@

buckets map[string]*bucketUploads
mu sync.Mutex

tempBlobFactory TempBlobFactory
}

func newUploader() *uploader {
func newUploader(tempBlobFactory TempBlobFactory) *uploader {
return &uploader{
buckets: make(map[string]*bucketUploads),
uploadID: new(big.Int),
buckets: make(map[string]*bucketUploads),
uploadID: new(big.Int),
tempBlobFactory: tempBlobFactory,
}
}

Expand All @@ -170,11 +174,12 @@
u.uploadID.Add(u.uploadID, add1)

mpu := &multipartUpload{
ID: UploadID(u.uploadID.String()),
Bucket: bucket,
Object: object,
Meta: meta,
Initiated: initiated,
ID: UploadID(u.uploadID.String()),
Bucket: bucket,
Object: object,
Meta: meta,
Initiated: initiated,
tempBlobFactory: u.tempBlobFactory,
}

// FIXME: make sure the uploader responds to DeleteBucket
Expand Down Expand Up @@ -221,7 +226,7 @@

result.Parts = append(result.Parts, ListMultipartUploadPartItem{
ETag: part.ETag,
Size: int64(len(part.Body)),
Size: part.Size,
PartNumber: partNumber,
LastModified: part.LastModified,
})
Expand Down Expand Up @@ -420,7 +425,8 @@
type multipartUploadPart struct {
PartNumber int
ETag string
Body []byte
TempBlob TempBlob
Size int64
LastModified ContentTime
}

Expand All @@ -443,29 +449,43 @@
// always be nil.
//
// Do not attempt to access parts without locking mu.
parts []*multipartUploadPart
parts []*multipartUploadPart
tempBlobFactory TempBlobFactory

mu sync.Mutex
}

func (mpu *multipartUpload) AddPart(partNumber int, at time.Time, body []byte) (etag string, err error) {
func (mpu *multipartUpload) AddPart(ctx context.Context, partNumber int, at time.Time, body *hashingReader, size int64) (etag string, err error) {
if partNumber > MaxUploadPartNumber {
return "", ErrInvalidPart
}

mpu.mu.Lock()
defer mpu.mu.Unlock()
tempBlob, err := mpu.tempBlobFactory.New(mpu.Bucket, mpu.Object, partNumber, size, body.GetExpectedMD5())
if err != nil {
return "", err
}

w := tempBlob.Writer(ctx)
defer w.Close()

Check failure on line 469 in uploader.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `w.Close` is not checked (errcheck)

_, err = io.Copy(w, newSizeCheckerReader(body, size))
if err != nil {
tempBlob.Cleanup(ctx)
return
}

// What the ETag actually is is not specified, so let's just invent any old thing
// from guaranteed unique input:
hash := md5.New()
hash.Write(body)
etag = fmt.Sprintf(`"%s"`, hex.EncodeToString(hash.Sum(nil)))
etag = fmt.Sprintf(`"%s"`, hex.EncodeToString(body.Sum(nil)))

mpu.mu.Lock()
defer mpu.mu.Unlock()

part := multipartUploadPart{
PartNumber: partNumber,
Body: body,
TempBlob: tempBlob,
ETag: etag,
Size: size,
LastModified: NewContentTime(at),
}
if partNumber >= len(mpu.parts) {
Expand All @@ -475,7 +495,7 @@
return etag, nil
}

func (mpu *multipartUpload) Reassemble(input *CompleteMultipartUploadRequest) (body []byte, etag string, err error) {
func (mpu *multipartUpload) Reassemble(ctx context.Context, input *CompleteMultipartUploadRequest) (body *hashingReader, size int64, err error) {
mpu.mu.Lock()
defer mpu.mu.Unlock()

Expand All @@ -485,34 +505,40 @@
// end up uploading more parts than you need to assemble, so it should
// probably just ignore that?
if len(input.Parts) > mpuPartsLen {
return nil, "", ErrInvalidPart
return nil, 0, ErrInvalidPart
}

if !input.partsAreSorted() {
return nil, "", ErrInvalidPartOrder
return nil, 0, ErrInvalidPartOrder
}

var size int64

for _, inPart := range input.Parts {
if inPart.PartNumber >= mpuPartsLen || mpu.parts[inPart.PartNumber] == nil {
return nil, "", ErrorMessagef(ErrInvalidPart, "unexpected part number %d in complete request", inPart.PartNumber)
return nil, 0, ErrorMessagef(ErrInvalidPart, "unexpected part number %d in complete request", inPart.PartNumber)
}

upPart := mpu.parts[inPart.PartNumber]
if strings.Trim(inPart.ETag, "\"") != strings.Trim(upPart.ETag, "\"") {
return nil, "", ErrorMessagef(ErrInvalidPart, "unexpected part etag for number %d in complete request", inPart.PartNumber)
return nil, 0, ErrorMessagef(ErrInvalidPart, "unexpected part etag for number %d in complete request", inPart.PartNumber)
}

size += int64(len(upPart.Body))
size += upPart.Size
}

body = make([]byte, 0, size)
for _, part := range input.Parts {
body = append(body, mpu.parts[part.PartNumber].Body...)
readers := make([]io.Reader, len(input.Parts))
for i, inPart := range input.Parts {
tmpBlob := mpu.parts[inPart.PartNumber].TempBlob
reader := tmpBlob.Reader(ctx)
defer reader.Close()

Check failure on line 532 in uploader.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `reader.Close` is not checked (errcheck)
defer tmpBlob.Cleanup(ctx)

readers[i] = reader
}

hash := fmt.Sprintf("%x", md5.Sum(body))
body, err = newHashingReader(io.MultiReader(readers...), "")
if err != nil {
return nil, 0, err
}

return body, hash, nil
return body, size, nil
}
Loading
Loading