Skip to content

Commit 7c94d6b

Browse files
committed
IWF-357: Add internal channel TypeStore
1 parent b810adb commit 7c94d6b

File tree

4 files changed

+84
-24
lines changed

4 files changed

+84
-24
lines changed

iwf/command_results.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from iwf.iwf_api.types import Unset
1212
from iwf.object_encoder import ObjectEncoder
13+
from iwf.type_store import TypeStore
1314

1415

1516
@dataclass
@@ -43,7 +44,7 @@ class CommandResults:
4344

4445
def from_idl_command_results(
4546
idl_results: Union[Unset, IdlCommandResults],
46-
internal_channel_types: dict[str, typing.Optional[type]],
47+
internal_channel_types: TypeStore,
4748
signal_channel_types: dict[str, typing.Optional[type]],
4849
object_encoder: ObjectEncoder,
4950
) -> CommandResults:
@@ -58,14 +59,8 @@ def from_idl_command_results(
5859

5960
if not isinstance(idl_results.inter_state_channel_results, Unset):
6061
for inter in idl_results.inter_state_channel_results:
61-
val_type = internal_channel_types.get(inter.channel_name)
62-
if val_type is None:
63-
# fallback to assume it's prefix
64-
# TODO use is_prefix to implement like Java SDK
65-
for name, t in internal_channel_types.items():
66-
if inter.channel_name.startswith(name):
67-
val_type = t
68-
break
62+
val_type = internal_channel_types.get_type(inter.channel_name)
63+
6964
if val_type is None:
7065
raise WorkflowDefinitionError(
7166
"internal channel is not registered: " + inter.channel_name

iwf/communication.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
)
1010
from iwf.object_encoder import ObjectEncoder
1111
from iwf.state_movement import StateMovement
12+
from iwf.type_store import TypeStore
1213

1314

1415
class Communication:
15-
_internal_channel_type_store: dict[str, Optional[type]]
16+
_internal_channel_type_store: TypeStore
1617
_signal_channel_type_store: dict[str, Optional[type]]
1718
_object_encoder: ObjectEncoder
1819
_to_publish_internal_channel: dict[str, list[EncodedObject]]
@@ -22,7 +23,7 @@ class Communication:
2223

2324
def __init__(
2425
self,
25-
internal_channel_type_store: dict[str, Optional[type]],
26+
internal_channel_type_store: TypeStore,
2627
signal_channel_type_store: dict[str, Optional[type]],
2728
object_encoder: ObjectEncoder,
2829
internal_channel_infos: Optional[WorkflowWorkerRpcRequestInternalChannelInfos],
@@ -47,12 +48,7 @@ def trigger_state_execution(self, state: Union[str, type], state_input: Any = No
4748
self._state_movements.append(movement)
4849

4950
def publish_to_internal_channel(self, channel_name: str, value: Any = None):
50-
registered_type = self._internal_channel_type_store.get(channel_name)
51-
52-
if registered_type is None:
53-
for name, t in self._internal_channel_type_store.items():
54-
if channel_name.startswith(name):
55-
registered_type = t
51+
registered_type = self._internal_channel_type_store.get_type(channel_name)
5652

5753
if registered_type is None:
5854
raise WorkflowDefinitionError(

iwf/registry.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from iwf.errors import InvalidArgumentError, WorkflowDefinitionError
55
from iwf.persistence_schema import PersistenceFieldType
66
from iwf.rpc import RPCInfo
7+
from iwf.type_store import TypeStore, Type
78
from iwf.workflow import ObjectWorkflow, get_workflow_type
89
from iwf.workflow_state import WorkflowState, get_state_id
910

@@ -12,7 +13,7 @@ class Registry:
1213
_workflow_store: dict[str, ObjectWorkflow]
1314
_starting_state_store: dict[str, WorkflowState]
1415
_state_store: dict[str, dict[str, WorkflowState]]
15-
_internal_channel_type_store: dict[str, dict[str, Optional[type]]]
16+
_internal_channel_type_store: dict[str, TypeStore]
1617
_signal_channel_type_store: dict[str, dict[str, Optional[type]]]
1718
_data_attribute_types: dict[str, dict[str, Optional[type]]]
1819
_rpc_infos: dict[str, dict[str, RPCInfo]]
@@ -63,7 +64,7 @@ def get_workflow_state_with_check(
6364
def get_state_store(self, wf_type: str) -> dict[str, WorkflowState]:
6465
return self._state_store[wf_type]
6566

66-
def get_internal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]:
67+
def get_internal_channel_type_store(self, wf_type: str) -> TypeStore:
6768
return self._internal_channel_type_store[wf_type]
6869

6970
def get_signal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]:
@@ -83,13 +84,13 @@ def _register_workflow_type(self, wf: ObjectWorkflow):
8384

8485
def _register_internal_channels(self, wf: ObjectWorkflow):
8586
wf_type = get_workflow_type(wf)
86-
types: dict[str, Optional[type]] = {}
87+
88+
if self._internal_channel_type_store[wf_type] is None:
89+
self._internal_channel_type_store[wf_type] = TypeStore(Type.INTERNAL_CHANNEL)
90+
8791
for method in wf.get_communication_schema().communication_methods:
8892
if method.method_type == CommunicationMethodType.InternalChannel:
89-
types[method.name] = method.value_type
90-
# TODO use is_prefix to implement like Java SDK
91-
#
92-
self._internal_channel_type_store[wf_type] = types
93+
self._internal_channel_type_store[wf_type].add_internal_channel_def(method)
9394

9495
def _register_signal_channels(self, wf: ObjectWorkflow):
9596
wf_type = get_workflow_type(wf)

iwf/type_store.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Optional
2+
from enum import Enum
3+
4+
from iwf.communication_schema import CommunicationMethod
5+
from iwf.errors import WorkflowDefinitionError
6+
7+
8+
class Type(Enum):
9+
INTERNAL_CHANNEL = 1
10+
# TODO: extend to other types
11+
# DATA_ATTRIBUTE = 2
12+
# SIGNAL_CHANNEL = 3
13+
14+
class TypeStore:
15+
_class_type: Type
16+
_name_to_type_store: dict[str, Optional[type]]
17+
_prefix_to_type_store: dict[str, Optional[type]]
18+
19+
def __init__(self, class_type: Type):
20+
self._class_type = class_type
21+
self._name_to_type_store = dict()
22+
self._prefix_to_type_store = dict()
23+
24+
def is_valid_name_or_prefix(self, name: str) -> bool:
25+
t = self._do_get_type(name)
26+
return t is not None
27+
28+
def get_type(self, name: str) -> Optional[type]:
29+
t = self._do_get_type(name)
30+
31+
if t is None:
32+
raise ValueError(f"{self._class_type} not registered: {name}")
33+
34+
return type
35+
36+
def add_internal_channel_def(self, obj: CommunicationMethod):
37+
if self._class_type != Type.INTERNAL_CHANNEL:
38+
raise WorkflowDefinitionError(
39+
f"Cannot add internal channel definition to {self._class_type}"
40+
)
41+
self._do_add_to_store(obj.is_prefix, obj.name, obj.value_type)
42+
43+
44+
def _do_get_type(self, name: str) -> Optional[type]:
45+
if name in self._name_to_type_store:
46+
return self._name_to_type_store[name]
47+
48+
prefixes = self._prefix_to_type_store.keys()
49+
50+
first = next((prefix for prefix in prefixes if prefix.startswith(name)), None)
51+
52+
if first is None:
53+
return None
54+
55+
return self._prefix_to_type_store.get(first, None)
56+
57+
def _do_add_to_store(self, is_prefix: bool, name: str, t: Optional[type]):
58+
if is_prefix:
59+
store = self._prefix_to_type_store
60+
else:
61+
store = self._name_to_type_store
62+
63+
if name in store:
64+
raise WorkflowDefinitionError(
65+
f"{self._class_type} name/prefix {name} already exists")
66+
67+
store[name] = t
68+

0 commit comments

Comments
 (0)