Skip to content

Commit a8b07bf

Browse files
committed
Change to custom context manager
1 parent 8637904 commit a8b07bf

4 files changed

Lines changed: 23 additions & 24 deletions

File tree

tests/helpers/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from datetime import datetime, timedelta, timezone
1212
from typing import (
1313
Any,
14+
Iterator,
1415
TypeVar,
1516
cast,
1617
)
@@ -440,3 +441,16 @@ def find(
440441
if pred(record):
441442
return record
442443
return None
444+
445+
446+
class LogHandler:
447+
@staticmethod
448+
@contextmanager
449+
def apply(logger: logging.Logger, handler: logging.Handler) -> Iterator[None]:
450+
level = logger.level
451+
logger.addHandler(handler)
452+
try:
453+
yield
454+
finally:
455+
logger.removeHandler(handler)
456+
logger.level = level

tests/test_runtime.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from temporalio.worker import Worker
2424
from tests.helpers import (
25+
LogHandler,
2526
assert_eq_eventually,
2627
assert_eventually,
2728
find_free_port,
@@ -96,8 +97,7 @@ async def log_queue_len() -> int:
9697
)
9798
)
9899

99-
logger.addHandler(handler)
100-
try:
100+
with LogHandler.apply(logger, handler):
101101
# Set capture only info logs
102102
logger.setLevel(logging.INFO)
103103
# Write some logs
@@ -141,8 +141,6 @@ async def log_queue_len() -> int:
141141
assert log_queue_list[2].message.startswith(
142142
"[sdk_core::temporal_sdk_bridge::runtime] info6"
143143
)
144-
finally:
145-
logger.removeHandler(handler)
146144

147145

148146
@workflow.defn
@@ -172,8 +170,7 @@ async def test_runtime_task_fail_log_forwarding(client: Client):
172170
),
173171
)
174172

175-
logger.addHandler(handler)
176-
try:
173+
with LogHandler.apply(logger, handler):
177174
# Start workflow
178175
task_queue = f"task-queue-{uuid.uuid4()}"
179176
async with Worker(client, task_queue=task_queue, workflows=[TaskFailWorkflow]):
@@ -199,8 +196,6 @@ async def has_log() -> bool:
199196
== f"{logger.name}-sdk_core::temporalio_sdk_core::worker::workflow"
200197
)
201198
assert record.temporal_log.fields["run_id"] == handle.result_run_id # type: ignore
202-
finally:
203-
logger.removeHandler(handler)
204199

205200

206201
async def test_prometheus_histogram_bucket_overrides(client: Client):

tests/worker/test_activity.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
Worker,
4747
WorkerConfig,
4848
)
49+
from tests.helpers import LogHandler
4950
from tests.helpers.worker import (
5051
ExternalWorker,
5152
KSAction,
@@ -1019,20 +1020,15 @@ async def say_hello(name: str) -> str:
10191020

10201021
# Create a queue, add handler to logger, call normal activity, then check
10211022
handler = logging.handlers.QueueHandler(queue.Queue())
1022-
activity.logger.base_logger.addHandler(handler)
1023-
prev_level = activity.logger.base_logger.level
1024-
activity.logger.base_logger.setLevel(logging.INFO)
1025-
try:
1023+
with LogHandler.apply(activity.logger.base_logger, handler):
1024+
activity.logger.base_logger.setLevel(logging.INFO)
10261025
result = await _execute_workflow_with_activity(
10271026
client,
10281027
worker,
10291028
say_hello,
10301029
"Temporal",
10311030
shared_state_manager=shared_state_manager,
10321031
)
1033-
finally:
1034-
activity.logger.base_logger.removeHandler(handler)
1035-
activity.logger.base_logger.setLevel(prev_level)
10361032
assert result.result == "Hello, Temporal!"
10371033
records: list[logging.LogRecord] = list(handler.queue.queue) # type: ignore
10381034
assert len(records) > 0
@@ -1671,9 +1667,8 @@ async def raise_error():
16711667
raise RuntimeError("oh no!")
16721668

16731669
handler = CustomLogHandler()
1674-
activity.logger.base_logger.addHandler(handler)
16751670

1676-
try:
1671+
with LogHandler.apply(activity.logger.base_logger, handler):
16771672
with pytest.raises(WorkflowFailureError) as err:
16781673
await _execute_workflow_with_activity(
16791674
client,
@@ -1686,9 +1681,6 @@ async def raise_error():
16861681
)
16871682
assert handler._trace_identifiers == 1
16881683

1689-
finally:
1690-
activity.logger.base_logger.removeHandler(handler)
1691-
16921684

16931685
async def test_activity_heartbeat_context(
16941686
client: Client,

tests/worker/test_workflow.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
from tests import DEV_SERVER_DOWNLOAD_VERSION
123123
from tests.helpers import (
124124
LogCapturer,
125+
LogHandler,
125126
admitted_update_task,
126127
assert_eq_eventually,
127128
assert_eventually,
@@ -8453,8 +8454,7 @@ async def test_disable_logger_sandbox(
84538454
):
84548455
logger = workflow.logger.logger
84558456
handler = CustomLogHandler()
8456-
logger.addHandler(handler)
8457-
try:
8457+
with LogHandler.apply(logger, handler):
84588458
async with new_worker(
84598459
client,
84608460
DisableLoggerSandbox,
@@ -8485,5 +8485,3 @@ async def test_disable_logger_sandbox(
84858485
run_timeout=timedelta(seconds=1),
84868486
retry_policy=RetryPolicy(maximum_attempts=1),
84878487
)
8488-
finally:
8489-
logger.removeHandler(handler)

0 commit comments

Comments
 (0)