Skip to content

Activity worker refactor #860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 176 additions & 177 deletions temporalio/worker/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ async def drain_poll_queue(self) -> None:

# Only call this after run()/drain_poll_queue() have returned. This will not
# raise an exception.
# TODO(dan): check accuracy of this comment; I would say it *does* raise an exception.
async def wait_all_completed(self) -> None:
running_tasks = [v.task for v in self._running_activities.values() if v.task]
if running_tasks:
Expand Down Expand Up @@ -281,183 +282,7 @@ async def _run_activity(
task_token=task_token
)
try:
# Find activity or fail
activity_def = self._activities.get(
start.activity_type, self._dynamic_activity
)
if not activity_def:
activity_names = ", ".join(sorted(self._activities.keys()))
raise temporalio.exceptions.ApplicationError(
f"Activity function {start.activity_type} for workflow {start.workflow_execution.workflow_id} "
f"is not registered on this worker, available activities: {activity_names}",
type="NotFoundError",
)

# Create the worker shutdown event if not created
if not self._worker_shutdown_event:
self._worker_shutdown_event = temporalio.activity._CompositeEvent(
thread_event=threading.Event(), async_event=asyncio.Event()
)

# Setup events
sync_non_threaded = False
if not activity_def.is_async:
running_activity.sync = True
# If we're in a thread-pool executor we can use threading events
# otherwise we must use manager events
if isinstance(
self._activity_executor, concurrent.futures.ThreadPoolExecutor
):
running_activity.cancelled_event = (
temporalio.activity._CompositeEvent(
thread_event=threading.Event(),
# No async event
async_event=None,
)
)
if not activity_def.no_thread_cancel_exception:
running_activity.cancel_thread_raiser = _ThreadExceptionRaiser()
else:
sync_non_threaded = True
manager = self._shared_state_manager
# Pre-checked on worker init
assert manager
running_activity.cancelled_event = (
temporalio.activity._CompositeEvent(
thread_event=manager.new_event(),
# No async event
async_event=None,
)
)
# We also must set the worker shutdown thread event to a
# manager event if this is the first sync event. We don't
# want to create if there never is a sync event.
if not self._seen_sync_activity:
self._worker_shutdown_event.thread_event = manager.new_event()
# Say we've seen a sync activity
self._seen_sync_activity = True
else:
# We have to set the async form of events
running_activity.cancelled_event = temporalio.activity._CompositeEvent(
thread_event=threading.Event(),
async_event=asyncio.Event(),
)

# Convert arguments. We use raw value for dynamic. Otherwise, we
# only use arg type hints if they match the input count.
arg_types = activity_def.arg_types
if not activity_def.name:
# Dynamic is just the raw value for each input value
arg_types = [temporalio.common.RawValue] * len(start.input)
elif arg_types is not None and len(arg_types) != len(start.input):
arg_types = None
try:
args = (
[]
if not start.input
else await self._data_converter.decode(
start.input, type_hints=arg_types
)
)
except Exception as err:
raise temporalio.exceptions.ApplicationError(
"Failed decoding arguments"
) from err
# Put the args inside a list if dynamic
if not activity_def.name:
args = [args]

# Convert heartbeat details
# TODO(cretz): Allow some way to configure heartbeat type hinting?
try:
heartbeat_details = (
[]
if not start.heartbeat_details
else await self._data_converter.decode(start.heartbeat_details)
)
except Exception as err:
raise temporalio.exceptions.ApplicationError(
"Failed decoding heartbeat details", non_retryable=True
) from err

