Skip to content

Commit bc43d80

Browse files
committed
Replace _StorageReference with ExternalStorageReference proto
1 parent 776bae2 commit bc43d80

4 files changed

Lines changed: 205 additions & 89 deletions

File tree

temporalio/converter/_data_converter.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import dataclasses
6-
import warnings
76
from collections.abc import Mapping, Sequence
87
from dataclasses import dataclass
98
from logging import getLogger
@@ -18,7 +17,6 @@
1817
_REFERENCE_ENCODING,
1918
ExternalStorage,
2019
StorageDriverStoreContext,
21-
StorageWarning,
2220
)
2321
from temporalio.converter._failure_converter import (
2422
FailureConverter,
@@ -41,6 +39,20 @@
4139
WithSerializationContext,
4240
)
4341

42+
43+
def _is_reference_payload(p: temporalio.api.common.v1.Payload) -> bool:
44+
"""Return True if *p* is an external-storage reference payload.
45+
46+
Covers both the legacy ``json/external-storage-reference`` encoding and the
47+
current proto-based format, which uses ``json/protobuf`` encoding with the
48+
``external_payloads`` repeated field set.
49+
"""
50+
return (
51+
p.metadata.get("encoding") == _REFERENCE_ENCODING
52+
or len(p.external_payloads) > 0
53+
)
54+
55+
4456
# Import defaults from public API to avoid pydoctor cross-reference issues
4557
if TYPE_CHECKING:
4658
from temporalio.converter import DefaultFailureConverter, DefaultPayloadConverter
@@ -307,13 +319,9 @@ async def _transform_inbound_payloads(
307319
if self.external_storage:
308320
await self.external_storage._retrieve_payloads(payloads)
309321
else:
310-
if any(
311-
p.metadata.get("encoding") == _REFERENCE_ENCODING
312-
for p in payloads.payloads
313-
):
314-
warnings.warn(
315-
"[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.",
316-
StorageWarning,
322+
if any(_is_reference_payload(p) for p in payloads.payloads):
323+
raise RuntimeError(
324+
"[TMPRL1105] Detected externally stored payload(s) but external storage is not configured."
317325
)
318326
if self.payload_codec:
319327
await self.payload_codec.decode_wrapper(payloads)
@@ -348,13 +356,9 @@ async def _external_retrieve_payload_sequence(
348356
retrieved_payloads
349357
)
350358
else:
351-
if any(
352-
p.metadata.get("encoding") == _REFERENCE_ENCODING
353-
for p in retrieved_payloads
354-
):
355-
warnings.warn(
356-
"[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.",
357-
StorageWarning,
359+
if any(_is_reference_payload(p) for p in retrieved_payloads):
360+
raise RuntimeError(
361+
"[TMPRL1105] Detected externally stored payload(s) but external storage is not configured."
358362
)
359363
return retrieved_payloads
360364

temporalio/converter/_extstore.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
from typing_extensions import Self
1919

2020
from temporalio.api.common.v1 import Payload, Payloads
21-
from temporalio.converter._payload_converter import JSONPlainPayloadConverter
21+
from temporalio.api.sdk.v1.external_storage_pb2 import ExternalStorageReference
22+
from temporalio.converter._payload_converter import (
23+
JSONPlainPayloadConverter,
24+
JSONProtoPayloadConverter,
25+
)
2226

2327
_T = TypeVar("_T")
2428

@@ -225,6 +229,11 @@ class StorageWarning(RuntimeWarning):
225229

226230
@dataclass(frozen=True)
227231
class _StorageReference:
232+
"""Legacy external storage reference used only on the retrieval path as a
233+
fallback for in-flight workflows that were written before the
234+
ExternalStorageReference proto was introduced.
235+
"""
236+
228237
driver_name: str
229238
driver_claim: StorageDriverClaim
230239

@@ -278,8 +287,9 @@ class ExternalStorage:
278287
)
279288
"""Store context bound to this instance via :meth:`_with_store_context`."""
280289

281-
_claim_converter: ClassVar[JSONPlainPayloadConverter] = JSONPlainPayloadConverter(
282-
encoding=_REFERENCE_ENCODING.decode()
290+
_claim_converter: ClassVar[JSONProtoPayloadConverter] = JSONProtoPayloadConverter()
291+
_legacy_claim_converter: ClassVar[JSONPlainPayloadConverter] = (
292+
JSONPlainPayloadConverter(encoding=_REFERENCE_ENCODING.decode())
283293
)
284294

285295
def __post_init__(self) -> None:
@@ -357,9 +367,9 @@ async def _store_payload(self, payload: Payload) -> Payload:
357367
self._validate_claim_length(claims, expected=1, driver=driver)
358368

359369
external_size = payload.ByteSize()
360-
reference = _StorageReference(
370+
reference = ExternalStorageReference(
361371
driver_name=driver.name(),
362-
driver_claim=claims[0],
372+
claim_data=claims[0].claim_data,
363373
)
364374
reference_payload = self._claim_converter.to_payload(reference)
365375
if reference_payload is None:
@@ -421,9 +431,9 @@ async def _store_payload_sequence(
421431
self._validate_claim_length(claims, expected=len(indices), driver=driver)
422432

423433
for i, claim in enumerate(claims):
424-
reference = _StorageReference(
434+
reference = ExternalStorageReference(
425435
driver_name=driver.name(),
426-
driver_claim=claim,
436+
claim_data=claim.claim_data,
427437
)
428438
reference_payload = self._claim_converter.to_payload(reference)
429439
if reference_payload is None:
@@ -443,20 +453,35 @@ async def _store_payload_sequence(
443453

444454
return results
445455

446-
async def _retrieve_payload(self, payload: Payload) -> Payload:
456+
def _decode_reference(self, payload: Payload) -> ExternalStorageReference | None:
457+
"""Decode an external storage reference from a payload."""
447458
if len(payload.external_payloads) == 0:
448-
return payload
449-
450-
start_time = time.monotonic()
459+
return None
460+
encoding = payload.metadata.get("encoding", b"")
461+
if encoding == _REFERENCE_ENCODING:
462+
legacy = self._legacy_claim_converter.from_payload(
463+
payload, _StorageReference
464+
)
465+
if not isinstance(legacy, _StorageReference):
466+
return None
467+
return ExternalStorageReference(
468+
driver_name=legacy.driver_name,
469+
claim_data=legacy.driver_claim.claim_data,
470+
)
471+
ref = self._claim_converter.from_payload(payload, ExternalStorageReference)
472+
return ref if isinstance(ref, ExternalStorageReference) else None
451473

452-
reference = self._claim_converter.from_payload(payload, _StorageReference)
453-
if not isinstance(reference, _StorageReference):
474+
async def _retrieve_payload(self, payload: Payload) -> Payload:
475+
ref = self._decode_reference(payload)
476+
if ref is None:
454477
return payload
455478

456-
driver = self._get_driver_by_name(reference.driver_name)
479+
start_time = time.monotonic()
480+
driver = self._get_driver_by_name(ref.driver_name)
457481
context = StorageDriverRetrieveContext()
482+
claim = StorageDriverClaim(claim_data=dict(ref.claim_data))
458483

459-
stored_payloads = await driver.retrieve(context, [reference.driver_claim])
484+
stored_payloads = await driver.retrieve(context, [claim])
460485

461486
self._validate_payload_length(stored_payloads, expected=1, driver=driver)
462487

@@ -486,15 +511,12 @@ async def _retrieve_payload_sequence(
486511

487512
driver_claims: dict[StorageDriver, list[tuple[int, StorageDriverClaim]]] = {}
488513
for index, payload in enumerate(payloads):
489-
if len(payload.external_payloads) == 0:
514+
ref = self._decode_reference(payload)
515+
if ref is None:
490516
continue
491-
492-
reference = self._claim_converter.from_payload(payload, _StorageReference)
493-
if not isinstance(reference, _StorageReference):
494-
continue
495-
496-
driver = self._get_driver_by_name(reference.driver_name)
497-
driver_claims.setdefault(driver, []).append((index, reference.driver_claim))
517+
driver = self._get_driver_by_name(ref.driver_name)
518+
claim = StorageDriverClaim(claim_data=dict(ref.claim_data))
519+
driver_claims.setdefault(driver, []).append((index, claim))
498520

499521
if not driver_claims:
500522
return results

tests/test_extstore.py

Lines changed: 124 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
from temporalio.api.common.v1 import Payload
9+
from temporalio.api.sdk.v1.external_storage_pb2 import ExternalStorageReference
910
from temporalio.converter import (
1011
DataConverter,
1112
ExternalStorage,
@@ -16,9 +17,26 @@
1617
StorageDriverRetrieveContext,
1718
StorageDriverStoreContext,
1819
)
19-
from temporalio.converter._extstore import _StorageReference
20+
from temporalio.converter._extstore import _REFERENCE_ENCODING, _StorageReference
21+
from temporalio.converter._payload_converter import JSONProtoPayloadConverter
2022
from temporalio.exceptions import ApplicationError
2123

24+
_legacy_ref_converter = JSONPlainPayloadConverter(encoding=_REFERENCE_ENCODING.decode())
25+
26+
27+
def _make_legacy_payload(
28+
driver_name: str, claim_data: dict[str, str], size_bytes: int
29+
) -> Payload:
30+
"""Build a reference payload in the legacy ``json/external-storage-reference`` format."""
31+
ref = _StorageReference(
32+
driver_name=driver_name,
33+
driver_claim=StorageDriverClaim(claim_data=claim_data),
34+
)
35+
payload = _legacy_ref_converter.to_payload(ref)
36+
assert payload is not None
37+
payload.external_payloads.add().size_bytes = size_bytes
38+
return payload
39+
2240

2341
class InMemoryTestDriver(StorageDriver):
2442
"""In-memory storage driver for testing."""
@@ -115,33 +133,27 @@ async def test_extstore_encode_decode(self):
115133
assert driver._retrieve_calls == 1
116134

117135
async def test_extstore_reference_structure(self):
118-
"""Test that external storage creates proper reference structure."""
136+
"""Externalized payloads are written as ExternalStorageReference proto (json/protobuf encoding)."""
119137
converter = DataConverter(
120138
external_storage=ExternalStorage(
121139
drivers=[InMemoryTestDriver("test-driver")],
122140
payload_size_threshold=50,
123141
)
124142
)
125143

126-
# Create large payload
127144
large_value = "x" * 100
128145
encoded = await converter.encode([large_value])
129146

130-
# Verify reference structure
131147
reference_payload = encoded[0]
132148
assert len(reference_payload.external_payloads) > 0
149+
assert reference_payload.metadata.get("encoding") == b"json/protobuf"
133150

134-
# The payload should contain a serialized _ExternalStorageReference
135-
# Deserialize it to verify structure using the same encoding
136-
claim_converter = JSONPlainPayloadConverter(
137-
encoding="json/external-storage-reference"
151+
reference = JSONProtoPayloadConverter().from_payload(
152+
reference_payload, ExternalStorageReference
138153
)
139-
reference = claim_converter.from_payload(reference_payload, _StorageReference)
140-
141-
assert isinstance(reference, _StorageReference)
142-
assert "test-driver" == reference.driver_name
143-
assert isinstance(reference.driver_claim, StorageDriverClaim)
144-
assert "key" in reference.driver_claim.claim_data
154+
assert isinstance(reference, ExternalStorageReference)
155+
assert reference.driver_name == "test-driver"
156+
assert "key" in reference.claim_data
145157

146158
async def test_extstore_composite_conditional(self):
147159
"""Test using multiple drivers based on size."""
@@ -482,9 +494,10 @@ async def test_selector_always_first_driver_handles_all_stores(self):
482494
assert second._store_calls == 0
483495

484496
# The reference in history names the first driver.
485-
ref = JSONPlainPayloadConverter(
486-
encoding="json/external-storage-reference"
487-
).from_payload(encoded[0], _StorageReference)
497+
ref = JSONProtoPayloadConverter().from_payload(
498+
encoded[0], ExternalStorageReference
499+
)
500+
assert isinstance(ref, ExternalStorageReference)
488501
assert ref.driver_name == "driver-first"
489502

490503
# Retrieval also goes to the first driver.
@@ -694,5 +707,99 @@ def test_negative_payload_size_threshold_raises(self, threshold: int):
694707
)
695708

696709

710+
class TestBackwardCompat:
711+
"""Tests that the retrieval path handles the legacy ``json/external-storage-reference``
712+
format for in-flight workflows written before the ExternalStorageReference proto."""
713+
714+
async def test_legacy_format_single_payload_decode(self):
715+
"""A single payload in the legacy reference format is retrieved correctly."""
716+
driver = InMemoryTestDriver()
717+
718+
inner_payload = (await DataConverter().encode(["x" * 200]))[0]
719+
stored_key = "payload-0"
720+
driver._storage[stored_key] = inner_payload.SerializeToString()
721+
722+
legacy_payload = _make_legacy_payload(
723+
driver_name=driver.name(),
724+
claim_data={"key": stored_key},
725+
size_bytes=inner_payload.ByteSize(),
726+
)
727+
728+
converter = DataConverter(
729+
external_storage=ExternalStorage(
730+
drivers=[driver],
731+
payload_size_threshold=100,
732+
)
733+
)
734+
decoded = await converter.decode([legacy_payload], [str])
735+
assert decoded[0] == "x" * 200
736+
assert driver._retrieve_calls == 1
737+
738+
async def test_legacy_and_new_format_mixed_batch_decode(self):
739+
"""A batch containing legacy-format, new proto-format, and inline payloads
740+
all decode correctly in a single call."""
741+
driver = InMemoryTestDriver()
742+
converter = DataConverter(
743+
external_storage=ExternalStorage(
744+
drivers=[driver],
745+
payload_size_threshold=50,
746+
)
747+
)
748+
749+
new_value = "new-format-value" * 20
750+
inline_value = "small"
751+
encoded = await converter.encode([new_value, inline_value])
752+
new_format_payload = encoded[0]
753+
inline_payload = encoded[1]
754+
assert driver._store_calls == 1
755+
756+
legacy_value = "legacy-format-value" * 20
757+
legacy_inner = (await DataConverter().encode([legacy_value]))[0]
758+
stored_key = f"payload-{len(driver._storage)}"
759+
driver._storage[stored_key] = legacy_inner.SerializeToString()
760+
legacy_payload = _make_legacy_payload(
761+
driver_name=driver.name(),
762+
claim_data={"key": stored_key},
763+
size_bytes=legacy_inner.ByteSize(),
764+
)
765+
766+
decoded = await converter.decode(
767+
[legacy_payload, new_format_payload, inline_payload], [str, str, str]
768+
)
769+
assert decoded[0] == legacy_value
770+
assert decoded[1] == new_value
771+
assert decoded[2] == inline_value
772+
# Both external payloads share the same driver and are batched into one retrieve call.
773+
assert driver._retrieve_calls == 1
774+
775+
async def test_new_format_encode_round_trips(self):
776+
"""Payloads written with the new ExternalStorageReference format round-trip
777+
correctly and carry the expected proto encoding."""
778+
driver = InMemoryTestDriver()
779+
converter = DataConverter(
780+
external_storage=ExternalStorage(
781+
drivers=[driver],
782+
payload_size_threshold=50,
783+
)
784+
)
785+
786+
value = "round-trip-value" * 20
787+
encoded = await converter.encode([value])
788+
ref_payload = encoded[0]
789+
790+
assert ref_payload.metadata.get("encoding") == b"json/protobuf"
791+
assert len(ref_payload.external_payloads) > 0
792+
793+
ref = JSONProtoPayloadConverter().from_payload(
794+
ref_payload, ExternalStorageReference
795+
)
796+
assert isinstance(ref, ExternalStorageReference)
797+
assert ref.driver_name == driver.name()
798+
assert "key" in ref.claim_data
799+
800+
decoded = await converter.decode(encoded, [str])
801+
assert decoded[0] == value
802+
803+
697804
if __name__ == "__main__":
698805
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)