Skip to content

Commit bb13a69

Browse files
committed
Replace storageReference with ExternalStorageReference proto
1 parent 9d32461 commit bb13a69

9 files changed

Lines changed: 497 additions & 58 deletions

File tree

converter/payload_handler_test.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@ type appendCodec struct {
3131
func (c *appendCodec) Encode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
3232
result := make([]*commonpb.Payload, len(payloads))
3333
for i, p := range payloads {
34-
enc := string(p.GetMetadata()[converter.MetadataEncoding]) + c.encodingSuffix
35-
data := append(append([]byte(nil), p.GetData()...), c.marker)
34+
meta := make(map[string][]byte, len(p.GetMetadata()))
35+
for k, v := range p.GetMetadata() {
36+
meta[k] = v
37+
}
38+
meta[converter.MetadataEncoding] = []byte(string(p.GetMetadata()[converter.MetadataEncoding]) + c.encodingSuffix)
3639
result[i] = &commonpb.Payload{
37-
Metadata: map[string][]byte{converter.MetadataEncoding: []byte(enc)},
38-
Data: data,
40+
Metadata: meta,
41+
Data: append(append([]byte(nil), p.GetData()...), c.marker),
3942
}
4043
}
4144
return result, nil
@@ -52,8 +55,13 @@ func (c *appendCodec) Decode(payloads []*commonpb.Payload) ([]*commonpb.Payload,
5255
if len(data) == 0 || data[len(data)-1] != c.marker {
5356
return nil, fmt.Errorf("appendCodec.Decode: expected trailing marker byte %d", c.marker)
5457
}
58+
meta := make(map[string][]byte, len(p.GetMetadata()))
59+
for k, v := range p.GetMetadata() {
60+
meta[k] = v
61+
}
62+
meta[converter.MetadataEncoding] = []byte(strings.TrimSuffix(enc, c.encodingSuffix))
5563
result[i] = &commonpb.Payload{
56-
Metadata: map[string][]byte{converter.MetadataEncoding: []byte(strings.TrimSuffix(enc, c.encodingSuffix))},
64+
Metadata: meta,
5765
Data: data[:len(data)-1],
5866
}
5967
}

internal/extstore/extstore_test.go

Lines changed: 187 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/stretchr/testify/require"
1414
commonpb "go.temporal.io/api/common/v1"
1515
"go.temporal.io/api/proxy"
16+
sdkpb "go.temporal.io/sdk/internal/temporalapi/sdk/v1"
1617
"google.golang.org/protobuf/encoding/protojson"
1718
"google.golang.org/protobuf/proto"
1819
)
@@ -210,20 +211,20 @@ func TestExternalStorageToParams_SingleDriverSynthesizesSelector(t *testing.T) {
210211
// ---------------------------------------------------------------------------
211212

212213
func TestStorageReferenceRoundTrip(t *testing.T) {
213-
ref := storageReference{
214-
DriverName: "mydriver",
215-
DriverClaim: StorageDriverClaim{ClaimData: map[string]string{"key": "abc123"}},
214+
ref := &sdkpb.ExternalStorageReference{
215+
DriverName: "mydriver",
216+
ClaimData: map[string]string{"key": "abc123"},
216217
}
217218
p, err := storageReferenceToPayload(ref, 512)
218219
require.NoError(t, err)
219-
require.Equal(t, metadataEncodingStorageRef, string(p.Metadata[metadataEncoding]))
220+
require.Equal(t, metadataEncodingProtoJSON, string(p.Metadata[metadataEncoding]))
220221
require.Len(t, p.ExternalPayloads, 1)
221222
require.Equal(t, int64(512), p.ExternalPayloads[0].SizeBytes)
222223

223224
decoded, err := payloadToStorageReference(p)
224225
require.NoError(t, err)
225226
require.Equal(t, ref.DriverName, decoded.DriverName)
226-
require.Equal(t, ref.DriverClaim.ClaimData, decoded.DriverClaim.ClaimData)
227+
require.Equal(t, ref.ClaimData, decoded.ClaimData)
227228
}
228229

229230
func TestPayloadToStorageReference_WrongEncoding(t *testing.T) {
@@ -237,13 +238,92 @@ func TestPayloadToStorageReference_WrongEncoding(t *testing.T) {
237238

238239
func TestPayloadToStorageReference_CorruptJSON(t *testing.T) {
239240
p := &commonpb.Payload{
240-
Metadata: map[string][]byte{metadataEncoding: []byte(metadataEncodingStorageRef)},
241+
Metadata: map[string][]byte{metadataEncoding: []byte(metadataEncodingStorageRefLegacy)},
241242
Data: []byte(`not json`),
242243
}
243244
_, err := payloadToStorageReference(p)
244245
require.Error(t, err)
245246
}
246247

248+
func TestPayloadToStorageReference_ProtoJSONFormat(t *testing.T) {
249+
want := &sdkpb.ExternalStorageReference{
250+
DriverName: "mydriver",
251+
ClaimData: map[string]string{"bucket": "b", "key": "k"},
252+
}
253+
p, err := storageReferenceToPayload(want, 0)
254+
require.NoError(t, err)
255+
require.Equal(t, metadataEncodingProtoJSON, string(p.Metadata[metadataEncoding]))
256+
require.Equal(t, externalStorageReferenceMessageType, string(p.Metadata[metadataMessageType]))
257+
258+
got, err := payloadToStorageReference(p)
259+
require.NoError(t, err)
260+
require.Equal(t, want.DriverName, got.DriverName)
261+
require.Equal(t, want.ClaimData, got.ClaimData)
262+
}
263+
264+
func TestPayloadToStorageReference_LegacyFormat(t *testing.T) {
265+
legacyData, err := json.Marshal(legacyStorageReference{
266+
DriverName: "mydriver",
267+
DriverClaim: StorageDriverClaim{ClaimData: map[string]string{"bucket": "b", "key": "k"}},
268+
})
269+
require.NoError(t, err)
270+
p := &commonpb.Payload{
271+
Metadata: map[string][]byte{metadataEncoding: []byte(metadataEncodingStorageRefLegacy)},
272+
Data: legacyData,
273+
}
274+
275+
got, err := payloadToStorageReference(p)
276+
require.NoError(t, err)
277+
require.Equal(t, "mydriver", got.DriverName)
278+
require.Equal(t, map[string]string{"bucket": "b", "key": "k"}, got.ClaimData)
279+
}
280+
281+
func TestPayloadToStorageReference_ProtoJSON_WrongMessageType(t *testing.T) {
282+
p := &commonpb.Payload{
283+
Metadata: map[string][]byte{
284+
metadataEncoding: []byte(metadataEncodingProtoJSON),
285+
metadataMessageType: []byte("some.other.MessageType"),
286+
},
287+
Data: []byte(`{}`),
288+
}
289+
_, err := payloadToStorageReference(p)
290+
require.Error(t, err)
291+
require.Contains(t, err.Error(), "some.other.MessageType")
292+
}
293+
294+
// ---------------------------------------------------------------------------
295+
// IsStorageReference
296+
// ---------------------------------------------------------------------------
297+
298+
func TestIsStorageReference_ProtoJSONFormat(t *testing.T) {
299+
p, err := storageReferenceToPayload(&sdkpb.ExternalStorageReference{DriverName: "d"}, 0)
300+
require.NoError(t, err)
301+
require.True(t, IsStorageReference(p))
302+
}
303+
304+
func TestIsStorageReference_LegacyFormat(t *testing.T) {
305+
p := &commonpb.Payload{
306+
Metadata: map[string][]byte{metadataEncoding: []byte(metadataEncodingStorageRefLegacy)},
307+
Data: []byte(`{"driver_name":"d","driver_claim":{}}`),
308+
}
309+
require.True(t, IsStorageReference(p))
310+
}
311+
312+
func TestIsStorageReference_ProtoJSON_WrongMessageType(t *testing.T) {
313+
p := &commonpb.Payload{
314+
Metadata: map[string][]byte{
315+
metadataEncoding: []byte(metadataEncodingProtoJSON),
316+
metadataMessageType: []byte("some.other.MessageType"),
317+
},
318+
}
319+
require.False(t, IsStorageReference(p))
320+
}
321+
322+
func TestIsStorageReference_NotStorageReference(t *testing.T) {
323+
require.False(t, IsStorageReference(makePayload(t, "hello")))
324+
require.False(t, IsStorageReference(&commonpb.Payload{}))
325+
}
326+
247327
// ---------------------------------------------------------------------------
248328
// storageStoreVisitor
249329
// ---------------------------------------------------------------------------
@@ -291,7 +371,7 @@ func TestStoreVisitor_AtThreshold_Stored(t *testing.T) {
291371

292372
result, err := visitPayloads(context.Background(), visitor, []*commonpb.Payload{p})
293373
require.NoError(t, err)
294-
require.Equal(t, metadataEncodingStorageRef, string(result[0].Metadata[metadataEncoding]))
374+
require.Equal(t, metadataEncodingProtoJSON, string(result[0].Metadata[metadataEncoding]))
295375
require.Equal(t, 1, driver.storeCount)
296376
}
297377

@@ -308,7 +388,7 @@ func TestStoreVisitor_AboveThreshold_Stored(t *testing.T) {
308388
p := makeOversizedPayload(t, threshold+1)
309389
result, err := visitPayloads(context.Background(), visitor, []*commonpb.Payload{p})
310390
require.NoError(t, err)
311-
require.Equal(t, metadataEncodingStorageRef, string(result[0].Metadata[metadataEncoding]))
391+
require.Equal(t, metadataEncodingProtoJSON, string(result[0].Metadata[metadataEncoding]))
312392
require.Equal(t, 1, driver.storeCount)
313393
}
314394

@@ -329,8 +409,8 @@ func TestStoreVisitor_MultiplePayloads_Batched(t *testing.T) {
329409
result, err := visitPayloads(context.Background(), visitor, []*commonpb.Payload{big1, small, big2})
330410
require.NoError(t, err)
331411
require.Len(t, result, 3)
332-
require.Equal(t, metadataEncodingStorageRef, string(result[0].Metadata[metadataEncoding]))
333-
require.Equal(t, metadataEncodingStorageRef, string(result[2].Metadata[metadataEncoding]))
412+
require.Equal(t, metadataEncodingProtoJSON, string(result[0].Metadata[metadataEncoding]))
413+
require.Equal(t, metadataEncodingProtoJSON, string(result[2].Metadata[metadataEncoding]))
334414
// small payload is inline
335415
require.Empty(t, result[1].ExternalPayloads)
336416
// both big payloads batched in a single Store call
@@ -663,9 +743,9 @@ func TestRetrievalVisitor_UnknownDriver(t *testing.T) {
663743
require.NoError(t, err)
664744
visitor := NewExternalRetrievalVisitor(params)
665745

666-
ref, err := storageReferenceToPayload(storageReference{
667-
DriverName: "unregistered-driver",
668-
DriverClaim: StorageDriverClaim{ClaimData: map[string]string{"key": "k"}},
746+
ref, err := storageReferenceToPayload(&sdkpb.ExternalStorageReference{
747+
DriverName: "unregistered-driver",
748+
ClaimData: map[string]string{"key": "k"},
669749
}, 10)
670750
require.NoError(t, err)
671751

@@ -702,9 +782,9 @@ func TestRetrievalVisitor_RetrievePanic(t *testing.T) {
702782
require.NoError(t, err)
703783
visitor := NewExternalRetrievalVisitor(params)
704784

705-
ref, err := storageReferenceToPayload(storageReference{
706-
DriverName: "my-panic-retrieve-driver",
707-
DriverClaim: StorageDriverClaim{ClaimData: map[string]string{"key": "k"}},
785+
ref, err := storageReferenceToPayload(&sdkpb.ExternalStorageReference{
786+
DriverName: "my-panic-retrieve-driver",
787+
ClaimData: map[string]string{"key": "k"},
708788
}, 10)
709789
require.NoError(t, err)
710790

@@ -732,9 +812,9 @@ func TestRetrievalVisitor_CancelOnError(t *testing.T) {
732812
require.NoError(t, err)
733813
errRef := refs[0]
734814

735-
blockRef, err := storageReferenceToPayload(storageReference{
736-
DriverName: "block-driver",
737-
DriverClaim: StorageDriverClaim{ClaimData: map[string]string{"key": "k"}},
815+
blockRef, err := storageReferenceToPayload(&sdkpb.ExternalStorageReference{
816+
DriverName: "block-driver",
817+
ClaimData: map[string]string{"key": "k"},
738818
}, 10)
739819
require.NoError(t, err)
740820

@@ -766,9 +846,9 @@ func TestRetrievalVisitor_WrongPayloadCount(t *testing.T) {
766846
require.NoError(t, err)
767847
visitor := NewExternalRetrievalVisitor(params)
768848

769-
ref, err := storageReferenceToPayload(storageReference{
770-
DriverName: "my-bad-count-driver",
771-
DriverClaim: StorageDriverClaim{ClaimData: map[string]string{"key": "k"}},
849+
ref, err := storageReferenceToPayload(&sdkpb.ExternalStorageReference{
850+
DriverName: "my-bad-count-driver",
851+
ClaimData: map[string]string{"key": "k"},
772852
}, 10)
773853
require.NoError(t, err)
774854

@@ -825,13 +905,52 @@ func TestRetrievalVisitor_Callback_ExternalCountOnly(t *testing.T) {
825905
}
826906

827907
// ---------------------------------------------------------------------------
828-
// Claim Compatibility: fixed claim JSON produced by another SDK
908+
// Claim Compatibility: legacy-format reference payload
909+
// ---------------------------------------------------------------------------
910+
911+
// TestRetrievalVisitor_LegacyFormat verifies that the retrieval visitor can
912+
// resolve a payload written in the legacy json/external-storage-reference
913+
// format (as written by earlier prerelease SDK versions).
914+
func TestRetrievalVisitor_LegacyFormat(t *testing.T) {
915+
driver := newTestDriver("d")
916+
params, err := ExternalStorageToParams(ExternalStorage{
917+
Drivers: []StorageDriver{driver},
918+
PayloadSizeThreshold: 1,
919+
})
920+
require.NoError(t, err)
921+
922+
// Store a payload to get a real claim key in the driver.
923+
original := makePayload(t, "legacy-compat-value")
924+
refs, err := visitPayloads(context.Background(), NewExternalStorageVisitor(params), []*commonpb.Payload{original})
925+
require.NoError(t, err)
926+
927+
// Extract the claim data from the new-format reference and rebuild it as a
928+
// legacy-format payload (encoding=json/external-storage-reference, old JSON structure).
929+
newRef, err := payloadToStorageReference(refs[0])
930+
require.NoError(t, err)
931+
legacyData, err := json.Marshal(legacyStorageReference{
932+
DriverName: newRef.DriverName,
933+
DriverClaim: StorageDriverClaim{ClaimData: newRef.ClaimData},
934+
})
935+
require.NoError(t, err)
936+
legacyPayload := &commonpb.Payload{
937+
Metadata: map[string][]byte{metadataEncoding: []byte(metadataEncodingStorageRefLegacy)},
938+
Data: legacyData,
939+
}
940+
941+
result, err := visitPayloads(context.Background(), NewExternalRetrievalVisitor(params), []*commonpb.Payload{legacyPayload})
942+
require.NoError(t, err)
943+
require.True(t, proto.Equal(original, result[0]))
944+
}
945+
946+
// ---------------------------------------------------------------------------
947+
// Claim Compatibility: legacy-format fixed claim JSON produced by another SDK
829948
// ---------------------------------------------------------------------------
830949

831-
// TestClaimDeserialization verifies that a full proto-JSON payload
832-
// produced by another language SDK (e.g. Python) is correctly parsed by
950+
// TestClaimDeserialization_PlainJson_OtherSdk verifies that a full plain JSON
951+
// payload produced by another language SDK (e.g. Python) is correctly parsed by
833952
// the Go SDK's payloadToStorageReference function.
834-
func TestClaimDeserialization(t *testing.T) {
953+
func TestClaimDeserialization_PlainJson_OtherSdk(t *testing.T) {
835954
// Full proto-JSON representation of a storage-reference payload as another
836955
// SDK would serialize it onto the wire.
837956
const rawPayloadJSON = `{
@@ -851,12 +970,50 @@ func TestClaimDeserialization(t *testing.T) {
851970

852971
ref, err := payloadToStorageReference(refPayload)
853972
require.NoError(t, err)
854-
require.Equal(t, metadataEncodingStorageRef, string(refPayload.GetMetadata()[metadataEncoding]))
973+
require.Equal(t, metadataEncodingStorageRefLegacy, string(refPayload.GetMetadata()[metadataEncoding]))
855974
require.Equal(t, 1, len(refPayload.ExternalPayloads))
856975
require.Equal(t, int64(385), refPayload.ExternalPayloads[0].SizeBytes)
857976
require.Equal(t, "temporalio:driver:s3", ref.DriverName)
858-
require.Equal(t, "test-bucket", ref.DriverClaim.ClaimData["bucket"])
859-
require.Equal(t, "/ns/default/wi/13f3d9cf-1705-4ce1-b3cb-370974a482c7/d/sha256/6ca22c34560cf35ac24427dc7619c9ab472a82cf18f286f27871649a2b5608c8", ref.DriverClaim.ClaimData["key"])
977+
require.Equal(t, "test-bucket", ref.ClaimData["bucket"])
978+
require.Equal(t, "/ns/default/wi/13f3d9cf-1705-4ce1-b3cb-370974a482c7/d/sha256/6ca22c34560cf35ac24427dc7619c9ab472a82cf18f286f27871649a2b5608c8", ref.ClaimData["key"])
979+
}
980+
981+
// ---------------------------------------------------------------------------
982+
// Claim Compatibility: current-format fixed claim JSON produced by another SDK
983+
// ---------------------------------------------------------------------------
984+
985+
// TestClaimDeserialization_OtherSdk_ProtoJson verifies that a full proto-JSON
986+
// payload produced by another language SDK (e.g. Python) is correctly parsed by
987+
// the Go SDK's payloadToStorageReference function.
988+
func TestClaimDeserialization_OtherSdk_ProtoJson(t *testing.T) {
989+
// Full proto-JSON representation of a storage-reference payload as another
990+
// SDK would serialize it onto the wire.
991+
const rawPayloadJSON = `{
992+
"metadata": {
993+
"encoding": "anNvbi9wcm90b2J1Zg==",
994+
"messageType": "dGVtcG9yYWwuYXBpLnNkay52MS5FeHRlcm5hbFN0b3JhZ2VSZWZlcmVuY2U="
995+
},
996+
"data": "eyJjbGFpbURhdGEiOnsiYnVja2V0IjoidGVzdC1idWNrZXQiLCJoYXNoX2FsZ29yaXRobSI6InNoYTI1NiIsImhhc2hfdmFsdWUiOiI2Y2EyMmMzNDU2MGNmMzVhYzI0NDI3ZGM3NjE5YzlhYjQ3MmE4MmNmMThmMjg2ZjI3ODcxNjQ5YTJiNTYwOGM4Iiwia2V5IjoidjAvbnMvZGVmYXVsdC93dC9MYXJnZUlPV29ya2Zsb3cvd2kvZjFkMmE0YWMtZjhjYi00NWQzLTkwOGMtOTNhMGYzM2FiMjQ1L3JpL251bGwvZC9zaGEyNTYvNmNhMjJjMzQ1NjBjZjM1YWMyNDQyN2RjNzYxOWM5YWI0NzJhODJjZjE4ZjI4NmYyNzg3MTY0OWEyYjU2MDhjOCJ9LCJkcml2ZXJOYW1lIjoiYXdzLnMzZHJpdmVyIn0=",
997+
"externalPayloads": [
998+
{
999+
"sizeBytes": "385"
1000+
}
1001+
]
1002+
}`
1003+
1004+
externalStorageReferenceMessageType = string((*sdkpb.ExternalStorageReference)(nil).ProtoReflect().Descriptor().FullName())
1005+
1006+
refPayload := &commonpb.Payload{}
1007+
require.NoError(t, protojson.Unmarshal([]byte(rawPayloadJSON), refPayload))
1008+
1009+
ref, err := payloadToStorageReference(refPayload)
1010+
require.NoError(t, err)
1011+
require.Equal(t, metadataEncodingProtoJSON, string(refPayload.GetMetadata()[metadataEncoding]))
1012+
require.Equal(t, externalStorageReferenceMessageType, string(refPayload.GetMetadata()[metadataMessageType]))
1013+
require.Equal(t, 1, len(refPayload.ExternalPayloads))
1014+
require.Equal(t, int64(385), refPayload.ExternalPayloads[0].SizeBytes)
1015+
require.Equal(t, "aws.s3driver", ref.DriverName)
1016+
require.Equal(t, "v0/ns/default/wt/LargeIOWorkflow/wi/f1d2a4ac-f8cb-45d3-908c-93a0f33ab245/ri/null/d/sha256/6ca22c34560cf35ac24427dc7619c9ab472a82cf18f286f27871649a2b5608c8", ref.ClaimData["key"])
8601017
}
8611018

8621019
// ---------------------------------------------------------------------------
@@ -876,7 +1033,7 @@ func TestStoreRetrieveRoundTrip_Single(t *testing.T) {
8761033
original := makePayload(t, "round-trip value")
8771034
refs, err := visitPayloads(context.Background(), storeVisitor, []*commonpb.Payload{original})
8781035
require.NoError(t, err)
879-
require.Equal(t, metadataEncodingStorageRef, string(refs[0].Metadata[metadataEncoding]))
1036+
require.Equal(t, metadataEncodingProtoJSON, string(refs[0].Metadata[metadataEncoding]))
8801037

8811038
restored, err := visitPayloads(context.Background(), retrieveVisitor, refs)
8821039
require.NoError(t, err)

0 commit comments

Comments
 (0)