# Build info
info = temporalio.activity.Info(
activity_id=start.activity_id,
activity_type=start.activity_type,
attempt=start.attempt,
current_attempt_scheduled_time=_proto_to_datetime(
start.current_attempt_scheduled_time
),
heartbeat_details=heartbeat_details,
heartbeat_timeout=_proto_to_non_zero_timedelta(start.heartbeat_timeout)
if start.HasField("heartbeat_timeout")
else None,
is_local=start.is_local,
schedule_to_close_timeout=_proto_to_non_zero_timedelta(
start.schedule_to_close_timeout
)
if start.HasField("schedule_to_close_timeout")
else None,
scheduled_time=_proto_to_datetime(start.scheduled_time),
start_to_close_timeout=_proto_to_non_zero_timedelta(
start.start_to_close_timeout
)
if start.HasField("start_to_close_timeout")
else None,
started_time=_proto_to_datetime(start.started_time),
task_queue=self._task_queue,
task_token=task_token,
workflow_id=start.workflow_execution.workflow_id,
workflow_namespace=start.workflow_namespace,
workflow_run_id=start.workflow_execution.run_id,
workflow_type=start.workflow_type,
priority=temporalio.common.Priority._from_proto(start.priority),
)
running_activity.info = info
input = ExecuteActivityInput(
fn=activity_def.fn,
args=args,
executor=None if not running_activity.sync else self._activity_executor,
headers=start.header_fields,
)

# Set the context early so the logging adapter works and
# interceptors have it
temporalio.activity._Context.set(
temporalio.activity._Context(
info=lambda: info,
heartbeat=None,
cancelled_event=running_activity.cancelled_event,
worker_shutdown_event=self._worker_shutdown_event,
shield_thread_cancel_exception=None
if not running_activity.cancel_thread_raiser
else running_activity.cancel_thread_raiser.shielded,
payload_converter_class_or_instance=self._data_converter.payload_converter,
runtime_metric_meter=None
if sync_non_threaded
else self._metric_meter,
)
)
temporalio.activity.logger.debug("Starting activity")

