Skip to content

Commit 50f0d3e

Browse files
authored
Merge branch 'main' into chore/update-uv-lock-deps
2 parents a92060e + f6e113d commit 50f0d3e

8 files changed

Lines changed: 159 additions & 9 deletions

File tree

temporalio/client.py

Lines changed: 53 additions & 0 deletions
Large diffs are not rendered by default.

temporalio/contrib/aws/s3driver/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ from temporalio.contrib.aws.s3driver.aioboto3 import new_aioboto3_client
2323
from temporalio.converter import DataConverter, ExternalStorage
2424

2525
session = aioboto3.Session()
26-
# Credentials and region are resolved automatically from the standard AWS credential
27-
# chain e.g. environment variables, ~/.aws/config, IAM instance profile, and so on.
26+
# To see how to set credentials and region via environment, config objects, or configuration files,
27+
# see:
28+
# https://docs.aws.amazon.com/boto3/latest/guide/configuration.html
2829
async with session.client("s3") as s3_client:
2930
driver = S3StorageDriver(
3031
client=new_aioboto3_client(s3_client),

temporalio/contrib/aws/s3driver/_client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88

99
from abc import ABC, abstractmethod
10+
from collections.abc import Mapping
1011

1112

1213
class S3StorageDriverClient(ABC):
@@ -30,3 +31,11 @@ async def object_exists(self, *, bucket: str, key: str) -> bool:
3031
@abstractmethod
3132
async def get_object(self, *, bucket: str, key: str) -> bytes:
3233
"""Download and return the bytes stored at the given S3 *bucket* and *key*."""
34+
35+
def describe(self) -> Mapping[str, str]:
36+
"""Return client-specific diagnostic metadata (e.g. region, credentials
37+
source) that the driver appends to error messages. Implementations may
38+
override this to surface configuration that is useful for debugging
39+
common misconfigurations. Returns an empty mapping by default.
40+
"""
41+
return {}

temporalio/contrib/aws/s3driver/_driver.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@
2626
_T = TypeVar("_T")
2727

2828

29+
def _format_client_context(client: S3StorageDriverClient) -> str:
30+
"""Format the client's ``describe()`` output as ", k=v, k=v" for error
31+
messages. Returns an empty string when the client reports no metadata or
32+
describe itself raises (diagnostic output must never mask the real error).
33+
"""
34+
try:
35+
info = client.describe()
36+
except Exception:
37+
return ""
38+
if not info:
39+
return ""
40+
return "".join(f", {k}={v}" for k, v in info.items())
41+
42+
2943
async def _gather_with_cancellation(
3044
coros: Sequence[Coroutine[Any, Any, _T]],
3145
) -> list[_T]:
@@ -156,7 +170,8 @@ async def _upload(payload: Payload) -> StorageDriverClaim:
156170
)
157171
except Exception as e:
158172
raise RuntimeError(
159-
f"S3StorageDriver store failed [bucket={bucket}, key={key}]"
173+
f"S3StorageDriver store failed [bucket={bucket}, key={key}"
174+
f"{_format_client_context(self._client)}]"
160175
) from e
161176

162177
return StorageDriverClaim(
@@ -185,7 +200,8 @@ async def _download(claim: StorageDriverClaim) -> Payload:
185200
payload_bytes = await self._client.get_object(bucket=bucket, key=key)
186201
except Exception as e:
187202
raise RuntimeError(
188-
f"S3StorageDriver retrieve failed [bucket={bucket}, key={key}]"
203+
f"S3StorageDriver retrieve failed [bucket={bucket}, key={key}"
204+
f"{_format_client_context(self._client)}]"
189205
) from e
190206

191207
hash_algorithm = claim.claim_data.get("hash_algorithm")

temporalio/contrib/aws/s3driver/aioboto3.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88

99
import io
10+
from collections.abc import Mapping
1011

1112
from botocore.exceptions import ClientError
1213
from types_aiobotocore_s3.client import S3Client
@@ -34,6 +35,13 @@ def __init__(self, client: S3Client) -> None:
3435
"""
3536
self._client = client
3637

38+
def describe(self) -> Mapping[str, str]:
39+
"""Region of the wrapped aioboto3 client, surfaced in driver error
40+
messages to short-circuit the most common silent 403 misconfiguration.
41+
"""
42+
region = self._client.meta.region_name
43+
return {"region": region} if region else {}
44+
3745
async def object_exists(self, *, bucket: str, key: str) -> bool:
3846
"""Check existence via aioboto3's ``head_object``."""
3947
try:

tests/contrib/aws/s3driver/test_s3driver.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
S3StorageDriver,
2727
S3StorageDriverClient,
2828
)
29+
from temporalio.contrib.aws.s3driver._driver import _format_client_context
30+
from temporalio.contrib.aws.s3driver.aioboto3 import _Aioboto3StorageDriverClient
2931
from temporalio.converter import (
3032
JSONPlainPayloadConverter,
3133
StorageDriverActivityInfo,
@@ -34,7 +36,7 @@
3436
StorageDriverStoreContext,
3537
StorageDriverWorkflowInfo,
3638
)
37-
from tests.contrib.aws.s3driver.conftest import BUCKET
39+
from tests.contrib.aws.s3driver.conftest import BUCKET, REGION
3840

3941
_CONVERTER = JSONPlainPayloadConverter()
4042

