Skip to content

Commit

Permalink
IWF-357: Add internal channel TypeStore
Browse files Browse the repository at this point in the history
  • Loading branch information
lwolczynski committed Jan 3, 2025
1 parent b810adb commit 7c94d6b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 24 deletions.
13 changes: 4 additions & 9 deletions iwf/command_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from iwf.iwf_api.types import Unset
from iwf.object_encoder import ObjectEncoder
from iwf.type_store import TypeStore


@dataclass
Expand Down Expand Up @@ -43,7 +44,7 @@ class CommandResults:

def from_idl_command_results(
idl_results: Union[Unset, IdlCommandResults],
internal_channel_types: dict[str, typing.Optional[type]],
internal_channel_types: TypeStore,
signal_channel_types: dict[str, typing.Optional[type]],
object_encoder: ObjectEncoder,
) -> CommandResults:
Expand All @@ -58,14 +59,8 @@ def from_idl_command_results(

if not isinstance(idl_results.inter_state_channel_results, Unset):
for inter in idl_results.inter_state_channel_results:
val_type = internal_channel_types.get(inter.channel_name)
if val_type is None:
# fallback to assume it's prefix
# TODO use is_prefix to implement like Java SDK
for name, t in internal_channel_types.items():
if inter.channel_name.startswith(name):
val_type = t
break
val_type = internal_channel_types.get_type(inter.channel_name)

if val_type is None:
raise WorkflowDefinitionError(
"internal channel is not registered: " + inter.channel_name
Expand Down
12 changes: 4 additions & 8 deletions iwf/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
)
from iwf.object_encoder import ObjectEncoder
from iwf.state_movement import StateMovement
from iwf.type_store import TypeStore


class Communication:
_internal_channel_type_store: dict[str, Optional[type]]
_internal_channel_type_store: TypeStore
_signal_channel_type_store: dict[str, Optional[type]]
_object_encoder: ObjectEncoder
_to_publish_internal_channel: dict[str, list[EncodedObject]]
Expand All @@ -22,7 +23,7 @@ class Communication:

def __init__(
self,
internal_channel_type_store: dict[str, Optional[type]],
internal_channel_type_store: TypeStore,
signal_channel_type_store: dict[str, Optional[type]],
object_encoder: ObjectEncoder,
internal_channel_infos: Optional[WorkflowWorkerRpcRequestInternalChannelInfos],
Expand All @@ -47,12 +48,7 @@ def trigger_state_execution(self, state: Union[str, type], state_input: Any = No
self._state_movements.append(movement)

def publish_to_internal_channel(self, channel_name: str, value: Any = None):
registered_type = self._internal_channel_type_store.get(channel_name)

if registered_type is None:
for name, t in self._internal_channel_type_store.items():
if channel_name.startswith(name):
registered_type = t
registered_type = self._internal_channel_type_store.get_type(channel_name)

if registered_type is None:
raise WorkflowDefinitionError(
Expand Down
15 changes: 8 additions & 7 deletions iwf/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from iwf.errors import InvalidArgumentError, WorkflowDefinitionError
from iwf.persistence_schema import PersistenceFieldType
from iwf.rpc import RPCInfo
from iwf.type_store import TypeStore, Type
from iwf.workflow import ObjectWorkflow, get_workflow_type
from iwf.workflow_state import WorkflowState, get_state_id

Expand All @@ -12,7 +13,7 @@ class Registry:
_workflow_store: dict[str, ObjectWorkflow]
_starting_state_store: dict[str, WorkflowState]
_state_store: dict[str, dict[str, WorkflowState]]
_internal_channel_type_store: dict[str, dict[str, Optional[type]]]
_internal_channel_type_store: dict[str, TypeStore]
_signal_channel_type_store: dict[str, dict[str, Optional[type]]]
_data_attribute_types: dict[str, dict[str, Optional[type]]]
_rpc_infos: dict[str, dict[str, RPCInfo]]
Expand Down Expand Up @@ -63,7 +64,7 @@ def get_workflow_state_with_check(
def get_state_store(self, wf_type: str) -> dict[str, WorkflowState]:
return self._state_store[wf_type]

def get_internal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]:
def get_internal_channel_type_store(self, wf_type: str) -> TypeStore:
return self._internal_channel_type_store[wf_type]

def get_signal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]:
Expand All @@ -83,13 +84,13 @@ def _register_workflow_type(self, wf: ObjectWorkflow):

def _register_internal_channels(self, wf: ObjectWorkflow):
wf_type = get_workflow_type(wf)
types: dict[str, Optional[type]] = {}

if self._internal_channel_type_store[wf_type] is None:
self._internal_channel_type_store[wf_type] = TypeStore(Type.INTERNAL_CHANNEL)

for method in wf.get_communication_schema().communication_methods:
if method.method_type == CommunicationMethodType.InternalChannel:
types[method.name] = method.value_type
# TODO use is_prefix to implement like Java SDK
#
self._internal_channel_type_store[wf_type] = types
self._internal_channel_type_store[wf_type].add_internal_channel_def(method)

def _register_signal_channels(self, wf: ObjectWorkflow):
wf_type = get_workflow_type(wf)
Expand Down
68 changes: 68 additions & 0 deletions iwf/type_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Optional
from enum import Enum

from iwf.communication_schema import CommunicationMethod
from iwf.errors import WorkflowDefinitionError


class Type(Enum):
INTERNAL_CHANNEL = 1
# TODO: extend to other types
# DATA_ATTRIBUTE = 2
# SIGNAL_CHANNEL = 3

class TypeStore:
_class_type: Type
_name_to_type_store: dict[str, Optional[type]]
_prefix_to_type_store: dict[str, Optional[type]]

def __init__(self, class_type: Type):
self._class_type = class_type
self._name_to_type_store = dict()
self._prefix_to_type_store = dict()

def is_valid_name_or_prefix(self, name: str) -> bool:
t = self._do_get_type(name)
return t is not None

def get_type(self, name: str) -> Optional[type]:
t = self._do_get_type(name)

if t is None:
raise ValueError(f"{self._class_type} not registered: {name}")

return type

def add_internal_channel_def(self, obj: CommunicationMethod):
if self._class_type != Type.INTERNAL_CHANNEL:
raise WorkflowDefinitionError(
f"Cannot add internal channel definition to {self._class_type}"
)
self._do_add_to_store(obj.is_prefix, obj.name, obj.value_type)


def _do_get_type(self, name: str) -> Optional[type]:
if name in self._name_to_type_store:
return self._name_to_type_store[name]

prefixes = self._prefix_to_type_store.keys()

first = next((prefix for prefix in prefixes if prefix.startswith(name)), None)

if first is None:
return None

return self._prefix_to_type_store.get(first, None)

def _do_add_to_store(self, is_prefix: bool, name: str, t: Optional[type]):
if is_prefix:
store = self._prefix_to_type_store
else:
store = self._name_to_type_store

if name in store:
raise WorkflowDefinitionError(
f"{self._class_type} name/prefix {name} already exists")

store[name] = t

0 comments on commit 7c94d6b

Please sign in to comment.