Skip to content

Commit ac842fd

Browse files
authored
Emit workflow task duration information via logging (#1386)
1 parent 8988e5b commit ac842fd

6 files changed

Lines changed: 441 additions & 21 deletions

File tree

temporalio/bridge/worker.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import temporalio.bridge.runtime
2121
import temporalio.bridge.temporal_sdk_bridge
2222
import temporalio.converter
23+
import temporalio.converter._extstore
2324
from temporalio.api.common.v1.message_pb2 import Payload
2425
from temporalio.bridge._visitor import VisitorFunctions
2526
from temporalio.bridge.temporal_sdk_bridge import (
@@ -302,19 +303,33 @@ async def decode_activation(
302303
activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation,
303304
data_converter: temporalio.converter.DataConverter,
304305
decode_headers: bool,
305-
) -> None:
306-
"""Decode all payloads in the activation."""
307-
await CommandAwarePayloadVisitor(
308-
skip_search_attributes=True, skip_headers=not decode_headers
309-
).visit(_Visitor(data_converter._decode_payload_sequence), activation)
306+
) -> temporalio.converter._extstore.StorageOperationMetrics:
307+
"""Decode all payloads in the activation.
308+
309+
Returns:
310+
Metrics from any external storage retrieval operations that occurred.
311+
"""
312+
metrics = temporalio.converter._extstore.StorageOperationMetrics()
313+
with metrics.track():
314+
await CommandAwarePayloadVisitor(
315+
skip_search_attributes=True, skip_headers=not decode_headers
316+
).visit(_Visitor(data_converter._decode_payload_sequence), activation)
317+
return metrics
310318

311319

312320
async def encode_completion(
313321
completion: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
314322
data_converter: temporalio.converter.DataConverter,
315323
encode_headers: bool,
316-
) -> None:
317-
"""Encode all payloads in the completion."""
318-
await CommandAwarePayloadVisitor(
319-
skip_search_attributes=True, skip_headers=not encode_headers
320-
).visit(_Visitor(data_converter._encode_payload_sequence), completion)
324+
) -> temporalio.converter._extstore.StorageOperationMetrics:
325+
"""Encode all payloads in the completion.
326+
327+
Returns:
328+
Metrics from any external storage store operations that occurred.
329+
"""
330+
metrics = temporalio.converter._extstore.StorageOperationMetrics()
331+
with metrics.track():
332+
await CommandAwarePayloadVisitor(
333+
skip_search_attributes=True, skip_headers=not encode_headers
334+
).visit(_Visitor(data_converter._encode_payload_sequence), completion)
335+
return metrics

temporalio/converter/_extstore.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
from __future__ import annotations
66

77
import asyncio
8+
import contextlib
9+
import contextvars
810
import dataclasses
11+
import time
912
from abc import ABC, abstractmethod
10-
from collections.abc import Callable, Coroutine, Mapping, Sequence
13+
from collections.abc import Callable, Coroutine, Generator, Mapping, Sequence
1114
from dataclasses import dataclass
15+
from datetime import timedelta
1216
from typing import Any, ClassVar, TypeVar
1317

1418
from typing_extensions import Self
@@ -25,6 +29,40 @@
2529
_REFERENCE_ENCODING = b"json/external-storage-reference"
2630

2731

32+
@dataclass
33+
class StorageOperationMetrics:
34+
"""Accumulates metrics from external storage operations."""
35+
36+
payload_count: int = 0
37+
"""Number of payloads stored or retrieved externally."""
38+
39+
total_size: int = 0
40+
"""Total size in bytes of externally stored/retrieved payloads."""
41+
42+
total_duration: timedelta = dataclasses.field(default_factory=timedelta)
43+
"""Wall-clock time spent on external storage operations."""
44+
45+
def record_batch(self, count: int, size: int, duration: timedelta) -> None:
46+
"""Record metrics from a batch of storage operations."""
47+
self.payload_count += count
48+
self.total_size += size
49+
self.total_duration += duration
50+
51+
@contextlib.contextmanager
52+
def track(self) -> Generator[Self, None, None]:
53+
"""Set this instance as the current metrics context and reset on exit."""
54+
token = _current_storage_metrics.set(self)
55+
try:
56+
yield self
57+
finally:
58+
_current_storage_metrics.reset(token)
59+
60+
61+
_current_storage_metrics: contextvars.ContextVar[StorageOperationMetrics | None] = (
62+
contextvars.ContextVar("_current_storage_metrics", default=None)
63+
)
64+
65+
2866
async def _gather_cancel_on_error(
2967
coros: Sequence[Coroutine[Any, Any, _T]],
3068
) -> list[_T]:
@@ -255,6 +293,7 @@ def _get_driver_by_name(self, name: str) -> StorageDriver:
255293
return driver
256294