@@ -618,7 +620,7 @@ async def test_store_nonexistent_bucket_raises(
618620
await driver.store(make_store_context(), [payload])
619621
assert (
620622
str(exc_info.value)
621-
== f"S3StorageDriver store failed [bucket={bucket}, key={expected_key}]"
623+
== f"S3StorageDriver store failed [bucket={bucket}, key={expected_key}, region={REGION}]"
622624
)
623625
assert isinstance(exc_info.value.__cause__, ClientError)
624626
assert (
@@ -636,7 +638,7 @@ async def test_retrieve_nonexistent_key_raises(
636638
await driver.retrieve(StorageDriverRetrieveContext(), [claim])
637639
assert (
638640
str(exc_info.value)
639-
== f"S3StorageDriver retrieve failed [bucket={BUCKET}, key={key}]"
641+
== f"S3StorageDriver retrieve failed [bucket={BUCKET}, key={key}, region={REGION}]"
640642
)
641643
assert isinstance(exc_info.value.__cause__, ClientError)
642644
assert (
@@ -655,7 +657,7 @@ async def test_retrieve_nonexistent_bucket_raises(
655657
await driver.retrieve(StorageDriverRetrieveContext(), [claim])
656658
assert (
657659
str(exc_info.value)
658-
== f"S3StorageDriver retrieve failed [bucket={bucket}, key={key}]"
660+
== f"S3StorageDriver retrieve failed [bucket={bucket}, key={key}, region={REGION}]"
659661
)
660662
assert isinstance(exc_info.value.__cause__, ClientError)
661663
assert (
@@ -839,3 +841,49 @@ async def test_retrieve_cancels_remaining_on_failure(
839841
assert (
840842
len(faulty_client.cancelled) == 2
841843
), "Expected 2 remaining tasks to be cancelled"
844+
845+
846+
# ---------------------------------------------------------------------------
847+
# TestAioboto3StorageDriverClientDescribe
848+
# ---------------------------------------------------------------------------
849+
850+
851+
class TestAioboto3StorageDriverClientDescribe:
852+
def _make_client(self, region: str | None) -> _Aioboto3StorageDriverClient:
853+
mock_s3 = MagicMock()
854+
mock_s3.meta.region_name = region
855+
return _Aioboto3StorageDriverClient(mock_s3)
856+
857+
def test_returns_region(self) -> None:
858+
client = self._make_client(region="ap-southeast-1")
859+
assert client.describe() == {"region": "ap-southeast-1"}
860+
861+
def test_omits_region_when_none(self) -> None:
862+
client = self._make_client(region=None)
863+
assert client.describe() == {}
864+
865+
def test_omits_region_when_empty_string(self) -> None:
866+
client = self._make_client(region="")
867+
assert client.describe() == {}
868+
869+
870+
# ---------------------------------------------------------------------------
871+
# TestFormatClientContext
872+
# ---------------------------------------------------------------------------
873+
874+
875+
class TestFormatClientContext:
876+
def test_formats_entry(self) -> None:
877+
client = MagicMock(spec=S3StorageDriverClient)
878+
client.describe.return_value = {"region": "us-east-1"}
879+
assert _format_client_context(client) == ", region=us-east-1"
880+
881+
def test_returns_empty_string_for_empty_describe(self) -> None:
882+
client = MagicMock(spec=S3StorageDriverClient)
883+
client.describe.return_value = {}
884+
assert _format_client_context(client) == ""
885+
886+
def test_returns_empty_string_when_describe_raises(self) -> None:
887+
client = MagicMock(spec=S3StorageDriverClient)
888+
client.describe.side_effect = RuntimeError("oops")
889+
assert _format_client_context(client) == ""

tests/contrib/aws/s3driver/test_s3driver_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,4 +506,4 @@ async def test_s3_store_failure_surfaces_in_workflow_history(
506506
msg = app_error.message
507507
assert f"S3StorageDriver store failed [bucket={bad_bucket}, key=" in msg
508508
assert f"/wt/LargeOutputNoRetryWorkflow/wi/{workflow_id}/ri/" in msg
509-
assert f"/d/sha256/{expected_hash}]" in msg
509+
assert f"/d/sha256/{expected_hash}, region={REGION}]" in msg

tests/test_activity.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,20 +220,23 @@ async def test_start_activity_calls_interceptor(
220220

221221
activity_id = str(uuid.uuid4())
222222
task_queue = str(uuid.uuid4())
223+
start_delay = timedelta(seconds=3)
223224

224225
await intercepted_client.start_activity(
225226
increment,
226227
args=(1,),
227228
id=activity_id,
228229
task_queue=task_queue,
229230
start_to_close_timeout=timedelta(seconds=5),
231+
start_delay=start_delay,
230232
)
231233

232234
assert len(interceptor.start_activity_calls) == 1
233235
call = interceptor.start_activity_calls[0]
234236
assert call.id == activity_id
235237
assert call.task_queue == task_queue
236238
assert call.activity_type == "increment"
239+
assert call.start_delay == start_delay
237240

238241

239242
async def test_describe_activity_calls_interceptor(
@@ -413,6 +416,18 @@ async def test_count_activities_calls_interceptor(
413416
assert call.query == query
414417

415418

419+
async def test_start_activity_rejects_negative_start_delay(client: Client):
420+
with pytest.raises(ValueError, match="start_delay must be non-negative"):
421+
await client.start_activity(
422+
increment,
423+
args=(1,),
424+
id=str(uuid.uuid4()),
425+
task_queue=str(uuid.uuid4()),
426+
start_to_close_timeout=timedelta(seconds=5),
427+
start_delay=timedelta(seconds=-1),
428+
)
429+
430+
416431
async def test_get_result(client: Client, env: WorkflowEnvironment):
417432
if env.supports_time_skipping:
418433
pytest.skip(

0 commit comments

Comments
 (0)