Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IWF-357: Add internal channel TypeStore #70

Merged
merged 15 commits into from
Jan 7, 2025
20 changes: 8 additions & 12 deletions iwf/command_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from dataclasses import dataclass
from typing import Any, Union

from iwf.errors import WorkflowDefinitionError
from iwf.errors import WorkflowDefinitionError, NotRegisteredError
from iwf.iwf_api.models import (
ChannelRequestStatus,
CommandResults as IdlCommandResults,
TimerStatus,
)
from iwf.iwf_api.types import Unset
from iwf.object_encoder import ObjectEncoder
from iwf.type_store import TypeStore


@dataclass
Expand Down Expand Up @@ -43,7 +44,7 @@ class CommandResults:

def from_idl_command_results(
idl_results: Union[Unset, IdlCommandResults],
internal_channel_types: dict[str, typing.Optional[type]],
internal_channel_types: TypeStore,
signal_channel_types: dict[str, typing.Optional[type]],
object_encoder: ObjectEncoder,
) -> CommandResults:
Expand All @@ -58,18 +59,13 @@ def from_idl_command_results(

if not isinstance(idl_results.inter_state_channel_results, Unset):
for inter in idl_results.inter_state_channel_results:
val_type = internal_channel_types.get(inter.channel_name)
if val_type is None:
# fallback to assume it's prefix
# TODO use is_prefix to implement like Java SDK
for name, t in internal_channel_types.items():
if inter.channel_name.startswith(name):
val_type = t
break
if val_type is None:

try:
val_type = internal_channel_types.get_type(inter.channel_name)
except NotRegisteredError as exception:
raise WorkflowDefinitionError(
"internal channel is not registered: " + inter.channel_name
)
) from exception

encoded = object_encoder.decode(inter.value, val_type)

Expand Down
31 changes: 12 additions & 19 deletions iwf/communication.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Optional, Union

from iwf.errors import WorkflowDefinitionError
from iwf.errors import WorkflowDefinitionError, NotRegisteredError
from iwf.iwf_api.models import (
EncodedObject,
InterStateChannelPublishing,
Expand All @@ -9,10 +9,11 @@
)
from iwf.object_encoder import ObjectEncoder
from iwf.state_movement import StateMovement
from iwf.type_store import TypeStore


class Communication:
_internal_channel_type_store: dict[str, Optional[type]]
_internal_channel_type_store: TypeStore
_signal_channel_type_store: dict[str, Optional[type]]
_object_encoder: ObjectEncoder
_to_publish_internal_channel: dict[str, list[EncodedObject]]
Expand All @@ -22,7 +23,7 @@ class Communication:

