Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions st2common/st2common/models/db/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from st2common.util import date as date_utils


__all__ = ["WorkflowExecutionDB", "TaskExecutionDB"]
__all__ = ["WorkflowExecutionDB", "TaskExecutionDB", "TaskItemStateDB"]


LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -85,4 +85,31 @@ class TaskExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionField
}


MODELS = [WorkflowExecutionDB, TaskExecutionDB]
class TaskItemStateDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin):
"""
Model for storing individual item states for tasks with items (itemized tasks).
This allows efficient storage and retrieval of individual item states without
serializing/deserializing the entire task context for each item.
"""

RESOURCE_TYPE = types.ResourceType.EXECUTION

task_execution = me.StringField(required=True)
item_id = me.IntField(required=True)
status = me.StringField(required=True)
result = JSONDictEscapedFieldCompatibilityField()
context = JSONDictEscapedFieldCompatibilityField()
start_timestamp = db_field_types.ComplexDateTimeField(
default=date_utils.get_datetime_utc_now
)
end_timestamp = db_field_types.ComplexDateTimeField()

meta = {
"indexes": [
{"fields": ["task_execution"]},
{"fields": ["task_execution", "item_id"], "unique": True},
]
}


MODELS = [WorkflowExecutionDB, TaskExecutionDB, TaskItemStateDB]
42 changes: 41 additions & 1 deletion st2common/st2common/persistence/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from st2common.persistence import base as persistence


__all__ = ["WorkflowExecution", "TaskExecution"]
__all__ = ["WorkflowExecution", "TaskExecution", "TaskItemState"]


class WorkflowExecution(persistence.StatusBasedResource):
Expand Down Expand Up @@ -55,3 +55,43 @@ def _get_impl(cls):
@classmethod
def delete_by_query(cls, *args, **query):
return cls._get_impl().delete_by_query(*args, **query)


class TaskItemState(persistence.StatusBasedResource):
impl = db.ChangeRevisionMongoDBAccess(wf_db_models.TaskItemStateDB)
publisher = None

@classmethod
def _get_impl(cls):
return cls.impl

@classmethod
def get_by_task_and_item(cls, task_execution_id, item_id):
"""
Retrieve the state record for a specific item in a task execution.

Args:
task_execution_id: ID of the task execution
item_id: ID of the specific item

Returns:
TaskItemStateDB: The state record for the specified item
"""
return cls._get_impl().get(task_execution=task_execution_id, item_id=item_id)

@classmethod
def query_by_task_execution(cls, task_execution_id):
"""
Retrieve all item state records for a task execution.

Args:
task_execution_id: ID of the task execution

Returns:
list: List of TaskItemStateDB objects for all items in the task
"""
return cls.query(task_execution=task_execution_id)

@classmethod
def delete_by_query(cls, *args, **query):
return cls._get_impl().delete_by_query(*args, **query)
61 changes: 49 additions & 12 deletions st2common/st2common/services/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,16 +611,28 @@ def request_task_execution(wf_ex_db, st2_ctx, task_ex_req):
status=statuses.REQUESTED,
)

# Prepare the result format for itemized task execution.
if task_ex_db.itemized:
task_ex_db.result = {"items": [None] * task_ex_db.items_count}

# Insert new record into the database.
task_ex_db = wf_db_access.TaskExecution.insert(task_ex_db, publish=False)
task_ex_id = str(task_ex_db.id)
msg = 'Task execution "%s" created for task "%s", route "%s".'
update_progress(wf_ex_db, msg % (task_ex_id, task_id, str(task_route)))

# Prepare state storage for itemized task execution.
if task_ex_db.itemized and task_ex_db.items_count > 0:
# Create a minimal result structure in task_ex_db
task_ex_db.result = {"items_count": task_ex_db.items_count}
wf_db_access.TaskExecution.update(task_ex_db, publish=False)

# Create separate state records for each item
for i in range(task_ex_db.items_count):
item_state_db = wf_db_models.TaskItemStateDB(
task_execution=str(task_ex_db.id),
item_id=i,
status=statuses.REQUESTED,
context={}, # Will be populated when processing this specific item
)
wf_db_access.TaskItemState.insert(item_state_db, publish=False)

try:
# Return here if no action is specified in task spec.
if task_spec.action is None:
Expand Down Expand Up @@ -723,6 +735,12 @@ def request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=Non
msg = "Unable to request action execution. Identifier for the item is not provided."
raise Exception(msg)

# For itemized tasks, fetch item context from the item state
if task_ex_db.itemized and item_id is not None:
item_state_db = wf_db_access.TaskItemState.get_by_task_and_item(
str(task_ex_db.id), item_id
)

# Identify the action to execute.
action_db = action_utils.get_action_by_ref(ref=action_ref)

Expand Down Expand Up @@ -759,6 +777,13 @@ def request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=Non
if item_id is not None:
ac_ex_ctx["orquesta"]["item_id"] = item_id

# Update the item state context
item_state_db = wf_db_access.TaskItemState.get_by_task_and_item(
str(task_ex_db.id), item_id
)
item_state_db.context = ac_ex_ctx
wf_db_access.TaskItemState.update(item_state_db, publish=False)

# Render action execution parameters and setup action execution object.
ac_ex_params = param_utils.render_live_params(
runner_type_db.runner_parameters or {},
Expand Down Expand Up @@ -1256,27 +1281,39 @@ def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx
msg = msg % (task_ex_db.task_id, str(task_ex_db.task_route), item_id)
update_progress(wf_ex_db, msg, severity="debug")

task_ex_db.result["items"][item_id] = {
"status": ac_ex_status,
"result": ac_ex_result,
}
# Update the specific item state
item_state_db = wf_db_access.TaskItemState.get_by_task_and_item(
task_ex_id, item_id
)
item_state_db.status = ac_ex_status
item_state_db.result = ac_ex_result
wf_db_access.TaskItemState.update(item_state_db, publish=False)

item_statuses = [
item.get("status", statuses.UNSET) if item else statuses.UNSET
for item in task_ex_db.result["items"]
]
# Check if all items are complete
item_state_dbs = wf_db_access.TaskItemState.query_by_task_execution(task_ex_id)
item_statuses = [item_state_db.status for item_state_db in item_state_dbs]

task_completed = all(
[status in statuses.COMPLETED_STATUSES for status in item_statuses]
)

if task_completed:
# If all items are complete, update the task status
new_task_status = (
statuses.SUCCEEDED
if all([status == statuses.SUCCEEDED for status in item_statuses])
else statuses.FAILED
)

# Also collect all item results for the main task result
results = []
for item_state_db in item_state_dbs:
results.append(
{"status": item_state_db.status, "result": item_state_db.result}
)

task_ex_db.result = {"items": results}

msg = 'Updating task execution from status "%s" to "%s".'
update_progress(
wf_ex_db, msg % (task_ex_db.status, new_task_status), severity="debug"
Expand Down
Loading