diff --git a/Taskfile.yaml b/Taskfile.yaml index b8aa147de2..2edebaf3c7 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -210,6 +210,7 @@ tasks: cmds: - cd frontend/app/ && pnpm run lint:fix && pnpm run prettier:check - cd frontend/docs/ && pnpm run lint:fix && pnpm run prettier:check + - cd sdks/python/ && bash ./generate.sh - pre-commit run --all-files || pre-commit run --all-files docs: cmds: diff --git a/sdks/python/hatchet_sdk/clients/listeners/durable_event_listener.py b/sdks/python/hatchet_sdk/clients/listeners/durable_event_listener.py index 47957023da..0fb2b0aca8 100644 --- a/sdks/python/hatchet_sdk/clients/listeners/durable_event_listener.py +++ b/sdks/python/hatchet_sdk/clients/listeners/durable_event_listener.py @@ -12,7 +12,10 @@ from hatchet_sdk.config import ClientConfig from hatchet_sdk.connection import new_conn from hatchet_sdk.contracts.v1.dispatcher_pb2 import ( + CreateDurableEventLogRequest, DurableEvent, + GetDurableEventLogRequest, + GetDurableEventLogResponse, ListenForDurableEventRequest, ) from hatchet_sdk.contracts.v1.dispatcher_pb2 import ( @@ -121,6 +124,45 @@ def register_durable_event( return True + def get_durable_event_log( + self, external_id: str, key: str + ) -> GetDurableEventLogResponse: + conn = new_conn(self.config, True) + client = V1DispatcherStub(conn) + + get_durable_event_log = tenacity_retry( + client.GetDurableEventLog, self.config.tenacity + ) + + resp: GetDurableEventLogResponse = get_durable_event_log( + GetDurableEventLogRequest( + external_id=external_id, + key=key, + ), + timeout=5, + metadata=get_metadata(self.token), + ) + + return resp + + def create_durable_event_log(self, external_id: str, key: str, data: bytes) -> None: + conn = new_conn(self.config, True) + client = V1DispatcherStub(conn) + + create_durable_event_log = tenacity_retry( + client.CreateDurableEventLog, self.config.tenacity + ) + + create_durable_event_log( + CreateDurableEventLogRequest( + external_id=external_id, + key=key, + data=data, + ), + timeout=5, + metadata=get_metadata(self.token), + ) + async def result(self, task_id: str, signal_key: str) -> dict[str, Any]: key = self._generate_key(task_id, signal_key) diff --git a/sdks/python/hatchet_sdk/context/context.py b/sdks/python/hatchet_sdk/context/context.py index 549d7dc7f0..b5a8cba1d9 100644 --- a/sdks/python/hatchet_sdk/context/context.py +++ b/sdks/python/hatchet_sdk/context/context.py @@ -1,7 +1,9 @@ import asyncio +import hashlib import json +from collections.abc import Awaitable, Callable from datetime import timedelta -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast from warnings import warn from hatchet_sdk.clients.admin import AdminClient @@ -28,11 +30,20 @@ from hatchet_sdk.utils.typing import JSONSerializableMapping, LogLevel from hatchet_sdk.worker.runner.utils.capture_logs import AsyncLogSender, LogRecord +TMemo = TypeVar("TMemo") + if TYPE_CHECKING: from hatchet_sdk.runnables.task import Task from hatchet_sdk.runnables.types import R, TWorkflowInput +def _compute_memo_key(step_name: str, deps: list[Any]) -> str: + h = hashlib.sha256() + h.update(step_name.encode()) + h.update(json.dumps(deps, default=str).encode()) + return h.hexdigest() + + class Context: def __init__( self, @@ -520,3 +531,59 @@ async def aio_sleep_for(self, duration: Duration) -> dict[str, Any]: f"sleep:{timedelta_to_expr(duration)}-{wait_index}", SleepCondition(duration=duration), ) + + def memo(self, fn: Callable[[], TMemo], deps: list[Any]) -> TMemo: + if self.durable_event_listener is None: + raise ValueError("Durable event listener is not available") + + key = _compute_memo_key(self.action.action_id, deps) + + resp = self.durable_event_listener.get_durable_event_log( + external_id=self.workflow_run_id, + key=key, + ) + + if resp.found: + return cast(TMemo, json.loads(resp.data)) + + result = fn() + + data = json.dumps(result).encode() + + self.durable_event_listener.create_durable_event_log( + external_id=self.workflow_run_id, + key=key, + data=data, + ) + + return result + + async def aio_memo( + self, fn: Callable[[], Awaitable[TMemo]], deps: list[Any] + ) -> TMemo: + if self.durable_event_listener is None: + raise ValueError("Durable event listener is not available") + + key = _compute_memo_key(self.action.action_id, deps) + + resp = await asyncio.to_thread( + self.durable_event_listener.get_durable_event_log, + external_id=self.workflow_run_id, + key=key, + ) + + if resp.found: + return await asyncio.to_thread(json.loads, resp.data) + + result = await fn() + + data = (await asyncio.to_thread(json.dumps, result)).encode() + + await asyncio.to_thread( + self.durable_event_listener.create_durable_event_log, + external_id=self.workflow_run_id, + key=key, + data=data, + ) + + return result diff --git a/sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2.py b/sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2.py index 988661d822..3d2e97fe8a 100644 --- a/sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2.py +++ b/sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2.py @@ -25,7 +25,7 @@ from hatchet_sdk.contracts.v1.shared import condition_pb2 as v1_dot_shared_dot_condition__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13v1/dispatcher.proto\x12\x02v1\x1a\x19v1/shared/condition.proto\"z\n\x1bRegisterDurableEventRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x12\n\nsignal_key\x18\x02 \x01(\t\x12\x36\n\nconditions\x18\x03 \x01(\x0b\x32\".v1.DurableEventListenerConditions\"\x1e\n\x1cRegisterDurableEventResponse\"C\n\x1cListenForDurableEventRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x12\n\nsignal_key\x18\x02 \x01(\t\"A\n\x0c\x44urableEvent\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x12\n\nsignal_key\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x32\xbe\x01\n\x0cV1Dispatcher\x12[\n\x14RegisterDurableEvent\x12\x1f.v1.RegisterDurableEventRequest\x1a .v1.RegisterDurableEventResponse\"\x00\x12Q\n\x15ListenForDurableEvent\x12 .v1.ListenForDurableEventRequest\x1a\x10.v1.DurableEvent\"\x00(\x01\x30\x01\x42\x42Z@github.com/hatchet-dev/hatchet/internal/services/shared/proto/v1b\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13v1/dispatcher.proto\x12\x02v1\x1a\x19v1/shared/condition.proto\"z\n\x1bRegisterDurableEventRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x12\n\nsignal_key\x18\x02 \x01(\t\x12\x36\n\nconditions\x18\x03 \x01(\x0b\x32\".v1.DurableEventListenerConditions\"\x1e\n\x1cRegisterDurableEventResponse\"C\n\x1cListenForDurableEventRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x12\n\nsignal_key\x18\x02 \x01(\t\"A\n\x0c\x44urableEvent\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x12\n\nsignal_key\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"=\n\x19GetDurableEventLogRequest\x12\x13\n\x0b\x65xternal_id\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"9\n\x1aGetDurableEventLogResponse\x12\r\n\x05\x66ound\x18\x01 \x01(\x08\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"N\n\x1c\x43reateDurableEventLogRequest\x12\x13\n\x0b\x65xternal_id\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x1f\n\x1d\x43reateDurableEventLogResponse2\xf5\x02\n\x0cV1Dispatcher\x12[\n\x14RegisterDurableEvent\x12\x1f.v1.RegisterDurableEventRequest\x1a .v1.RegisterDurableEventResponse\"\x00\x12Q\n\x15ListenForDurableEvent\x12 .v1.ListenForDurableEventRequest\x1a\x10.v1.DurableEvent\"\x00(\x01\x30\x01\x12U\n\x12GetDurableEventLog\x12\x1d.v1.GetDurableEventLogRequest\x1a\x1e.v1.GetDurableEventLogResponse\"\x00\x12^\n\x15\x43reateDurableEventLog\x12 .v1.CreateDurableEventLogRequest\x1a!.v1.CreateDurableEventLogResponse\"\x00\x42\x42Z@github.com/hatchet-dev/hatchet/internal/services/shared/proto/v1b\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -41,6 +41,14 @@ _globals['_LISTENFORDURABLEEVENTREQUEST']._serialized_end=277 _globals['_DURABLEEVENT']._serialized_start=279 _globals['_DURABLEEVENT']._serialized_end=344 - _globals['_V1DISPATCHER']._serialized_start=347 - _globals['_V1DISPATCHER']._serialized_end=537 + _globals['_GETDURABLEEVENTLOGREQUEST']._serialized_start=346 + _globals['_GETDURABLEEVENTLOGREQUEST']._serialized_end=407 + _globals['_GETDURABLEEVENTLOGRESPONSE']._serialized_start=409 + _globals['_GETDURABLEEVENTLOGRESPONSE']._serialized_end=466 + _globals['_CREATEDURABLEEVENTLOGREQUEST']._serialized_start=468 + _globals['_CREATEDURABLEEVENTLOGREQUEST']._serialized_end=546 + _globals['_CREATEDURABLEEVENTLOGRESPONSE']._serialized_start=548 + _globals['_CREATEDURABLEEVENTLOGRESPONSE']._serialized_end=579 + _globals['_V1DISPATCHER']._serialized_start=582 + _globals['_V1DISPATCHER']._serialized_end=955 # @@protoc_insertion_point(module_scope) diff --git a/sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2.pyi b/sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2.pyi index c8b3ddc79a..d5aaf14126 100644 --- a/sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2.pyi +++ b/sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2.pyi @@ -37,3 +37,33 @@ class DurableEvent(_message.Message): signal_key: str data: bytes def __init__(self, task_id: _Optional[str] = ..., signal_key: _Optional[str] = ..., data: _Optional[bytes] = ...) -> None: ... + +class GetDurableEventLogRequest(_message.Message): + __slots__ = ("external_id", "key") + EXTERNAL_ID_FIELD_NUMBER: _ClassVar[int] + KEY_FIELD_NUMBER: _ClassVar[int] + external_id: str + key: str + def __init__(self, external_id: _Optional[str] = ..., key: _Optional[str] = ...) -> None: ... + +class GetDurableEventLogResponse(_message.Message): + __slots__ = ("found", "data") + FOUND_FIELD_NUMBER: _ClassVar[int] + DATA_FIELD_NUMBER: _ClassVar[int] + found: bool + data: bytes + def __init__(self, found: bool = ..., data: _Optional[bytes] = ...) -> None: ... + +class CreateDurableEventLogRequest(_message.Message): + __slots__ = ("external_id", "key", "data") + EXTERNAL_ID_FIELD_NUMBER: _ClassVar[int] + KEY_FIELD_NUMBER: _ClassVar[int] + DATA_FIELD_NUMBER: _ClassVar[int] + external_id: str + key: str + data: bytes + def __init__(self, external_id: _Optional[str] = ..., key: _Optional[str] = ..., data: _Optional[bytes] = ...) -> None: ... + +class CreateDurableEventLogResponse(_message.Message): + __slots__ = () + def __init__(self) -> None: ... diff --git a/sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py b/sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py index 74d39ceec3..fea3b61324 100644 --- a/sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py +++ b/sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py @@ -44,6 +44,16 @@ def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None: request_serializer=v1_dot_dispatcher__pb2.ListenForDurableEventRequest.SerializeToString, response_deserializer=v1_dot_dispatcher__pb2.DurableEvent.FromString, _registered_method=True) + self.GetDurableEventLog = channel.unary_unary( + '/v1.V1Dispatcher/GetDurableEventLog', + request_serializer=v1_dot_dispatcher__pb2.GetDurableEventLogRequest.SerializeToString, + response_deserializer=v1_dot_dispatcher__pb2.GetDurableEventLogResponse.FromString, + _registered_method=True) + self.CreateDurableEventLog = channel.unary_unary( + '/v1.V1Dispatcher/CreateDurableEventLog', + request_serializer=v1_dot_dispatcher__pb2.CreateDurableEventLogRequest.SerializeToString, + response_deserializer=v1_dot_dispatcher__pb2.CreateDurableEventLogResponse.FromString, + _registered_method=True) class V1DispatcherServicer(object): @@ -61,6 +71,18 @@ def ListenForDurableEvent(self, request_iterator, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetDurableEventLog(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreateDurableEventLog(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_V1DispatcherServicer_to_server(servicer, server): rpc_method_handlers = { @@ -74,6 +96,16 @@ def add_V1DispatcherServicer_to_server(servicer, server): request_deserializer=v1_dot_dispatcher__pb2.ListenForDurableEventRequest.FromString, response_serializer=v1_dot_dispatcher__pb2.DurableEvent.SerializeToString, ), + 'GetDurableEventLog': grpc.unary_unary_rpc_method_handler( + servicer.GetDurableEventLog, + request_deserializer=v1_dot_dispatcher__pb2.GetDurableEventLogRequest.FromString, + response_serializer=v1_dot_dispatcher__pb2.GetDurableEventLogResponse.SerializeToString, + ), + 'CreateDurableEventLog': grpc.unary_unary_rpc_method_handler( + servicer.CreateDurableEventLog, + request_deserializer=v1_dot_dispatcher__pb2.CreateDurableEventLogRequest.FromString, + response_serializer=v1_dot_dispatcher__pb2.CreateDurableEventLogResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'v1.V1Dispatcher', rpc_method_handlers) @@ -138,3 +170,57 @@ def ListenForDurableEvent(request_iterator, timeout, metadata, _registered_method=True) + + @staticmethod + def GetDurableEventLog(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/v1.V1Dispatcher/GetDurableEventLog', + v1_dot_dispatcher__pb2.GetDurableEventLogRequest.SerializeToString, + v1_dot_dispatcher__pb2.GetDurableEventLogResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def CreateDurableEventLog(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/v1.V1Dispatcher/CreateDurableEventLog', + v1_dot_dispatcher__pb2.CreateDurableEventLogRequest.SerializeToString, + v1_dot_dispatcher__pb2.CreateDurableEventLogResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/sdks/python/poetry.lock b/sdks/python/poetry.lock index c0f202c034..38198b6e4b 100644 --- a/sdks/python/poetry.lock +++ b/sdks/python/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.0.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -153,7 +153,7 @@ propcache = ">=0.2.0" yarl = ">=1.17.0,<2.0" [package.extras] -speedups = ["Brotli (>=1.2)", "aiodns (>=3.3.0)", "backports.zstd", "brotlicffi (>=1.2)"] +speedups = ["Brotli (>=1.2) ; platform_python_implementation == \"CPython\"", "aiodns (>=3.3.0)", "backports.zstd ; platform_python_implementation == \"CPython\" and python_version < \"3.14\"", "brotlicffi (>=1.2) ; platform_python_implementation != \"CPython\""] [[package]] name = "aiosignal" @@ -201,7 +201,7 @@ idna = ">=2.8" typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} [package.extras] -trio = ["trio (>=0.31.0)", "trio (>=0.32.0)"] +trio = ["trio (>=0.31.0) ; python_version < \"3.10\"", "trio (>=0.32.0) ; python_version >= \"3.10\""] [[package]] name = "async-timeout" @@ -210,7 +210,7 @@ description = "Timeout context manager for asyncio programs" optional = false python-versions = ">=3.8" groups = ["main"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, @@ -503,7 +503,7 @@ files = [ ] [package.extras] -dev = ["docstring-parser[docs]", "docstring-parser[test]", "pre-commit (>=2.16.0)"] +dev = ["docstring-parser[docs]", "docstring-parser[test]", "pre-commit (>=2.16.0) ; python_version >= \"3.9\""] docs = ["pydoctor (>=25.4.0)"] test = ["pytest"] @@ -514,7 +514,7 @@ description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" groups = ["docs", "test"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598"}, {file = "exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219"}, @@ -978,7 +978,7 @@ httpcore = "==1.*" idna = "*" [package.extras] -brotli = ["brotli", "brotlicffi"] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] @@ -1016,12 +1016,12 @@ files = [ zipp = ">=3.20" [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] enabler = ["pytest-enabler (>=2.2)"] perf = ["ipython"] -test = ["flufl.flake8", "importlib_resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"] +test = ["flufl.flake8", "importlib_resources (>=1.3) ; python_version < \"3.9\"", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"] type = ["pytest-mypy"] [[package]] @@ -1440,7 +1440,7 @@ watchdog = ">=2.0" [package.extras] i18n = ["babel (>=2.9.0)"] -min-versions = ["babel (==2.9.0)", "click (==7.0)", "colorama (==0.4)", "ghp-import (==1.0)", "importlib-metadata (==4.4)", "jinja2 (==2.11.1)", "markdown (==3.3.6)", "markupsafe (==2.0.1)", "mergedeep (==1.3.4)", "mkdocs-get-deps (==0.2.0)", "packaging (==20.5)", "pathspec (==0.11.1)", "pyyaml (==5.1)", "pyyaml-env-tag (==0.1)", "watchdog (==2.0)"] +min-versions = ["babel (==2.9.0)", "click (==7.0)", "colorama (==0.4) ; platform_system == \"Windows\"", "ghp-import (==1.0)", "importlib-metadata (==4.4) ; python_version < \"3.10\"", "jinja2 (==2.11.1)", "markdown (==3.3.6)", "markupsafe (==2.0.1)", "mergedeep (==1.3.4)", "mkdocs-get-deps (==0.2.0)", "packaging (==20.5)", "pathspec (==0.11.1)", "pyyaml (==5.1)", "pyyaml-env-tag (==0.1)", "watchdog (==2.0)"] [[package]] name = "mkdocs-autorefs" @@ -2249,12 +2249,12 @@ typing-extensions = {version = ">=4.6", markers = "python_version < \"3.13\""} tzdata = {version = "*", markers = "sys_platform == \"win32\""} [package.extras] -binary = ["psycopg-binary (==3.3.2)"] -c = ["psycopg-c (==3.3.2)"] +binary = ["psycopg-binary (==3.3.2) ; implementation_name != \"pypy\""] +c = ["psycopg-c (==3.3.2) ; implementation_name != \"pypy\""] dev = ["ast-comments (>=1.1.2)", "black (>=24.1.0)", "codespell (>=2.2)", "cython-lint (>=0.16)", "dnspython (>=2.1)", "flake8 (>=4.0)", "isort-psycopg", "isort[colors] (>=6.0)", "mypy (>=1.19.0)", "pre-commit (>=4.0.1)", "types-setuptools (>=57.4)", "types-shapely (>=2.0)", "wheel (>=0.37)"] docs = ["Sphinx (>=5.0)", "furo (==2022.6.21)", "sphinx-autobuild (>=2021.3.14)", "sphinx-autodoc-typehints (>=1.12)"] pool = ["psycopg-pool"] -test = ["anyio (>=4.0)", "mypy (>=1.19.0)", "pproxy (>=2.7)", "pytest (>=6.2.5)", "pytest-cov (>=3.0)", "pytest-randomly (>=3.5)"] +test = ["anyio (>=4.0)", "mypy (>=1.19.0) ; implementation_name != \"pypy\"", "pproxy (>=2.7)", "pytest (>=6.2.5)", "pytest-cov (>=3.0)", "pytest-randomly (>=3.5)"] [[package]] name = "psycopg-pool" @@ -2294,7 +2294,7 @@ typing-inspection = ">=0.4.2" [package.extras] email = ["email-validator (>=2.0.0)"] -timezone = ["tzdata"] +timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""] [[package]] name = "pydantic-core" @@ -2801,13 +2801,13 @@ files = [ ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"] -core = ["importlib_metadata (>=6)", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] +core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] enabler = ["pytest-enabler (>=2.2)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] -type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"] [[package]] name = "six" @@ -2886,7 +2886,7 @@ description = "A lil' TOML parser" optional = false python-versions = ">=3.8" groups = ["docs", "lint", "test"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45"}, {file = "tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba"}, @@ -3088,10 +3088,10 @@ files = [ ] [package.extras] -brotli = ["brotli (>=1.2.0)", "brotlicffi (>=1.2.0.0)"] +brotli = ["brotli (>=1.2.0) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=1.2.0.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] -zstd = ["backports-zstd (>=1.0.0)"] +zstd = ["backports-zstd (>=1.0.0) ; python_version < \"3.14\""] [[package]] name = "uvicorn" @@ -3111,7 +3111,7 @@ h11 = ">=0.8" typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} [package.extras] -standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] +standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.15.1) ; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"", "watchfiles (>=0.13)", "websockets (>=10.4)"] [[package]] name = "watchdog" @@ -3407,7 +3407,7 @@ files = [ ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] enabler = ["pytest-enabler (>=2.2)"] diff --git a/sdks/typescript/src/clients/listeners/durable-listener/durable-listener-client.ts b/sdks/typescript/src/clients/listeners/durable-listener/durable-listener-client.ts index 93f019443e..b4c6aee8ca 100644 --- a/sdks/typescript/src/clients/listeners/durable-listener/durable-listener-client.ts +++ b/sdks/typescript/src/clients/listeners/durable-listener/durable-listener-client.ts @@ -3,7 +3,12 @@ import { Channel, ClientFactory } from 'nice-grpc'; import { ClientConfig } from '@clients/hatchet-client/client-config'; import { Logger } from '@hatchet/util/logger'; -import { V1DispatcherClient, V1DispatcherDefinition } from '@hatchet/protoc/v1/dispatcher'; +import { + GetDurableEventLogResponse, + CreateDurableEventLogResponse, + V1DispatcherClient, + V1DispatcherDefinition, +} from '@hatchet/protoc/v1/dispatcher'; import { SleepMatchCondition, UserEventMatchCondition } from '@hatchet/protoc/v1/shared/condition'; import { Api } from '../../rest'; import { DurableEventGrpcPooledListener } from './pooled-durable-listener-client'; @@ -47,4 +52,19 @@ export class DurableListenerClient { return this.pooledListener.registerDurableEvent(request); } + + async getDurableEventLog(request: { + externalId: string; + key: string; + }): Promise { + return this.client.getDurableEventLog(request); + } + + async createDurableEventLog(request: { + externalId: string; + key: string; + data: Uint8Array; + }): Promise { + return this.client.createDurableEventLog(request); + } } diff --git a/sdks/typescript/src/protoc/v1/dispatcher.ts b/sdks/typescript/src/protoc/v1/dispatcher.ts index 8df1d1e8e9..7eb6a57176 100644 --- a/sdks/typescript/src/protoc/v1/dispatcher.ts +++ b/sdks/typescript/src/protoc/v1/dispatcher.ts @@ -36,6 +36,24 @@ export interface DurableEvent { data: Uint8Array; } +export interface GetDurableEventLogRequest { + externalId: string; + key: string; +} + +export interface GetDurableEventLogResponse { + found: boolean; + data: Uint8Array; +} + +export interface CreateDurableEventLogRequest { + externalId: string; + key: string; + data: Uint8Array; +} + +export interface CreateDurableEventLogResponse {} + function createBaseRegisterDurableEventRequest(): RegisterDurableEventRequest { return { taskId: '', signalKey: '', conditions: undefined }; } @@ -350,6 +368,305 @@ export const DurableEvent: MessageFns = { }, }; +function createBaseGetDurableEventLogRequest(): GetDurableEventLogRequest { + return { externalId: '', key: '' }; +} + +export const GetDurableEventLogRequest: MessageFns = { + encode( + message: GetDurableEventLogRequest, + writer: BinaryWriter = new BinaryWriter() + ): BinaryWriter { + if (message.externalId !== '') { + writer.uint32(10).string(message.externalId); + } + if (message.key !== '') { + writer.uint32(18).string(message.key); + } + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): GetDurableEventLogRequest { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGetDurableEventLogRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 10) { + break; + } + + message.externalId = reader.string(); + continue; + } + case 2: { + if (tag !== 18) { + break; + } + + message.key = reader.string(); + continue; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(object: any): GetDurableEventLogRequest { + return { + externalId: isSet(object.externalId) ? globalThis.String(object.externalId) : '', + key: isSet(object.key) ? globalThis.String(object.key) : '', + }; + }, + + toJSON(message: GetDurableEventLogRequest): unknown { + const obj: any = {}; + if (message.externalId !== '') { + obj.externalId = message.externalId; + } + if (message.key !== '') { + obj.key = message.key; + } + return obj; + }, + + create(base?: DeepPartial): GetDurableEventLogRequest { + return GetDurableEventLogRequest.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): GetDurableEventLogRequest { + const message = createBaseGetDurableEventLogRequest(); + message.externalId = object.externalId ?? ''; + message.key = object.key ?? ''; + return message; + }, +}; + +function createBaseGetDurableEventLogResponse(): GetDurableEventLogResponse { + return { found: false, data: new Uint8Array(0) }; +} + +export const GetDurableEventLogResponse: MessageFns = { + encode( + message: GetDurableEventLogResponse, + writer: BinaryWriter = new BinaryWriter() + ): BinaryWriter { + if (message.found !== false) { + writer.uint32(8).bool(message.found); + } + if (message.data.length !== 0) { + writer.uint32(18).bytes(message.data); + } + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): GetDurableEventLogResponse { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGetDurableEventLogResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 8) { + break; + } + + message.found = reader.bool(); + continue; + } + case 2: { + if (tag !== 18) { + break; + } + + message.data = reader.bytes(); + continue; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(object: any): GetDurableEventLogResponse { + return { + found: isSet(object.found) ? globalThis.Boolean(object.found) : false, + data: isSet(object.data) ? bytesFromBase64(object.data) : new Uint8Array(0), + }; + }, + + toJSON(message: GetDurableEventLogResponse): unknown { + const obj: any = {}; + if (message.found !== false) { + obj.found = message.found; + } + if (message.data.length !== 0) { + obj.data = base64FromBytes(message.data); + } + return obj; + }, + + create(base?: DeepPartial): GetDurableEventLogResponse { + return GetDurableEventLogResponse.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): GetDurableEventLogResponse { + const message = createBaseGetDurableEventLogResponse(); + message.found = object.found ?? false; + message.data = object.data ?? new Uint8Array(0); + return message; + }, +}; + +function createBaseCreateDurableEventLogRequest(): CreateDurableEventLogRequest { + return { externalId: '', key: '', data: new Uint8Array(0) }; +} + +export const CreateDurableEventLogRequest: MessageFns = { + encode( + message: CreateDurableEventLogRequest, + writer: BinaryWriter = new BinaryWriter() + ): BinaryWriter { + if (message.externalId !== '') { + writer.uint32(10).string(message.externalId); + } + if (message.key !== '') { + writer.uint32(18).string(message.key); + } + if (message.data.length !== 0) { + writer.uint32(26).bytes(message.data); + } + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): CreateDurableEventLogRequest { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseCreateDurableEventLogRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 10) { + break; + } + + message.externalId = reader.string(); + continue; + } + case 2: { + if (tag !== 18) { + break; + } + + message.key = reader.string(); + continue; + } + case 3: { + if (tag !== 26) { + break; + } + + message.data = reader.bytes(); + continue; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(object: any): CreateDurableEventLogRequest { + return { + externalId: isSet(object.externalId) ? globalThis.String(object.externalId) : '', + key: isSet(object.key) ? globalThis.String(object.key) : '', + data: isSet(object.data) ? bytesFromBase64(object.data) : new Uint8Array(0), + }; + }, + + toJSON(message: CreateDurableEventLogRequest): unknown { + const obj: any = {}; + if (message.externalId !== '') { + obj.externalId = message.externalId; + } + if (message.key !== '') { + obj.key = message.key; + } + if (message.data.length !== 0) { + obj.data = base64FromBytes(message.data); + } + return obj; + }, + + create(base?: DeepPartial): CreateDurableEventLogRequest { + return CreateDurableEventLogRequest.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): CreateDurableEventLogRequest { + const message = createBaseCreateDurableEventLogRequest(); + message.externalId = object.externalId ?? ''; + message.key = object.key ?? ''; + message.data = object.data ?? new Uint8Array(0); + return message; + }, +}; + +function createBaseCreateDurableEventLogResponse(): CreateDurableEventLogResponse { + return {}; +} + +export const CreateDurableEventLogResponse: MessageFns = { + encode( + _: CreateDurableEventLogResponse, + writer: BinaryWriter = new BinaryWriter() + ): BinaryWriter { + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): CreateDurableEventLogResponse { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseCreateDurableEventLogResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(_: any): CreateDurableEventLogResponse { + return {}; + }, + + toJSON(_: CreateDurableEventLogResponse): unknown { + const obj: any = {}; + return obj; + }, + + create(base?: DeepPartial): CreateDurableEventLogResponse { + return CreateDurableEventLogResponse.fromPartial(base ?? {}); + }, + fromPartial(_: DeepPartial): CreateDurableEventLogResponse { + const message = createBaseCreateDurableEventLogResponse(); + return message; + }, +}; + export type V1DispatcherDefinition = typeof V1DispatcherDefinition; export const V1DispatcherDefinition = { name: 'V1Dispatcher', @@ -371,6 +688,22 @@ export const V1DispatcherDefinition = { responseStream: true, options: {}, }, + getDurableEventLog: { + name: 'GetDurableEventLog', + requestType: GetDurableEventLogRequest, + requestStream: false, + responseType: GetDurableEventLogResponse, + responseStream: false, + options: {}, + }, + createDurableEventLog: { + name: 'CreateDurableEventLog', + requestType: CreateDurableEventLogRequest, + requestStream: false, + responseType: CreateDurableEventLogResponse, + responseStream: false, + options: {}, + }, }, } as const; @@ -383,6 +716,14 @@ export interface V1DispatcherServiceImplementation { request: AsyncIterable, context: CallContext & CallContextExt ): ServerStreamingMethodResult>; + getDurableEventLog( + request: GetDurableEventLogRequest, + context: CallContext & CallContextExt + ): Promise>; + createDurableEventLog( + request: CreateDurableEventLogRequest, + context: CallContext & CallContextExt + ): Promise>; } export interface V1DispatcherClient { @@ -394,6 +735,14 @@ export interface V1DispatcherClient { request: AsyncIterable>, options?: CallOptions & CallOptionsExt ): AsyncIterable; + getDurableEventLog( + request: DeepPartial, + options?: CallOptions & CallOptionsExt + ): Promise; + createDurableEventLog( + request: DeepPartial, + options?: CallOptions & CallOptionsExt + ): Promise; } function bytesFromBase64(b64: string): Uint8Array { diff --git a/sdks/typescript/src/v1/client/worker/context.ts b/sdks/typescript/src/v1/client/worker/context.ts index adc95e9b42..6c357784e5 100644 --- a/sdks/typescript/src/v1/client/worker/context.ts +++ b/sdks/typescript/src/v1/client/worker/context.ts @@ -1,5 +1,6 @@ /* eslint-disable no-underscore-dangle */ /* eslint-disable max-classes-per-file */ +import { createHash } from 'crypto'; import { Priority, RunOpts, @@ -752,4 +753,38 @@ export class DurableContext extends Context { const res = JSON.parse(eventData) as Record>; return res.CREATE; } + + async memo(fn: () => Promise, deps: any[]): Promise { + const key = computeMemoKey(this.action.stepName, deps); + + const resp = await this.v1._v0.durableListener.getDurableEventLog({ + externalId: this.action.workflowRunId, + key, + }); + + if (resp.found) { + const data = + resp.data instanceof Uint8Array ? new TextDecoder().decode(resp.data) : resp.data; + return JSON.parse(data) as R; + } + + const result = await fn(); + + const data = new TextEncoder().encode(JSON.stringify(result)); + + await this.v1._v0.durableListener.createDurableEventLog({ + externalId: this.action.workflowRunId, + key, + data, + }); + + return result; + } +} + +function computeMemoKey(stepName: string, deps: any[]): string { + const h = createHash('sha256'); + h.update(stepName); + h.update(JSON.stringify(deps)); + return h.digest('hex'); } diff --git a/sdks/typescript/src/version.ts b/sdks/typescript/src/version.ts index 8435f06e99..c408357ada 100644 --- a/sdks/typescript/src/version.ts +++ b/sdks/typescript/src/version.ts @@ -1 +1 @@ -export const HATCHET_VERSION = '1.9.5'; +export const HATCHET_VERSION = '1.10.8';