def __init__(
self,
internal_channel_type_store: dict[str, Optional[type]],
internal_channel_type_store: TypeStore,
signal_channel_type_store: dict[str, Optional[type]],
object_encoder: ObjectEncoder,
internal_channel_infos: Optional[WorkflowWorkerRpcRequestInternalChannelInfos],
Expand All @@ -47,17 +48,12 @@ def trigger_state_execution(self, state: Union[str, type], state_input: Any = No
self._state_movements.append(movement)

def publish_to_internal_channel(self, channel_name: str, value: Any = None):
registered_type = self._internal_channel_type_store.get(channel_name)

if registered_type is None:
for name, t in self._internal_channel_type_store.items():
if channel_name.startswith(name):
registered_type = t

if registered_type is None:
try:
registered_type = self._internal_channel_type_store.get_type(channel_name)
except NotRegisteredError as exception:
raise WorkflowDefinitionError(
f"InternalChannel channel_name is not defined {channel_name}"
)
) from exception

if (
value is not None
Expand All @@ -84,14 +80,11 @@ def get_to_trigger_state_movements(self) -> list[StateMovement]:
return self._state_movements

def get_internal_channel_size(self, channel_name):
registered_type = self._internal_channel_type_store.get(channel_name)

if registered_type is None:
for name, t in self._internal_channel_type_store.items():
if channel_name.startswith(name):
registered_type = t
is_type_registered = self._internal_channel_type_store.is_valid_name_or_prefix(
channel_name
)

if registered_type is None:
if is_type_registered is False:
raise WorkflowDefinitionError(
f"InternalChannel channel_name is not defined {channel_name}"
)
Expand Down
4 changes: 4 additions & 0 deletions iwf/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class InvalidArgumentError(Exception):
pass


class NotRegisteredError(Exception):
pass


class HttpError(RuntimeError):
def __init__(self, status: int, err_resp: ErrorResponse):
super().__init__(err_resp.detail)
Expand Down
19 changes: 12 additions & 7 deletions iwf/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from iwf.errors import InvalidArgumentError, WorkflowDefinitionError
from iwf.persistence_schema import PersistenceFieldType
from iwf.rpc import RPCInfo
from iwf.type_store import TypeStore, Type
from iwf.workflow import ObjectWorkflow, get_workflow_type
from iwf.workflow_state import WorkflowState, get_state_id

Expand All @@ -12,7 +13,7 @@ class Registry:
_workflow_store: dict[str, ObjectWorkflow]
_starting_state_store: dict[str, WorkflowState]
_state_store: dict[str, dict[str, WorkflowState]]
_internal_channel_type_store: dict[str, dict[str, Optional[type]]]
_internal_channel_type_store: dict[str, TypeStore]
_signal_channel_type_store: dict[str, dict[str, Optional[type]]]
_data_attribute_types: dict[str, dict[str, Optional[type]]]
_rpc_infos: dict[str, dict[str, RPCInfo]]
Expand Down Expand Up @@ -63,7 +64,7 @@ def get_workflow_state_with_check(
def get_state_store(self, wf_type: str) -> dict[str, WorkflowState]:
return self._state_store[wf_type]

def get_internal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]:
def get_internal_channel_type_store(self, wf_type: str) -> TypeStore:
return self._internal_channel_type_store[wf_type]

def get_signal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]:
Expand All @@ -83,13 +84,17 @@ def _register_workflow_type(self, wf: ObjectWorkflow):

def _register_internal_channels(self, wf: ObjectWorkflow):
wf_type = get_workflow_type(wf)
types: dict[str, Optional[type]] = {}

if wf_type not in self._internal_channel_type_store:
self._internal_channel_type_store[wf_type] = TypeStore(
Type.INTERNAL_CHANNEL
)

for method in wf.get_communication_schema().communication_methods:
if method.method_type == CommunicationMethodType.InternalChannel:
types[method.name] = method.value_type
# TODO use is_prefix to implement like Java SDK
#
self._internal_channel_type_store[wf_type] = types
self._internal_channel_type_store[wf_type].add_internal_channel_def(
method
)

def _register_signal_channels(self, wf: ObjectWorkflow):
wf_type = get_workflow_type(wf)
Expand Down
10 changes: 6 additions & 4 deletions iwf/tests/test_internal_channel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import time
import unittest

from iwf.client import Client
from iwf.command_request import CommandRequest, InternalChannelCommand
Expand Down Expand Up @@ -133,8 +134,9 @@ def get_communication_schema(self) -> CommunicationSchema:
client = Client(registry)


def test_internal_channel_workflow():
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"
class TestConditionalComplete(unittest.TestCase):
def test_internal_channel_workflow(self):
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"

client.start_workflow(InternalChannelWorkflow, wf_id, 100, None)
client.get_simple_workflow_result_with_wait(wf_id, None)
client.start_workflow(InternalChannelWorkflow, wf_id, 100, None)
client.get_simple_workflow_result_with_wait(wf_id, None)
123 changes: 123 additions & 0 deletions iwf/tests/test_internal_channel_with_no_prefix_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import inspect
import time
import unittest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test added to make sure the currently existing issue is fixed. Issue description by @longquanzheng

The MVP solution works but not ideal – it just blend/mix the prefix and non-prefix channel names without differentiation. This could cause some confusion/unexpected behavior.

For example:

  • User can define a channel name “ABC” (not by prefix) and try to publish with name “ABCD” will also be allowed – but it should be disallowed. Because “ABC” is not by prefix.

from iwf.client import Client
from iwf.command_request import CommandRequest, InternalChannelCommand
from iwf.command_results import CommandResults
from iwf.communication import Communication
from iwf.communication_schema import CommunicationMethod, CommunicationSchema
from iwf.persistence import Persistence
from iwf.state_decision import StateDecision
from iwf.state_schema import StateSchema
from iwf.tests.worker_server import registry
from iwf.workflow import ObjectWorkflow
from iwf.workflow_context import WorkflowContext
from iwf.workflow_state import T, WorkflowState

internal_channel_name = "internal-channel-1"

test_non_prefix_channel_name = "test-channel-"
test_non_prefix_channel_name_with_suffix = test_non_prefix_channel_name + "abc"


class InitState(WorkflowState[None]):
def execute(
self,
ctx: WorkflowContext,
input: T,
command_results: CommandResults,
persistence: Persistence,
communication: Communication,
) -> StateDecision:
return StateDecision.multi_next_states(
WaitAnyWithPublishState, WaitAllThenPublishState
)


class WaitAnyWithPublishState(WorkflowState[None]):
def wait_until(
self,
ctx: WorkflowContext,
input: T,
persistence: Persistence,
communication: Communication,
) -> CommandRequest:
# Trying to publish to a non-existing channel; this would only work if test_channel_name_non_prefix was defined as a prefix channel
communication.publish_to_internal_channel(
test_non_prefix_channel_name_with_suffix, "str-value-for-prefix"
)
return CommandRequest.for_any_command_completed(
InternalChannelCommand.by_name(internal_channel_name),
)

def execute(
self,
ctx: WorkflowContext,
input: T,
command_results: CommandResults,
persistence: Persistence,
communication: Communication,
) -> StateDecision:
return StateDecision.graceful_complete_workflow()


class WaitAllThenPublishState(WorkflowState[None]):
def wait_until(
self,
ctx: WorkflowContext,
input: T,
persistence: Persistence,
communication: Communication,
) -> CommandRequest:
return CommandRequest.for_all_command_completed(
InternalChannelCommand.by_name(test_non_prefix_channel_name),
)

def execute(
self,
ctx: WorkflowContext,
input: T,
command_results: CommandResults,
persistence: Persistence,
communication: Communication,
) -> StateDecision:
communication.publish_to_internal_channel(internal_channel_name, None)
return StateDecision.dead_end


class InternalChannelWorkflowWithNoPrefixChannel(ObjectWorkflow):
def get_workflow_states(self) -> StateSchema:
return StateSchema.with_starting_state(
InitState(), WaitAnyWithPublishState(), WaitAllThenPublishState()
)

def get_communication_schema(self) -> CommunicationSchema:
return CommunicationSchema.create(
CommunicationMethod.internal_channel_def(internal_channel_name, type(None)),
# Defining a standard channel (non-prefix) to make sure messages to the channel with a suffix added will not be accepted
CommunicationMethod.internal_channel_def(test_non_prefix_channel_name, str),
)


wf = InternalChannelWorkflowWithNoPrefixChannel()
registry.add_workflow(wf)
client = Client(registry)


class TestInternalChannelWithNoPrefix(unittest.TestCase):
def test_internal_channel_workflow_with_no_prefix_channel(self):
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"

client.start_workflow(
InternalChannelWorkflowWithNoPrefixChannel, wf_id, 5, None
)

with self.assertRaises(Exception) as context:
client.wait_for_workflow_completion(wf_id, None)

self.assertIn("FAILED", context.exception.workflow_status)
self.assertIn(
f"WorkerExecutionError: InternalChannel channel_name is not defined {test_non_prefix_channel_name_with_suffix}",
context.exception.error_message,
)
68 changes: 68 additions & 0 deletions iwf/type_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Optional
from enum import Enum

from iwf.communication_schema import CommunicationMethod
from iwf.errors import WorkflowDefinitionError, NotRegisteredError


class Type(Enum):
INTERNAL_CHANNEL = 1
# TODO: extend to other types
# DATA_ATTRIBUTE = 2
# SIGNAL_CHANNEL = 3
Comment on lines +9 to +12
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JavaSDK allows prefixing SignalChannels and DataAttributes. Leaving this for future use

https://github.com/indeedeng/iwf-java-sdk/pull/192/files#diff-c0b29a42e46cc3af3c19154c55b89cb3b33b9632e9b3b27bb918c600197b28d1R16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#42



class TypeStore:
_class_type: Type
_name_to_type_store: dict[str, Optional[type]]
_prefix_to_type_store: dict[str, Optional[type]]

def __init__(self, class_type: Type):
self._class_type = class_type
self._name_to_type_store = dict()
self._prefix_to_type_store = dict()

def is_valid_name_or_prefix(self, name: str) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used anywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to be used now. It was not used before, good catch

t = self._do_get_type(name)
return t is not None

def get_type(self, name: str) -> type:
t = self._do_get_type(name)

if t is None:
raise NotRegisteredError(f"{self._class_type} not registered: {name}")

return t

def add_internal_channel_def(self, obj: CommunicationMethod):
if self._class_type != Type.INTERNAL_CHANNEL:
raise ValueError(
f"Cannot add internal channel definition to {self._class_type}"
)
self._do_add_to_store(obj.is_prefix, obj.name, obj.value_type)

def _do_get_type(self, name: str) -> Optional[type]:
if name in self._name_to_type_store:
return self._name_to_type_store[name]

prefixes = self._prefix_to_type_store.keys()

first = next((prefix for prefix in prefixes if name.startswith(prefix)), None)

if first is None:
return None

return self._prefix_to_type_store.get(first, None)

def _do_add_to_store(self, is_prefix: bool, name: str, t: Optional[type]):
if is_prefix:
store = self._prefix_to_type_store
else:
store = self._name_to_type_store

if name in store:
raise WorkflowDefinitionError(
f"{self._class_type} name/prefix {name} already exists"
)

store[name] = t
Loading
Loading