# Build the interceptors chaining in reverse. We build a context right
# now even though the info() can't be intercepted and heartbeat() will
# fail. The interceptors may want to use the info() during init.
impl: ActivityInboundInterceptor = _ActivityInboundImpl(
self, running_activity
)
for interceptor in reversed(list(self._interceptors)):
impl = interceptor.intercept_activity(impl)
# Init
impl.init(_ActivityOutboundImpl(self, running_activity.info))
# Exec
result = await impl.execute_activity(input)
# Convert result even if none. Since Python essentially only
# supports single result types (even if they are tuples), we will do
# the same.
completion.result.completed.result.CopyFrom(
(await self._data_converter.encode([result]))[0]
)
await self._execute_activity(start, running_activity, completion)
except BaseException as err:
try:
if isinstance(err, temporalio.activity._CompleteAsyncError):
Expand Down Expand Up @@ -532,6 +357,180 @@ async def _run_activity(
except Exception:
temporalio.activity.logger.exception("Failed completing activity task")

async def _execute_activity(
self,
start: temporalio.bridge.proto.activity_task.Start,
running_activity: _RunningActivity,
completion: temporalio.bridge.proto.ActivityTaskCompletion,
):
# Find activity or fail
activity_def = self._activities.get(start.activity_type, self._dynamic_activity)
if not activity_def:
activity_names = ", ".join(sorted(self._activities.keys()))
raise temporalio.exceptions.ApplicationError(
f"Activity function {start.activity_type} for workflow {start.workflow_execution.workflow_id} "
f"is not registered on this worker, available activities: {activity_names}",
type="NotFoundError",
)

# Create the worker shutdown event if not created
if not self._worker_shutdown_event:
self._worker_shutdown_event = temporalio.activity._CompositeEvent(
thread_event=threading.Event(), async_event=asyncio.Event()
)

# Setup events
sync_non_threaded = False
if not activity_def.is_async:
running_activity.sync = True
# If we're in a thread-pool executor we can use threading events
# otherwise we must use manager events
if isinstance(
self._activity_executor, concurrent.futures.ThreadPoolExecutor
):
running_activity.cancelled_event = temporalio.activity._CompositeEvent(
thread_event=threading.Event(),
# No async event
async_event=None,
)
if not activity_def.no_thread_cancel_exception:
running_activity.cancel_thread_raiser = _ThreadExceptionRaiser()
else:
sync_non_threaded = True
manager = self._shared_state_manager
# Pre-checked on worker init
assert manager
running_activity.cancelled_event = temporalio.activity._CompositeEvent(
thread_event=manager.new_event(),
# No async event
async_event=None,
)
# We also must set the worker shutdown thread event to a
# manager event if this is the first sync event. We don't
# want to create if there never is a sync event.
if not self._seen_sync_activity:
self._worker_shutdown_event.thread_event = manager.new_event()
# Say we've seen a sync activity
self._seen_sync_activity = True
else:
# We have to set the async form of events
running_activity.cancelled_event = temporalio.activity._CompositeEvent(
thread_event=threading.Event(),
async_event=asyncio.Event(),
)

# Convert arguments. We use raw value for dynamic. Otherwise, we
# only use arg type hints if they match the input count.
arg_types = activity_def.arg_types
if not activity_def.name:
# Dynamic is just the raw value for each input value
arg_types = [temporalio.common.RawValue] * len(start.input)
elif arg_types is not None and len(arg_types) != len(start.input):
arg_types = None
try:
args = (
[]
if not start.input
else await self._data_converter.decode(
start.input, type_hints=arg_types
)
)
except Exception as err:
raise temporalio.exceptions.ApplicationError(
"Failed decoding arguments"
) from err
# Put the args inside a list if dynamic
if not activity_def.name:
args = [args]

# Convert heartbeat details
# TODO(cretz): Allow some way to configure heartbeat type hinting?
try:
heartbeat_details = (
[]
if not start.heartbeat_details
else await self._data_converter.decode(start.heartbeat_details)
)
except Exception as err:
raise temporalio.exceptions.ApplicationError(
"Failed decoding heartbeat details", non_retryable=True
) from err

# Build info
info = temporalio.activity.Info(
activity_id=start.activity_id,
activity_type=start.activity_type,
attempt=start.attempt,
current_attempt_scheduled_time=_proto_to_datetime(
start.current_attempt_scheduled_time
),
heartbeat_details=heartbeat_details,
heartbeat_timeout=_proto_to_non_zero_timedelta(start.heartbeat_timeout)
if start.HasField("heartbeat_timeout")
else None,
is_local=start.is_local,
schedule_to_close_timeout=_proto_to_non_zero_timedelta(
start.schedule_to_close_timeout
)
if start.HasField("schedule_to_close_timeout")
else None,
scheduled_time=_proto_to_datetime(start.scheduled_time),
start_to_close_timeout=_proto_to_non_zero_timedelta(
start.start_to_close_timeout
)
if start.HasField("start_to_close_timeout")
else None,
started_time=_proto_to_datetime(start.started_time),
task_queue=self._task_queue,
task_token=completion.task_token,
workflow_id=start.workflow_execution.workflow_id,
workflow_namespace=start.workflow_namespace,
workflow_run_id=start.workflow_execution.run_id,
workflow_type=start.workflow_type,
priority=temporalio.common.Priority._from_proto(start.priority),
)
running_activity.info = info
input = ExecuteActivityInput(
fn=activity_def.fn,
args=args,
executor=None if not running_activity.sync else self._activity_executor,
headers=start.header_fields,
)

# Set the context early so the logging adapter works and
# interceptors have it
temporalio.activity._Context.set(
temporalio.activity._Context(
info=lambda: info,
heartbeat=None,
cancelled_event=running_activity.cancelled_event,
worker_shutdown_event=self._worker_shutdown_event,
shield_thread_cancel_exception=None
if not running_activity.cancel_thread_raiser
else running_activity.cancel_thread_raiser.shielded,
payload_converter_class_or_instance=self._data_converter.payload_converter,
runtime_metric_meter=None if sync_non_threaded else self._metric_meter,
)
)
temporalio.activity.logger.debug("Starting activity")

# Build the interceptors chaining in reverse. We build a context right
# now even though the info() can't be intercepted and heartbeat() will
# fail. The interceptors may want to use the info() during init.
impl: ActivityInboundInterceptor = _ActivityInboundImpl(self, running_activity)
for interceptor in reversed(list(self._interceptors)):
impl = interceptor.intercept_activity(impl)
# Init
impl.init(_ActivityOutboundImpl(self, running_activity.info))
# Exec
result = await impl.execute_activity(input)
# Convert result even if none. Since Python essentially only
# supports single result types (even if they are tuples), we will do
# the same.
completion.result.completed.result.CopyFrom(
(await self._data_converter.encode([result]))[0]
)


@dataclass
class _RunningActivity:
Expand Down
Loading