257295
async def _store_payload(self, payload: Payload) -> Payload:
296+
start_time = time.monotonic()
258297
context = StorageDriverStoreContext(serialization_context=self._context)
259298

260299
driver = self._select_driver(context, payload)
@@ -265,6 +304,7 @@ async def _store_payload(self, payload: Payload) -> Payload:
265304

266305
self._validate_claim_length(claims, expected=1, driver=driver)
267306

307+
external_size = payload.ByteSize()
268308
reference = _StorageReference(
269309
driver_name=driver.name(),
270310
driver_claim=claims[0],
@@ -274,7 +314,10 @@ async def _store_payload(self, payload: Payload) -> Payload:
274314
raise ValueError(
275315
f"Failed to serialize storage reference for driver '{driver.name()}'"
276316
)
277-
reference_payload.external_payloads.add().size_bytes = payload.ByteSize()
317+
reference_payload.external_payloads.add().size_bytes = external_size
318+
319+
ExternalStorage._record_metrics(1, external_size, start_time)
320+
278321
return reference_payload
279322

280323
async def _store_payloads(self, payloads: Payloads):
@@ -289,6 +332,8 @@ async def _store_payload_sequence(
289332
if len(payloads) == 1:
290333
return [await self._store_payload(payloads[0])]
291334

335+
start_time = time.monotonic()
336+
292337
results = list(payloads)
293338
context = StorageDriverStoreContext(serialization_context=self._context)
294339

@@ -315,6 +360,8 @@ async def _store_payload_sequence(
315360
]
316361
)
317362

363+
external_count = 0
364+
external_size = 0
318365
for (driver, indexed_payloads), claims in zip(driver_group_list, all_claims):
319366
indices = [idx for idx, _ in indexed_payloads]
320367
sizes = [p.ByteSize() for _, p in indexed_payloads]
@@ -333,13 +380,20 @@ async def _store_payload_sequence(
333380
)
334381
reference_payload.external_payloads.add().size_bytes = sizes[i]
335382
results[indices[i]] = reference_payload
383+
external_size += sizes[i]
384+
385+
external_count += len(claims)
386+
387+
ExternalStorage._record_metrics(external_count, external_size, start_time)
336388

337389
return results
338390

339391
async def _retrieve_payload(self, payload: Payload) -> Payload:
340392
if len(payload.external_payloads) == 0:
341393
return payload
342394

395+
start_time = time.monotonic()
396+
343397
reference = self._claim_converter.from_payload(payload, _StorageReference)
344398
if not isinstance(reference, _StorageReference):
345399
return payload
@@ -351,7 +405,11 @@ async def _retrieve_payload(self, payload: Payload) -> Payload:
351405

352406
self._validate_payload_length(stored_payloads, expected=1, driver=driver)
353407

354-
return stored_payloads[0]
408+
stored_payload = stored_payloads[0]
409+
410+
ExternalStorage._record_metrics(1, stored_payload.ByteSize(), start_time)
411+
412+
return stored_payload
355413

356414
async def _retrieve_payloads(self, payloads: Payloads):
357415
stored_payloads = await self._retrieve_payload_sequence(payloads.payloads)
@@ -362,11 +420,13 @@ async def _retrieve_payload_sequence(
362420
self,
363421
payloads: Sequence[Payload],
364422
) -> list[Payload]:
365-
results = list(payloads)
366-
367423
if len(payloads) == 1:
368424
return [await self._retrieve_payload(payloads[0])]
369425

426+
start_time = time.monotonic()
427+
428+
results = list(payloads)
429+
370430
driver_claims: dict[StorageDriver, list[tuple[int, StorageDriverClaim]]] = {}
371431
for index, payload in enumerate(payloads):
372432
if len(payload.external_payloads) == 0:
@@ -394,6 +454,8 @@ async def _retrieve_payload_sequence(
394454
]
395455
)
396456

