Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 13 additions & 30 deletions materialize-snowflake/bdec.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
Expand Down Expand Up @@ -416,24 +415,22 @@ func getCipherStream(encryptionKey string, fileName blobFileName) (cipher.Stream
return cipher.NewCTR(block, make([]byte, aes.BlockSize)), nil
}

// reencrypt reads encrypted blob data from r, decrypts it using the
// decryptKey and original file name, and then re-encrypts it using the
// new file name and writes the output to w. The blob metadata is updated in
// place.
// reencrypt reads encrypted blob data from r, decrypts it using the encryption
// key and original file name, and then re-encrypts it using the new file name
// and writes the output to w. The blob metadata is updated in place.
func reencrypt(
r io.Reader,
w io.Writer,
blob *blobMetadata,
decryptKey string,
channel *channel,
encryptionKey string,
newFileName blobFileName,
) error {
decryptStream, err := getCipherStream(decryptKey, blobFileName(blob.Path))
decryptStream, err := getCipherStream(encryptionKey, blobFileName(blob.Path))
if err != nil {
return fmt.Errorf("getting decryptStream: %w", err)
}

encryptStream, err := getCipherStream(channel.EncryptionKey, newFileName)
encryptStream, err := getCipherStream(encryptionKey, newFileName)
if err != nil {
return fmt.Errorf("getting encryptStream: %w", err)
}
Expand All @@ -442,28 +439,21 @@ func reencrypt(
upMd5 := md5.New()

buf := make([]byte, 64*1024)
for {
n, err := r.Read(buf)
if err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("reading from r: %w", err)
} else if n == 0 {
break
}

firstRead := true
n, err := io.ReadAtLeast(r, buf, 4)
if err != nil {
return fmt.Errorf("reading from r: %w", err)
}

for n != 0 {
if m, err := downMd5.Write(buf[:n]); err != nil {
return fmt.Errorf("writing to downMd5: %w", err)
} else if m != n {
return fmt.Errorf("written to downMd5 %d != expected bytes %d", m, n)
}

decryptStream.XORKeyStream(buf[:n], buf[:n])

// The stream must be a valid parquet file, if the magic is incorrect
// is is very likely we have the wrong decryption key.
if firstRead && !bytes.Equal(buf[:4], []byte("PAR1")) {
return fmt.Errorf("unexpected magic: %v", buf[:4])
}

encryptStream.XORKeyStream(buf[:n], buf[:n])

if m, err := upMd5.Write(buf[:n]); err != nil {
Expand All @@ -475,12 +465,6 @@ func reencrypt(
} else if written != n {
return fmt.Errorf("written bytes %d != expected bytes %d", written, n)
}

firstRead = false
n, err = r.Read(buf)
if err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("reading from r: %w", err)
}
}

downHash := hex.EncodeToString(downMd5.Sum(nil))
Expand All @@ -492,7 +476,6 @@ func reencrypt(
blob.Path = string(newFileName)
blob.MD5 = upHash
blob.Chunks[0].ChunkMD5 = upHash
blob.Chunks[0].EncryptionKeyID = channel.EncryptionKeyId

return nil
}
Expand Down
10 changes: 1 addition & 9 deletions materialize-snowflake/bdec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ func TestReencrypt(t *testing.T) {
input := make([]byte, 1024*1024+3)
_, err := rand.Read(input)
require.NoError(t, err)
copy(input, "PAR1")

encryptionKey := "aGVsbG8K"
oldFileName := blobFileName("old")
newFileName := blobFileName("new")
Expand All @@ -133,11 +131,6 @@ func TestReencrypt(t *testing.T) {
sum := md5.Sum(encrypted)
hsh := hex.EncodeToString(sum[:])

channel := &channel{
EncryptionKey: encryptionKey,
EncryptionKeyId: 12345,
}

blob := &blobMetadata{
Path: string(oldFileName),
MD5: hsh,
Expand All @@ -147,8 +140,7 @@ func TestReencrypt(t *testing.T) {
var in = bytes.NewReader(encrypted)
var out bytes.Buffer

err = reencrypt(in, &out, blob, encryptionKey, channel, newFileName)
require.NoError(t, err)
require.NoError(t, reencrypt(in, &out, blob, encryptionKey, newFileName))

// The re-encrypted file matches the original input.
newKey, err := deriveKey(encryptionKey, newFileName)
Expand Down
31 changes: 10 additions & 21 deletions materialize-snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,6 @@ type transactor struct {
// this shard's range spec and version, used to key pipes so they don't collide
_range *pf.RangeSpec
version string

// If this is still the recovery (first after startup) commit, where special
// handling may be needed for registering streaming files.
didRecovery bool
}

func (d *transactor) RecoverCheckpoint(_ context.Context, _ pf.MaterializationSpec, _ pf.RangeSpec) (m.RuntimeCheckpoint, error) {
Expand Down Expand Up @@ -582,14 +578,13 @@ func (d *transactor) pipeExists(ctx context.Context, pipeName string) (bool, err
}

type checkpointItem struct {
Table string
Query string
StagedDir string
StreamBlobs []*blobMetadata
PipeName string
PipeFiles []fileRecord
Version string
EncryptionKey string
Table string
Query string
StagedDir string
StreamBlobs []*blobMetadata
PipeName string
PipeFiles []fileRecord
Version string
}

type checkpoint = map[string]*checkpointItem
Expand Down Expand Up @@ -638,15 +633,14 @@ func (d *transactor) Store(it *m.StoreIterator) (m.StartCommitFunc, error) {

func (d *transactor) buildDriverCheckpoint(ctx context.Context, runtimeCheckpoint *protocol.Checkpoint) (json.RawMessage, error) {
streamBlobs := make(map[int][]*blobMetadata)
keys := make(map[int]string)
if d.streamManager != nil {
// The "base token" only really needs to be sufficiently random that it
// doesn't collide with the prior or next transaction's value. Deriving
// it from the runtime checkpoint is not absolutely necessary, but it's
// convenient to make testing outputs consistent.
if mcp, err := runtimeCheckpoint.Marshal(); err != nil {
return nil, fmt.Errorf("marshalling checkpoint: %w", err)
} else if streamBlobs, keys, err = d.streamManager.flush(fmt.Sprintf("%016x", xxhash.Sum64(mcp))); err != nil {
} else if streamBlobs, err = d.streamManager.flush(fmt.Sprintf("%016x", xxhash.Sum64(mcp))); err != nil {
return nil, fmt.Errorf("flushing stream manager: %w", err)
}
}
Expand All @@ -655,8 +649,7 @@ func (d *transactor) buildDriverCheckpoint(ctx context.Context, runtimeCheckpoin
if b.streaming {
if blobs, ok := streamBlobs[idx]; ok {
d.cp[b.target.StateKey] = &checkpointItem{
StreamBlobs: blobs,
EncryptionKey: keys[idx],
StreamBlobs: blobs,
}
}
continue
Expand Down Expand Up @@ -819,10 +812,6 @@ func (d *transactor) copyHistory(ctx context.Context, tableName string, fileName

// Acknowledge merges data from temporary table to main table
func (d *transactor) Acknowledge(ctx context.Context) (*pf.ConnectorState, error) {
defer func() {
d.didRecovery = true
}()

// Run store queries concurrently, as each independently operates on a separate table.
group, groupCtx := errgroup.WithContext(ctx)
group.SetLimit(MaxConcurrentQueries)
Expand Down Expand Up @@ -885,7 +874,7 @@ func (d *transactor) Acknowledge(ctx context.Context) (*pf.ConnectorState, error
} else if len(item.StreamBlobs) > 0 {
group.Go(func() error {
d.be.StartedResourceCommit(path)
if err := d.streamManager.write(groupCtx, item.StreamBlobs, item.EncryptionKey, !d.didRecovery); err != nil {
if err := d.streamManager.write(groupCtx, item.StreamBlobs); err != nil {
return fmt.Errorf("writing streaming blobs for %s: %w", path, err)
}
d.be.FinishedResourceCommit(path)
Expand Down
129 changes: 48 additions & 81 deletions materialize-snowflake/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"errors"
"fmt"
"path"
"strconv"
Expand Down Expand Up @@ -169,28 +170,26 @@ func (sm *streamManager) objMetadata() blob.WriterOption {
})
}

func (sm *streamManager) flush(baseToken string) (map[int][]*blobMetadata, map[int]string, error) {
func (sm *streamManager) flush(baseToken string) (map[int][]*blobMetadata, error) {
if sm.bdecWriter != nil {
if err := sm.finishBlob(); err != nil {
return nil, nil, fmt.Errorf("finishBlob: %w", err)
return nil, fmt.Errorf("finishBlob: %w", err)
}
}

out := make(map[int][]*blobMetadata)
keys := make(map[int]string)
for binding, trackedBlobs := range sm.blobStats {
for idx, trackedBlob := range trackedBlobs {
out[binding] = append(out[binding], generateBlobMetadata(
trackedBlob,
sm.tableStreams[binding].channel,
blobToken(baseToken, idx)),
)
keys[binding] = sm.tableStreams[binding].channel.EncryptionKey
}
}
maps.Clear(sm.blobStats)

return out, keys, nil
return out, nil
}

// write registers a series of blobs to the table. All the blobs must be for the
Expand All @@ -209,7 +208,7 @@ func (sm *streamManager) flush(baseToken string) (map[int][]*blobMetadata, map[i
// channel itself persists the last registered token, and this allows us to
// filter out blobs that had previously been registered and don't need
// registered again on a re-attempt of an Acknowledge.
func (sm *streamManager) write(ctx context.Context, blobs []*blobMetadata, decryptKey string, recovery bool) error {
func (sm *streamManager) write(ctx context.Context, blobs []*blobMetadata) error {
if err := validWriteBlobs(blobs); err != nil {
return fmt.Errorf("validWriteBlobs: %w", err)
}
Expand Down Expand Up @@ -248,46 +247,48 @@ func (sm *streamManager) write(ctx context.Context, blobs []*blobMetadata, decry
blob.Chunks[0].Channels[0].ClientSequencer = thisChannel.ClientSequencer
blob.Chunks[0].Channels[0].RowSequencer = thisChannel.RowSequencer + 1

if recovery {
// Handle the case where decryptKey was not in the checkpoint because
// the checkpoint was created before it was added. We will guess that
// its the same as the current key, if its not then the decrypted
// data will not have the parquet magic header.
//
// This check can be removed once all tasks have written a checkpoint.
if decryptKey == "" {
log.WithFields(log.Fields{
"schema": schema,
"table": table,
}).Warn("unknown encryption key for rewrite; using current")
decryptKey = thisChannel.EncryptionKey
if err := sm.c.write(ctx, blob); err != nil {
var apiError *streamingApiError
if errors.As(err, &apiError) && apiError.Code == 38 {
// The "blob has wrong format or extension" error occurs if the
// blob was written not-so-recently; apparently anything older
// than an hour or so is rejected by what seems to be a
// server-side check the examines the name of the file, which
// contains the timestamp it was written.
//
// In these cases, which may arise from re-enabling a disabled
// binding / materialization, or extended outages, we have to
// download the file and re-upload it with an up-to-date name.
//
// This should be quite rare, even more rare than one may think,
// since blob registration tokens are persisted in Snowflake
// rather than exclusively managed by our Acknowledge
// checkpoints. But it is still technically possible and so it
// is handled with this.
if err := sm.maybeInitializeBucket(ctx); err != nil {
return fmt.Errorf("initializing bucket to rename: %w", err)
}
nextName := sm.getNextFileName(time.Now(), fmt.Sprintf("%s_%d", sm.prefix, sm.deploymentId))

ll := log.WithFields(log.Fields{
"oldName": blob.Path,
"newName": nextName,
"token": blobToken,
})
ll.Info("attempting to register blob with malformed name by renaming")

if err := sm.renameBlob(ctx, blob, thisChannel.EncryptionKey, nextName); err != nil {
return fmt.Errorf("renameBlob: %w", err)
}

if err := sm.c.write(ctx, blob); err != nil {
log.WithField("blob", blob).Warn("blob metadata")
return fmt.Errorf("failed to write renamed blob: %w", err)
}
ll.Info("successfully registered renamed blob")
} else {
return fmt.Errorf("write: %w", err)
}

// Always use the rename and re-upload strategy during recovery
// commits. This prevents issues where blobs could be re-registered
// if the underlying channel has changed somehow, perhaps by being
// dropped out-of-band, or its last persisted token changed in some
// other way.
//
// It is also possible the channel encryption key has changed and
// the blob will need to be rewritten with the new encryption key.
//
// This also address the "blob has wrong format or extension" error
// which occurs if the blob was written not-so-recently; apparently
// anything older than an hour or so is rejected by what seems to
// be a server-side check the examines the name of the file, which
// contains the timestamp it was written. In this case, which
// may arise from re-enabling a disabled binding / materialization,
// or extended outages, we have to download the file and re-upload
// it with an up-to-date name.
//
// This should be a relatively infrequent occurrence, so the
// performance impact + increased data transfer should be tolerable.
if err := sm.writeRenamed(ctx, decryptKey, thisChannel, blob); err != nil {
return fmt.Errorf("writeRenamed: %w", err)
}
} else if err := sm.c.write(ctx, blob); err != nil {
return fmt.Errorf("write: %w", err)
}

thisChannel.RowSequencer++
Expand All @@ -311,48 +312,14 @@ func (sm *streamManager) write(ctx context.Context, blobs []*blobMetadata, decry
return nil
}

func (sm *streamManager) writeRenamed(ctx context.Context, decryptKey string, channel *channel, blob *blobMetadata) error {
if err := sm.maybeInitializeBucket(ctx); err != nil {
return fmt.Errorf("initializing bucket to rename: %w", err)
}
nextName := sm.getNextFileName(time.Now(), fmt.Sprintf("%s_%d", sm.prefix, sm.deploymentId))

ll := log.WithFields(log.Fields{
"oldName": blob.Path,
"newName": nextName,
"token": blob.Chunks[0].Channels[0].OffsetToken,
"schema": blob.Chunks[0].Schema,
"table": blob.Chunks[0].Table,
})
ll.Info("attempting to register blob by renaming")

if err := sm.renameBlob(ctx, blob, decryptKey, channel, nextName); err != nil {
return fmt.Errorf("renameBlob: %w", err)
}

if err := sm.c.write(ctx, blob); err != nil {
ll.WithField("blob", blob).Warn("blob metadata")
return fmt.Errorf("failed to write renamed blob: %w", err)
}
ll.Info("successfully registered renamed blob")

return nil
}

func (sm *streamManager) renameBlob(
ctx context.Context,
blob *blobMetadata,
decryptKey string,
channel *channel,
newName blobFileName,
) error {
func (sm *streamManager) renameBlob(ctx context.Context, blob *blobMetadata, encryptionKey string, newName blobFileName) error {
r, err := sm.bucket.NewReader(ctx, path.Join(sm.bucketPath, blob.Path))
if err != nil {
return fmt.Errorf("NewReader: %w", err)
}
w := sm.bucket.NewWriter(ctx, path.Join(sm.bucketPath, string(newName)), sm.objMetadata())

if err := reencrypt(r, w, blob, decryptKey, channel, newName); err != nil {
if err := reencrypt(r, w, blob, encryptionKey, newName); err != nil {
return fmt.Errorf("reencrypt: %w", err)
} else if err := r.Close(); err != nil {
return fmt.Errorf("closing r: %w", err)
Expand Down
Loading
Loading