Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Taskfile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ tasks:
cmds:
- cd frontend/app/ && pnpm run lint:fix && pnpm run prettier:check
- cd frontend/docs/ && pnpm run lint:fix && pnpm run prettier:check
- cd sdks/python/ && bash ./generate.sh
- pre-commit run --all-files || pre-commit run --all-files
docs:
cmds:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from hatchet_sdk.config import ClientConfig
from hatchet_sdk.connection import new_conn
from hatchet_sdk.contracts.v1.dispatcher_pb2 import (
CreateDurableEventLogRequest,
DurableEvent,
GetDurableEventLogRequest,
GetDurableEventLogResponse,
ListenForDurableEventRequest,
)
from hatchet_sdk.contracts.v1.dispatcher_pb2 import (
Expand Down Expand Up @@ -121,6 +124,45 @@ def register_durable_event(

return True

def get_durable_event_log(
self, external_id: str, key: str
) -> GetDurableEventLogResponse:
conn = new_conn(self.config, True)
client = V1DispatcherStub(conn)

get_durable_event_log = tenacity_retry(
client.GetDurableEventLog, self.config.tenacity
)

resp: GetDurableEventLogResponse = get_durable_event_log(
GetDurableEventLogRequest(
external_id=external_id,
key=key,
),
timeout=5,
metadata=get_metadata(self.token),
)

return resp

def create_durable_event_log(self, external_id: str, key: str, data: bytes) -> None:
conn = new_conn(self.config, True)
client = V1DispatcherStub(conn)

create_durable_event_log = tenacity_retry(
client.CreateDurableEventLog, self.config.tenacity
)

create_durable_event_log(
CreateDurableEventLogRequest(
external_id=external_id,
key=key,
data=data,
),
timeout=5,
metadata=get_metadata(self.token),
)

async def result(self, task_id: str, signal_key: str) -> dict[str, Any]:
key = self._generate_key(task_id, signal_key)

Expand Down
69 changes: 68 additions & 1 deletion sdks/python/hatchet_sdk/context/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import hashlib
import json
from collections.abc import Awaitable, Callable
from datetime import timedelta
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast
from warnings import warn

from hatchet_sdk.clients.admin import AdminClient
Expand All @@ -28,11 +30,20 @@
from hatchet_sdk.utils.typing import JSONSerializableMapping, LogLevel
from hatchet_sdk.worker.runner.utils.capture_logs import AsyncLogSender, LogRecord

TMemo = TypeVar("TMemo")

if TYPE_CHECKING:
from hatchet_sdk.runnables.task import Task
from hatchet_sdk.runnables.types import R, TWorkflowInput


def _compute_memo_key(step_name: str, deps: list[Any]) -> str:
h = hashlib.sha256()
h.update(step_name.encode())
h.update(json.dumps(deps, default=str).encode())
return h.hexdigest()


class Context:
def __init__(
self,
Expand Down Expand Up @@ -520,3 +531,59 @@ async def aio_sleep_for(self, duration: Duration) -> dict[str, Any]:
f"sleep:{timedelta_to_expr(duration)}-{wait_index}",
SleepCondition(duration=duration),
)

def memo(self, fn: Callable[[], TMemo], deps: list[Any]) -> TMemo:
if self.durable_event_listener is None:
raise ValueError("Durable event listener is not available")

key = _compute_memo_key(self.action.action_id, deps)

resp = self.durable_event_listener.get_durable_event_log(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be wrapped in asyncio.to_thread so it doesn't block

external_id=self.workflow_run_id,
key=key,
)

if resp.found:
return cast(TMemo, json.loads(resp.data))

result = fn()

data = json.dumps(result).encode()

self.durable_event_listener.create_durable_event_log(
external_id=self.workflow_run_id,
key=key,
data=data,
)
Comment on lines +551 to +557
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and to_thread for these ones :)


return result

async def aio_memo(
self, fn: Callable[[], Awaitable[TMemo]], deps: list[Any]
) -> TMemo:
if self.durable_event_listener is None:
raise ValueError("Durable event listener is not available")

key = _compute_memo_key(self.action.action_id, deps)

resp = await asyncio.to_thread(
self.durable_event_listener.get_durable_event_log,
external_id=self.workflow_run_id,
key=key,
)

if resp.found:
return await asyncio.to_thread(json.loads, resp.data)

result = await fn()

data = (await asyncio.to_thread(json.dumps, result)).encode()

await asyncio.to_thread(
self.durable_event_listener.create_durable_event_log,
external_id=self.workflow_run_id,
key=key,
data=data,
)

return result
14 changes: 11 additions & 3 deletions sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,33 @@ class DurableEvent(_message.Message):
signal_key: str
data: bytes
def __init__(self, task_id: _Optional[str] = ..., signal_key: _Optional[str] = ..., data: _Optional[bytes] = ...) -> None: ...

class GetDurableEventLogRequest(_message.Message):
__slots__ = ("external_id", "key")
EXTERNAL_ID_FIELD_NUMBER: _ClassVar[int]
KEY_FIELD_NUMBER: _ClassVar[int]
external_id: str
key: str
def __init__(self, external_id: _Optional[str] = ..., key: _Optional[str] = ...) -> None: ...

class GetDurableEventLogResponse(_message.Message):
__slots__ = ("found", "data")
FOUND_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
found: bool
data: bytes
def __init__(self, found: bool = ..., data: _Optional[bytes] = ...) -> None: ...

class CreateDurableEventLogRequest(_message.Message):
__slots__ = ("external_id", "key", "data")
EXTERNAL_ID_FIELD_NUMBER: _ClassVar[int]
KEY_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
external_id: str
key: str
data: bytes
def __init__(self, external_id: _Optional[str] = ..., key: _Optional[str] = ..., data: _Optional[bytes] = ...) -> None: ...

class CreateDurableEventLogResponse(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
86 changes: 86 additions & 0 deletions sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
request_serializer=v1_dot_dispatcher__pb2.ListenForDurableEventRequest.SerializeToString,
response_deserializer=v1_dot_dispatcher__pb2.DurableEvent.FromString,
_registered_method=True)
self.GetDurableEventLog = channel.unary_unary(
'/v1.V1Dispatcher/GetDurableEventLog',
request_serializer=v1_dot_dispatcher__pb2.GetDurableEventLogRequest.SerializeToString,
response_deserializer=v1_dot_dispatcher__pb2.GetDurableEventLogResponse.FromString,
_registered_method=True)
self.CreateDurableEventLog = channel.unary_unary(
'/v1.V1Dispatcher/CreateDurableEventLog',
request_serializer=v1_dot_dispatcher__pb2.CreateDurableEventLogRequest.SerializeToString,
response_deserializer=v1_dot_dispatcher__pb2.CreateDurableEventLogResponse.FromString,
_registered_method=True)


class V1DispatcherServicer(object):
Expand All @@ -61,6 +71,18 @@ def ListenForDurableEvent(self, request_iterator, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetDurableEventLog(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def CreateDurableEventLog(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_V1DispatcherServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -74,6 +96,16 @@ def add_V1DispatcherServicer_to_server(servicer, server):
request_deserializer=v1_dot_dispatcher__pb2.ListenForDurableEventRequest.FromString,
response_serializer=v1_dot_dispatcher__pb2.DurableEvent.SerializeToString,
),
'GetDurableEventLog': grpc.unary_unary_rpc_method_handler(
servicer.GetDurableEventLog,
request_deserializer=v1_dot_dispatcher__pb2.GetDurableEventLogRequest.FromString,
response_serializer=v1_dot_dispatcher__pb2.GetDurableEventLogResponse.SerializeToString,
),
'CreateDurableEventLog': grpc.unary_unary_rpc_method_handler(
servicer.CreateDurableEventLog,
request_deserializer=v1_dot_dispatcher__pb2.CreateDurableEventLogRequest.FromString,
response_serializer=v1_dot_dispatcher__pb2.CreateDurableEventLogResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'v1.V1Dispatcher', rpc_method_handlers)
Expand Down Expand Up @@ -138,3 +170,57 @@ def ListenForDurableEvent(request_iterator,
timeout,
metadata,
_registered_method=True)

@staticmethod
def GetDurableEventLog(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/v1.V1Dispatcher/GetDurableEventLog',
v1_dot_dispatcher__pb2.GetDurableEventLogRequest.SerializeToString,
v1_dot_dispatcher__pb2.GetDurableEventLogResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def CreateDurableEventLog(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/v1.V1Dispatcher/CreateDurableEventLog',
v1_dot_dispatcher__pb2.CreateDurableEventLogRequest.SerializeToString,
v1_dot_dispatcher__pb2.CreateDurableEventLogResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
Loading
Loading