Skip to content

Commit 8c6fab6

Browse files
committed
Merge remote-tracking branch 'origin/main' into fix-swallowed-cancel
2 parents 2cd3c5c + 58f2a68 commit 8c6fab6

19 files changed

Lines changed: 555 additions & 89 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
__pycache__
33
/build
44
/dist
5+
temporalio/bridge/libtemporal_sdk_bridge.dylib.dSYM/
56
temporalio/bridge/target/
67
temporalio/bridge/temporal_sdk_bridge*
78
/tests/helpers/golangserver/golangserver

README.md

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,28 +2059,33 @@ The environment is now ready to develop in.
20592059

20602060
#### Testing
20612061

2062-
To execute tests:
2062+
To execute tests (in parallel if possible):
20632063

20642064
```bash
20652065
poe test
20662066
```
20672067

2068-
`poe test` spreads tests across multiple worker processes by default. If you
2069-
need a serial run for debugging, invoke pytest directly:
2068+
To execute tests serially:
20702069

20712070
```bash
20722071
uv run pytest
20732072
```
20742073

2075-
This runs against [Temporalite](https://github.com/temporalio/temporalite). To run against the time-skipping test
2076-
server, pass `--workflow-environment time-skipping`. To run against the `default` namespace of an already-running
2077-
server, pass the `host:port` to `--workflow-environment`. Can also use regular pytest arguments. For example, here's how
2078-
to run a single test with debug logs on the console:
2074+
To execute a single test:
20792075

20802076
```bash
20812077
poe test -s --log-cli-level=DEBUG -k test_sync_activity_thread_cancel_caught
20822078
```
20832079

2080+
**Temporal Server**
2081+
2082+
- Tests that use the workflow test environment run against the [Temporal CLI dev server](https://docs.temporal.io/cli#start-dev-server).
2083+
- By default, workflow-environment tests automatically start a local dev server.
2084+
- On first run, the dev server binary may be downloaded so network access is required if no server is currently running.
2085+
- To run workflow-environment tests against the time-skipping test server, pass `--workflow-environment time-skipping`.
2086+
- To run workflow-environment tests against the `default` namespace of an already-running server, pass the `host:port` to `--workflow-environment`.
2087+
- Unit tests that do not use the workflow environment do not start a dev server.
2088+
20842089
#### Proto Generation and Testing
20852090

20862091
If you have docker available, run

temporalio/contrib/google_adk_agents/_model.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import AsyncGenerator
1+
from collections.abc import AsyncGenerator, Callable
22
from datetime import timedelta
33

44
from google.adk.models import BaseLlm, LLMRegistry
@@ -40,20 +40,37 @@ class TemporalModel(BaseLlm):
4040
"""A Temporal-based LLM model that executes model invocations as activities."""
4141

4242
def __init__(
43-
self, model_name: str, activity_config: ActivityConfig | None = None
43+
self,
44+
model_name: str,
45+
activity_config: ActivityConfig | None = None,
46+
*,
47+
summary_fn: Callable[[LlmRequest], str | None] | None = None,
4448
) -> None:
4549
"""Initialize the TemporalModel.
4650
4751
Args:
4852
model_name: The name of the model to use.
4953
activity_config: Configuration options for the activity execution.
54+
summary_fn: Optional callable that receives the LlmRequest and
55+
returns a summary string (or None) for the activity. Must be
56+
deterministic as it is called during workflow execution. If
57+
the callable raises, the exception will propagate and fail
58+
the workflow task.
59+
60+
Raises:
61+
ValueError: If both ``ActivityConfig["summary"]`` and ``summary_fn`` are set.
5062
"""
5163
super().__init__(model=model_name)
5264
self._model_name = model_name
65+
self._summary_fn = summary_fn
5366
self._activity_config = ActivityConfig(
5467
start_to_close_timeout=timedelta(seconds=60)
5568
)
56-
if activity_config:
69+
if activity_config is not None:
70+
if summary_fn is not None and activity_config.get("summary") is not None:
71+
raise ValueError(
72+
"Cannot specify both ActivityConfig 'summary' and 'summary_fn'"
73+
)
5774
self._activity_config.update(activity_config)
5875

5976
async def generate_content_async(
@@ -76,10 +93,20 @@ async def generate_content_async(
7693
yield response
7794
return
7895

96+
config = self._activity_config.copy()
97+
if self._summary_fn is not None:
98+
summary = self._summary_fn(llm_request)
99+
if summary is not None:
100+
config["summary"] = summary
101+
elif "summary" not in config:
102+
if llm_request.config and llm_request.config.labels:
103+
agent_name = llm_request.config.labels.get("adk_agent_name")
104+
if agent_name:
105+
config["summary"] = agent_name
79106
responses = await workflow.execute_activity(
80107
invoke_model,
81108
args=[llm_request],
82-
**self._activity_config,
109+
**config,
83110
)
84111
for response in responses:
85112
yield response

temporalio/contrib/opentelemetry/_interceptor.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,60 @@ async def start_update_with_start_workflow(
341341

342342
return await super().start_update_with_start_workflow(input)
343343

344+
async def start_activity(
345+
self, input: temporalio.client.StartActivityInput
346+
) -> temporalio.client.ActivityHandle[Any]:
347+
with self.root._start_as_current_span(
348+
f"StartActivity:{input.activity_type}",
349+
attributes={
350+
"temporalActivityID": input.id,
351+
"temporalActivityType": input.activity_type,
352+
},
353+
input_with_headers=input,
354+
kind=opentelemetry.trace.SpanKind.CLIENT,
355+
):
356+
return await super().start_activity(input)
357+
358+
async def cancel_activity(
359+
self, input: temporalio.client.CancelActivityInput
360+
) -> None:
361+
with self.root._start_as_current_span(
362+
"CancelActivity",
363+
attributes={"temporalActivityID": input.activity_id},
364+
kind=opentelemetry.trace.SpanKind.CLIENT,
365+
):
366+
return await super().cancel_activity(input)
367+
368+
async def terminate_activity(
369+
self, input: temporalio.client.TerminateActivityInput
370+
) -> None:
371+
with self.root._start_as_current_span(
372+
"TerminateActivity",
373+
attributes={"temporalActivityID": input.activity_id},
374+
kind=opentelemetry.trace.SpanKind.CLIENT,
375+
):
376+
return await super().terminate_activity(input)
377+
378+
async def describe_activity(
379+
self, input: temporalio.client.DescribeActivityInput
380+
) -> temporalio.client.ActivityExecutionDescription:
381+
with self.root._start_as_current_span(
382+
"DescribeActivity",
383+
attributes={"temporalActivityID": input.activity_id},
384+
kind=opentelemetry.trace.SpanKind.CLIENT,
385+
):
386+
return await super().describe_activity(input)
387+
388+
async def count_activities(
389+
self, input: temporalio.client.CountActivitiesInput
390+
) -> temporalio.client.ActivityExecutionCount:
391+
with self.root._start_as_current_span(
392+
"CountActivities",
393+
attributes={},
394+
kind=opentelemetry.trace.SpanKind.CLIENT,
395+
):
396+
return await super().count_activities(input)
397+
344398

345399
class _TracingActivityInboundInterceptor(temporalio.worker.ActivityInboundInterceptor):
346400
def __init__(

temporalio/contrib/opentelemetry/_otel_interceptor.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,70 @@ async def start_update_with_start_workflow(
303303
)
304304
return await super().start_update_with_start_workflow(input)
305305

306+
async def start_activity(
307+
self, input: temporalio.client.StartActivityInput
308+
) -> temporalio.client.ActivityHandle[Any]:
309+
with _maybe_span(
310+
get_tracer(__name__),
311+
f"StartActivity:{input.activity_type}",
312+
add_temporal_spans=self._add_temporal_spans,
313+
attributes={
314+
"temporalActivityID": input.id,
315+
"temporalActivityType": input.activity_type,
316+
},
317+
kind=opentelemetry.trace.SpanKind.CLIENT,
318+
):
319+
input.headers = _context_to_headers(input.headers)
320+
return await super().start_activity(input)
321+
322+
async def cancel_activity(
323+
self, input: temporalio.client.CancelActivityInput
324+
) -> None:
325+
with _maybe_span(
326+
get_tracer(__name__),
327+
"CancelActivity",
328+
add_temporal_spans=self._add_temporal_spans,
329+
attributes={"temporalActivityID": input.activity_id},
330+
kind=opentelemetry.trace.SpanKind.CLIENT,
331+
):
332+
return await super().cancel_activity(input)
333+
334+
async def terminate_activity(
335+
self, input: temporalio.client.TerminateActivityInput
336+
) -> None:
337+
with _maybe_span(
338+
get_tracer(__name__),
339+
"TerminateActivity",
340+
add_temporal_spans=self._add_temporal_spans,
341+
attributes={"temporalActivityID": input.activity_id},
342+
kind=opentelemetry.trace.SpanKind.CLIENT,
343+
):
344+
return await super().terminate_activity(input)
345+
346+
async def describe_activity(
347+
self, input: temporalio.client.DescribeActivityInput
348+
) -> temporalio.client.ActivityExecutionDescription:
349+
with _maybe_span(
350+
get_tracer(__name__),
351+
"DescribeActivity",
352+
add_temporal_spans=self._add_temporal_spans,
353+
attributes={"temporalActivityID": input.activity_id},
354+
kind=opentelemetry.trace.SpanKind.CLIENT,
355+
):
356+
return await super().describe_activity(input)
357+
358+
async def count_activities(
359+
self, input: temporalio.client.CountActivitiesInput
360+
) -> temporalio.client.ActivityExecutionCount:
361+
with _maybe_span(
362+
get_tracer(__name__),
363+
"CountActivities",
364+
add_temporal_spans=self._add_temporal_spans,
365+
attributes={},
366+
kind=opentelemetry.trace.SpanKind.CLIENT,
367+
):
368+
return await super().count_activities(input)
369+
306370

307371
class _TracingActivityInboundInterceptor(temporalio.worker.ActivityInboundInterceptor):
308372
def __init__(

temporalio/converter/_extstore.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,17 @@ class StorageOperationMetrics:
3838
total_duration: timedelta = dataclasses.field(default_factory=timedelta)
3939
"""Wall-clock time spent on external storage operations."""
4040

41-
def record_batch(self, count: int, size: int, duration: timedelta) -> None:
41+
driver_names: set[str] = dataclasses.field(default_factory=set)
42+
"""Names of the drivers that participated in the operations."""
43+
44+
def record_batch(
45+
self, count: int, size: int, duration: timedelta, driver_names: set[str]
46+
) -> None:
4247
"""Record metrics from a batch of storage operations."""
4348
self.payload_count += count
4449
self.total_size += size
4550
self.total_duration += duration
51+
self.driver_names.update(driver_names)
4652

4753
@contextlib.contextmanager
4854
def track(self) -> Generator[Self, None, None]:
@@ -362,7 +368,7 @@ async def _store_payload(self, payload: Payload) -> Payload:
362368
)
363369
reference_payload.external_payloads.add().size_bytes = external_size
364370

365-
ExternalStorage._record_metrics(1, external_size, start_time)
371+
ExternalStorage._record_metrics(1, external_size, start_time, {driver.name()})
366372

367373
return reference_payload
368374

@@ -407,6 +413,7 @@ async def _store_payload_sequence(
407413

408414
external_count = 0
409415
external_size = 0
416+
driver_names: set[str] = set()
410417
for (driver, indexed_payloads), claims in zip(driver_group_list, all_claims):
411418
indices = [idx for idx, _ in indexed_payloads]
412419
sizes = [p.ByteSize() for _, p in indexed_payloads]
@@ -428,8 +435,11 @@ async def _store_payload_sequence(
428435
external_size += sizes[i]
429436

430437
external_count += len(claims)
438+
driver_names.add(driver.name())
431439

432-
ExternalStorage._record_metrics(external_count, external_size, start_time)
440+
ExternalStorage._record_metrics(
441+
external_count, external_size, start_time, driver_names
442+
)
433443

434444
return results
435445

@@ -452,7 +462,9 @@ async def _retrieve_payload(self, payload: Payload) -> Payload:
452462

453463
stored_payload = stored_payloads[0]
454464

455-
ExternalStorage._record_metrics(1, stored_payload.ByteSize(), start_time)
465+
ExternalStorage._record_metrics(
466+
1, stored_payload.ByteSize(), start_time, {driver.name()}
467+
)
456468

457469
return stored_payload
458470

@@ -501,6 +513,7 @@ async def _retrieve_payload_sequence(
501513

502514
external_count = 0
503515
external_size = 0
516+
driver_names: set[str] = set()
504517
for (driver, indexed_claims), stored_payloads in zip(
505518
driver_claim_list, all_stored
506519
):
@@ -517,14 +530,17 @@ async def _retrieve_payload_sequence(
517530
external_size += stored_payload.ByteSize()
518531

519532
external_count += len(stored_payloads)
533+
driver_names.add(driver.name())
520534

521535
retrieve_indices = sorted(stored_by_index.keys())
522536
stored_list = [stored_by_index[idx] for idx in retrieve_indices]
523537

524538
for i, retrieved_payload in enumerate(stored_list):
525539
results[retrieve_indices[i]] = retrieved_payload
526540

527-
ExternalStorage._record_metrics(external_count, external_size, start_time)
541+
ExternalStorage._record_metrics(
542+
external_count, external_size, start_time, driver_names
543+
)
528544

529545
return results
530546

@@ -545,9 +561,14 @@ def _validate_payload_length(
545561
)
546562

547563
@staticmethod
548-
def _record_metrics(count: int, size: int, start_time: float):
564+
def _record_metrics(
565+
count: int, size: int, start_time: float, driver_names: set[str]
566+
):
549567
metrics = _current_storage_metrics.get()
550568
if metrics is not None:
551569
metrics.record_batch(
552-
count, size, timedelta(seconds=time.monotonic() - start_time)
570+
count,
571+
size,
572+
timedelta(seconds=time.monotonic() - start_time),
573+
driver_names,
553574
)

0 commit comments

Comments
 (0)