457+
external_count = 0
458+
external_size = 0
397459
for (driver, indexed_claims), stored_payloads in zip(
398460
driver_claim_list, all_stored
399461
):
@@ -407,13 +469,18 @@ async def _retrieve_payload_sequence(
407469

408470
for idx, stored_payload in zip(indices, stored_payloads):
409471
stored_by_index[idx] = stored_payload
472+
external_size += stored_payload.ByteSize()
473+
474+
external_count += len(stored_payloads)
410475

411476
retrieve_indices = sorted(stored_by_index.keys())
412477
stored_list = [stored_by_index[idx] for idx in retrieve_indices]
413478

414479
for i, retrieved_payload in enumerate(stored_list):
415480
results[retrieve_indices[i]] = retrieved_payload
416481

482+
ExternalStorage._record_metrics(external_count, external_size, start_time)
483+
417484
return results
418485

419486
def _validate_claim_length(
@@ -431,3 +498,11 @@ def _validate_payload_length(
431498
raise ValueError(
432499
f"Driver '{driver.name()}' returned {len(payloads)} payloads, expected {expected}",
433500
)
501+
502+
@staticmethod
503+
def _record_metrics(count: int, size: int, start_time: float):
504+
metrics = _current_storage_metrics.get()
505+
if metrics is not None:
506+
metrics.record_batch(
507+
count, size, timedelta(seconds=time.monotonic() - start_time)
508+
)

temporalio/worker/_workflow.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
import os
1010
import sys
1111
import threading
12+
import time
1213
from collections.abc import Awaitable, Callable, MutableMapping, Sequence
1314
from dataclasses import dataclass
14-
from datetime import timezone
15+
from datetime import timedelta, timezone
1516
from types import TracebackType
1617

1718
import temporalio.api.common.v1
@@ -21,6 +22,7 @@
2122
import temporalio.bridge.worker
2223
import temporalio.common
2324
import temporalio.converter
25+
import temporalio.converter._extstore
2426
import temporalio.converter._payload_limits
2527
import temporalio.exceptions
2628
import temporalio.workflow
@@ -255,6 +257,8 @@ async def _handle_activation(
255257
completion.successful.SetInParent()
256258
workflow = None
257259
data_converter = self._data_converter
260+
task_start_time = time.monotonic()
261+
download_metrics = temporalio.converter._extstore.StorageOperationMetrics()
258262
try:
259263
if LOG_PROTOS:
260264
logger.debug("Received workflow activation:\n%s", act)
@@ -291,7 +295,7 @@ async def _handle_activation(
291295
workflow_context=workflow_context,
292296
),
293297
)
294-
await temporalio.bridge.worker.decode_activation(
298+
download_metrics = await temporalio.bridge.worker.decode_activation(
295299
act,
296300
data_converter,
297301
decode_headers=self._encode_headers,
@@ -399,9 +403,10 @@ async def _handle_activation(
399403
),
400404
)
401405

406+
upload_metrics = temporalio.converter._extstore.StorageOperationMetrics()
402407
try:
403408
try:
404-
await temporalio.bridge.worker.encode_completion(
409+
upload_metrics = await temporalio.bridge.worker.encode_completion(
405410
completion,
406411
data_converter,
407412
encode_headers=self._encode_headers,
@@ -429,6 +434,71 @@ async def _handle_activation(
429434
"Failed completing activation on workflow with run ID %s", act.run_id
430435
)
431436

437+
# Log workflow task duration with external storage metrics
438+
self._log_workflow_task_duration(
439+
act, task_start_time, download_metrics, upload_metrics
440+
)
441+
442+
@staticmethod
443+
def _log_workflow_task_duration(
444+
act: temporalio.bridge.proto.workflow_activation.WorkflowActivation,
445+
task_start_time: float,
446+
download_metrics: temporalio.converter._extstore.StorageOperationMetrics,
447+
upload_metrics: temporalio.converter._extstore.StorageOperationMetrics,
448+
) -> None:
449+
task_duration = timedelta(seconds=time.monotonic() - task_start_time)
450+
451+
def _fmt_duration(td: timedelta) -> str:
452+
secs = td.total_seconds()
453+
if secs >= 1:
454+
return f"{secs:.3f}s"
455+
return f"{secs * 1000:.3f}ms"
456+
457+
msg_details: dict[str, object] = {
458+
"event_id": act.history_length,
459+
"workflow_task_duration": _fmt_duration(task_duration),
460+
}
461+
extra: dict[str, object] = {
462+
"event_id": act.history_length,
463+
"workflow_task_duration": task_duration,
464+
}
465+
if download_metrics.payload_count > 0:
466+
msg_details["payload_download_count"] = download_metrics.payload_count
467+
msg_details["payload_download_size"] = download_metrics.total_size
468+
msg_details["payload_download_duration"] = _fmt_duration(
469+
download_metrics.total_duration
470+
)
471+
extra["payload_download_count"] = download_metrics.payload_count
472+
extra["payload_download_size"] = download_metrics.total_size
473+
extra["payload_download_duration"] = download_metrics.total_duration
474+
if upload_metrics.payload_count > 0:
475+
msg_details["payload_upload_count"] = upload_metrics.payload_count
476+
msg_details["payload_upload_size"] = upload_metrics.total_size
477+
msg_details["payload_upload_duration"] = _fmt_duration(
478+
upload_metrics.total_duration
479+
)
480+
extra["payload_upload_count"] = upload_metrics.payload_count
481+
extra["payload_upload_size"] = upload_metrics.total_size
482+
extra["payload_upload_duration"] = upload_metrics.total_duration
483+
if task_duration.total_seconds() > 10:
484+
logger.warning(
485+
"[TMPRL1104] Workflow task exceeded 10 seconds (%s)",
486+
msg_details,
487+
extra=extra,
488+
)
489+
elif task_duration.total_seconds() > 5:
490+
logger.info(
491+
"[TMPRL1104] Workflow task exceeded 5 seconds (%s)",
492+
msg_details,
493+
extra=extra,
494+
)
495+
else:
496+
logger.debug(
497+
"[TMPRL1104] Workflow task duration information (%s)",
498+
msg_details,
499+
extra=extra,
500+
)
501+
432502
async def _handle_cache_eviction(
433503
self,
434504
act: temporalio.bridge.proto.workflow_activation.WorkflowActivation,

tests/helpers/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,12 +422,12 @@ def __init__(self) -> None:
422422
self.log_queue: queue.Queue[logging.LogRecord] = queue.Queue()
423423

424424
@contextmanager
425-
def logs_captured(self, *loggers: logging.Logger):
425+
def logs_captured(self, *loggers: logging.Logger, level: int = logging.INFO):
426426
handler = logging.handlers.QueueHandler(self.log_queue)
427427

428428
prev_levels = [l.level for l in loggers]
429429
for l in loggers:
430-
l.setLevel(logging.INFO)
430+
l.setLevel(level)
431431
l.addHandler(handler)
432432
try:
433433
yield self
@@ -447,6 +447,15 @@ def find(
447447
return record
448448
return None
449449

450+
def find_all(
451+
self, pred: Callable[[logging.LogRecord], bool]
452+
) -> list[logging.LogRecord]:
453+
return [
454+
record
455+
for record in cast(list[logging.LogRecord], self.log_queue.queue)
456+
if pred(record)
457+
]
458+
450459

451460
class LogHandler:
452461
@staticmethod

0 commit comments

Comments
 (0)