Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Taskfile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ tasks:
cmds:
- cd frontend/app/ && pnpm run lint:fix && pnpm run prettier:check
- cd frontend/docs/ && pnpm run lint:fix && pnpm run prettier:check
- cd sdks/python/ && poetry install --all-extras
- cd sdks/python/ && bash ./lint.sh
- pre-commit run --all-files || pre-commit run --all-files
docs:
cmds:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from hatchet_sdk.config import ClientConfig
from hatchet_sdk.connection import new_conn
from hatchet_sdk.contracts.v1.dispatcher_pb2 import (
CreateDurableEventLogRequest,
DurableEvent,
GetDurableEventLogRequest,
GetDurableEventLogResponse,
ListenForDurableEventRequest,
)
from hatchet_sdk.contracts.v1.dispatcher_pb2 import (
Expand Down Expand Up @@ -121,6 +124,43 @@ def register_durable_event(

return True

def get_durable_event_log(
self, external_id: str, key: str
) -> GetDurableEventLogResponse:
conn = new_conn(self.config, True)
client = V1DispatcherStub(conn)

get_durable_event_log = tenacity_retry(
client.GetDurableEventLog, self.config.tenacity
)

return get_durable_event_log(
GetDurableEventLogRequest(
external_id=external_id,
key=key,
),
timeout=5,
metadata=get_metadata(self.token),
)

def create_durable_event_log(self, external_id: str, key: str, data: bytes) -> None:
conn = new_conn(self.config, True)
client = V1DispatcherStub(conn)

create_durable_event_log = tenacity_retry(
client.CreateDurableEventLog, self.config.tenacity
)

create_durable_event_log(
CreateDurableEventLogRequest(
external_id=external_id,
key=key,
data=data,
),
timeout=5,
metadata=get_metadata(self.token),
)

async def result(self, task_id: str, signal_key: str) -> dict[str, Any]:
key = self._generate_key(task_id, signal_key)

Expand Down
69 changes: 68 additions & 1 deletion sdks/python/hatchet_sdk/context/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import hashlib
import json
from collections.abc import Awaitable, Callable
from datetime import timedelta
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast
from warnings import warn

from hatchet_sdk.clients.admin import AdminClient
Expand All @@ -28,11 +30,20 @@
from hatchet_sdk.utils.typing import JSONSerializableMapping, LogLevel
from hatchet_sdk.worker.runner.utils.capture_logs import AsyncLogSender, LogRecord

TMemo = TypeVar("TMemo")

if TYPE_CHECKING:
from hatchet_sdk.runnables.task import Task
from hatchet_sdk.runnables.types import R, TWorkflowInput


def _compute_memo_key(step_name: str, deps: list[Any]) -> str:
h = hashlib.sha256()
h.update(step_name.encode())
h.update(json.dumps(deps, default=str).encode())
return h.hexdigest()


class Context:
def __init__(
self,
Expand Down Expand Up @@ -504,3 +515,59 @@ async def aio_sleep_for(self, duration: Duration) -> dict[str, Any]:
f"sleep:{timedelta_to_expr(duration)}-{wait_index}",
SleepCondition(duration=duration),
)

def memo(self, fn: Callable[[], TMemo], deps: list[Any]) -> TMemo:
if self.durable_event_listener is None:
raise ValueError("Durable event listener is not available")

key = _compute_memo_key(self.action.action_id, deps)

resp = self.durable_event_listener.get_durable_event_log(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be wrapped in asyncio.to_thread so it doesn't block

external_id=self.workflow_run_id,
key=key,
)

if resp.found:
return json.loads(resp.data) # type: ignore[no-any-return]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a small thing, but I have a slight preference for this (we do it elsewhere) because it's clearer when reading the code what's going on:

return cast(TMemo, json.loads(resp.data))


result = fn()

data = json.dumps(result).encode()

self.durable_event_listener.create_durable_event_log(
external_id=self.workflow_run_id,
key=key,
data=data,
)
Comment on lines +551 to +557
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and to_thread for these ones :)


return result

async def aio_memo(
self, fn: Callable[[], Awaitable[TMemo]], deps: list[Any]
) -> TMemo:
if self.durable_event_listener is None:
raise ValueError("Durable event listener is not available")

key = _compute_memo_key(self.action.action_id, deps)

resp = await asyncio.to_thread(
self.durable_event_listener.get_durable_event_log,
external_id=self.workflow_run_id,
key=key,
)

if resp.found:
return await asyncio.to_thread(json.loads, resp.data) # type: ignore[no-any-return]

result = await fn()

data = json.dumps(result).encode()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to_thread for the json.dumps too :P

one thing I'm curious about here too is what happens if the function e.g. returns a pydantic model or a dataclass or something we can't call json.dumps on. I think we probably should support dataclasses, Pydantic, etc. but it makes the (de)serialization and typing trickier here, but it also makes the DX much better


await asyncio.to_thread(
self.durable_event_listener.create_durable_event_log,
external_id=self.workflow_run_id,
key=key,
data=data,
)

return result
14 changes: 11 additions & 3 deletions sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 31 additions & 1 deletion sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from hatchet_sdk.contracts.v1.shared import condition_pb2 as _condition_pb2
from v1.shared import condition_pb2 as _condition_pb2
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from collections.abc import Mapping as _Mapping
Expand Down Expand Up @@ -37,3 +37,33 @@ class DurableEvent(_message.Message):
signal_key: str
data: bytes
def __init__(self, task_id: _Optional[str] = ..., signal_key: _Optional[str] = ..., data: _Optional[bytes] = ...) -> None: ...

class GetDurableEventLogRequest(_message.Message):
__slots__ = ("external_id", "key")
EXTERNAL_ID_FIELD_NUMBER: _ClassVar[int]
KEY_FIELD_NUMBER: _ClassVar[int]
external_id: str
key: str
def __init__(self, external_id: _Optional[str] = ..., key: _Optional[str] = ...) -> None: ...

class GetDurableEventLogResponse(_message.Message):
__slots__ = ("found", "data")
FOUND_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
found: bool
data: bytes
def __init__(self, found: bool = ..., data: _Optional[bytes] = ...) -> None: ...

class CreateDurableEventLogRequest(_message.Message):
__slots__ = ("external_id", "key", "data")
EXTERNAL_ID_FIELD_NUMBER: _ClassVar[int]
KEY_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
external_id: str
key: str
data: bytes
def __init__(self, external_id: _Optional[str] = ..., key: _Optional[str] = ..., data: _Optional[bytes] = ...) -> None: ...

class CreateDurableEventLogResponse(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
88 changes: 87 additions & 1 deletion sdks/python/hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
class V1DispatcherStub(object):
"""Missing associated documentation comment in .proto file."""

def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
def __init__(self, channel):
"""Constructor.

Args:
Expand All @@ -44,6 +44,16 @@ def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
request_serializer=v1_dot_dispatcher__pb2.ListenForDurableEventRequest.SerializeToString,
response_deserializer=v1_dot_dispatcher__pb2.DurableEvent.FromString,
_registered_method=True)
self.GetDurableEventLog = channel.unary_unary(
'/v1.V1Dispatcher/GetDurableEventLog',
request_serializer=v1_dot_dispatcher__pb2.GetDurableEventLogRequest.SerializeToString,
response_deserializer=v1_dot_dispatcher__pb2.GetDurableEventLogResponse.FromString,
_registered_method=True)
self.CreateDurableEventLog = channel.unary_unary(
'/v1.V1Dispatcher/CreateDurableEventLog',
request_serializer=v1_dot_dispatcher__pb2.CreateDurableEventLogRequest.SerializeToString,
response_deserializer=v1_dot_dispatcher__pb2.CreateDurableEventLogResponse.FromString,
_registered_method=True)


class V1DispatcherServicer(object):
Expand All @@ -61,6 +71,18 @@ def ListenForDurableEvent(self, request_iterator, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetDurableEventLog(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def CreateDurableEventLog(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_V1DispatcherServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -74,6 +96,16 @@ def add_V1DispatcherServicer_to_server(servicer, server):
request_deserializer=v1_dot_dispatcher__pb2.ListenForDurableEventRequest.FromString,
response_serializer=v1_dot_dispatcher__pb2.DurableEvent.SerializeToString,
),
'GetDurableEventLog': grpc.unary_unary_rpc_method_handler(
servicer.GetDurableEventLog,
request_deserializer=v1_dot_dispatcher__pb2.GetDurableEventLogRequest.FromString,
response_serializer=v1_dot_dispatcher__pb2.GetDurableEventLogResponse.SerializeToString,
),
'CreateDurableEventLog': grpc.unary_unary_rpc_method_handler(
servicer.CreateDurableEventLog,
request_deserializer=v1_dot_dispatcher__pb2.CreateDurableEventLogRequest.FromString,
response_serializer=v1_dot_dispatcher__pb2.CreateDurableEventLogResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'v1.V1Dispatcher', rpc_method_handlers)
Expand Down Expand Up @@ -138,3 +170,57 @@ def ListenForDurableEvent(request_iterator,
timeout,
metadata,
_registered_method=True)

@staticmethod
def GetDurableEventLog(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/v1.V1Dispatcher/GetDurableEventLog',
v1_dot_dispatcher__pb2.GetDurableEventLogRequest.SerializeToString,
v1_dot_dispatcher__pb2.GetDurableEventLogResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def CreateDurableEventLog(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/v1.V1Dispatcher/CreateDurableEventLog',
v1_dot_dispatcher__pb2.CreateDurableEventLogRequest.SerializeToString,
v1_dot_dispatcher__pb2.CreateDurableEventLogResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ import { Channel, ClientFactory } from 'nice-grpc';

import { ClientConfig } from '@clients/hatchet-client/client-config';
import { Logger } from '@hatchet/util/logger';
import { V1DispatcherClient, V1DispatcherDefinition } from '@hatchet/protoc/v1/dispatcher';
import {
GetDurableEventLogResponse,
CreateDurableEventLogResponse,
V1DispatcherClient,
V1DispatcherDefinition,
} from '@hatchet/protoc/v1/dispatcher';
import { SleepMatchCondition, UserEventMatchCondition } from '@hatchet/protoc/v1/shared/condition';
import { Api } from '../../rest';
import { DurableEventGrpcPooledListener } from './pooled-durable-listener-client';
Expand Down Expand Up @@ -47,4 +52,19 @@ export class DurableListenerClient {

return this.pooledListener.registerDurableEvent(request);
}

async getDurableEventLog(request: {
externalId: string;
key: string;
}): Promise<GetDurableEventLogResponse> {
return this.client.getDurableEventLog(request);
}

async createDurableEventLog(request: {
externalId: string;
key: string;
data: Uint8Array;
}): Promise<CreateDurableEventLogResponse> {
return this.client.createDurableEventLog(request);
}
}
Loading
Loading