|
1 | 1 | import asyncio |
2 | 2 | import json |
3 | 3 | import time |
4 | | -from dataclasses import field |
5 | | -from enum import Enum |
6 | | -from typing import TYPE_CHECKING, Any, AsyncGenerator, cast |
| 4 | +from typing import TYPE_CHECKING, AsyncGenerator, cast |
7 | 5 |
|
8 | 6 | import grpc |
9 | 7 | import grpc.aio |
10 | | -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator |
| 8 | +from pydantic import BaseModel, ConfigDict, Field, model_validator |
11 | 9 |
|
12 | 10 | from hatchet_sdk.clients.event_ts import ( |
13 | 11 | ThreadSafeEvent, |
|
30 | 28 | from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub |
31 | 29 | from hatchet_sdk.logger import logger |
32 | 30 | from hatchet_sdk.metadata import get_metadata |
| 31 | +from hatchet_sdk.runnables.action import Action, ActionPayload, ActionType |
33 | 32 | from hatchet_sdk.utils.backoff import exp_backoff_sleep |
34 | | -from hatchet_sdk.utils.opentelemetry import OTelAttribute |
35 | 33 | from hatchet_sdk.utils.proto_enums import convert_proto_enum_to_python |
36 | 34 | from hatchet_sdk.utils.typing import JSONSerializableMapping |
37 | 35 |
|
@@ -67,120 +65,6 @@ def validate_labels(self) -> "GetActionListenerRequest": |
67 | 65 | return self |
68 | 66 |
|
69 | 67 |
|
70 | | -class ActionPayload(BaseModel): |
71 | | - model_config = ConfigDict(extra="allow") |
72 | | - |
73 | | - input: JSONSerializableMapping = Field(default_factory=dict) |
74 | | - parents: dict[str, JSONSerializableMapping] = Field(default_factory=dict) |
75 | | - overrides: JSONSerializableMapping = Field(default_factory=dict) |
76 | | - user_data: JSONSerializableMapping = Field(default_factory=dict) |
77 | | - step_run_errors: dict[str, str] = Field(default_factory=dict) |
78 | | - triggered_by: str | None = None |
79 | | - triggers: JSONSerializableMapping = Field(default_factory=dict) |
80 | | - filter_payload: JSONSerializableMapping = Field(default_factory=dict) |
81 | | - |
82 | | - @field_validator( |
83 | | - "input", |
84 | | - "parents", |
85 | | - "overrides", |
86 | | - "user_data", |
87 | | - "step_run_errors", |
88 | | - "filter_payload", |
89 | | - mode="before", |
90 | | - ) |
91 | | - @classmethod |
92 | | - def validate_fields(cls, v: Any) -> Any: |
93 | | - return v or {} |
94 | | - |
95 | | - @model_validator(mode="after") |
96 | | - def validate_filter_payload(self) -> "ActionPayload": |
97 | | - self.filter_payload = self.triggers.get("filter_payload", {}) |
98 | | - |
99 | | - return self |
100 | | - |
101 | | - |
102 | | -class ActionType(str, Enum): |
103 | | - START_STEP_RUN = "START_STEP_RUN" |
104 | | - CANCEL_STEP_RUN = "CANCEL_STEP_RUN" |
105 | | - START_GET_GROUP_KEY = "START_GET_GROUP_KEY" |
106 | | - |
107 | | - |
108 | | -ActionKey = str |
109 | | - |
110 | | - |
111 | | -class Action(BaseModel): |
112 | | - worker_id: str |
113 | | - tenant_id: str |
114 | | - workflow_run_id: str |
115 | | - workflow_id: str | None = None |
116 | | - workflow_version_id: str | None = None |
117 | | - get_group_key_run_id: str |
118 | | - job_id: str |
119 | | - job_name: str |
120 | | - job_run_id: str |
121 | | - step_id: str |
122 | | - step_run_id: str |
123 | | - action_id: str |
124 | | - action_type: ActionType |
125 | | - retry_count: int |
126 | | - action_payload: ActionPayload |
127 | | - additional_metadata: JSONSerializableMapping = field(default_factory=dict) |
128 | | - |
129 | | - child_workflow_index: int | None = None |
130 | | - child_workflow_key: str | None = None |
131 | | - parent_workflow_run_id: str | None = None |
132 | | - |
133 | | - priority: int | None = None |
134 | | - |
135 | | - def _dump_payload_to_str(self) -> str: |
136 | | - try: |
137 | | - return json.dumps(self.action_payload.model_dump(), default=str) |
138 | | - except Exception: |
139 | | - return str(self.action_payload) |
140 | | - |
141 | | - def get_otel_attributes(self, config: "ClientConfig") -> dict[str, str | int]: |
142 | | - try: |
143 | | - payload_str = json.dumps(self.action_payload.model_dump(), default=str) |
144 | | - except Exception: |
145 | | - payload_str = str(self.action_payload) |
146 | | - |
147 | | - attrs: dict[OTelAttribute, str | int | None] = { |
148 | | - OTelAttribute.TENANT_ID: self.tenant_id, |
149 | | - OTelAttribute.WORKER_ID: self.worker_id, |
150 | | - OTelAttribute.WORKFLOW_RUN_ID: self.workflow_run_id, |
151 | | - OTelAttribute.STEP_ID: self.step_id, |
152 | | - OTelAttribute.STEP_RUN_ID: self.step_run_id, |
153 | | - OTelAttribute.RETRY_COUNT: self.retry_count, |
154 | | - OTelAttribute.PARENT_WORKFLOW_RUN_ID: self.parent_workflow_run_id, |
155 | | - OTelAttribute.CHILD_WORKFLOW_INDEX: self.child_workflow_index, |
156 | | - OTelAttribute.CHILD_WORKFLOW_KEY: self.child_workflow_key, |
157 | | - OTelAttribute.ACTION_PAYLOAD: payload_str, |
158 | | - OTelAttribute.WORKFLOW_NAME: self.job_name, |
159 | | - OTelAttribute.ACTION_NAME: self.action_id, |
160 | | - OTelAttribute.GET_GROUP_KEY_RUN_ID: self.get_group_key_run_id, |
161 | | - OTelAttribute.WORKFLOW_ID: self.workflow_id, |
162 | | - OTelAttribute.WORKFLOW_VERSION_ID: self.workflow_version_id, |
163 | | - } |
164 | | - |
165 | | - return { |
166 | | - f"hatchet.{k.value}": v |
167 | | - for k, v in attrs.items() |
168 | | - if v and k not in config.otel.excluded_attributes |
169 | | - } |
170 | | - |
171 | | - @property |
172 | | - def key(self) -> ActionKey: |
173 | | - """ |
174 | | - This key is used to uniquely identify a single step run by its id + retry count. |
175 | | - It's used when storing references to a task, a context, etc. in a dictionary so that |
176 | | - we can look up those items in the dictionary by a unique key. |
177 | | - """ |
178 | | - if self.action_type == ActionType.START_GET_GROUP_KEY: |
179 | | - return f"{self.get_group_key_run_id}/{self.retry_count}" |
180 | | - else: |
181 | | - return f"{self.step_run_id}/{self.retry_count}" |
182 | | - |
183 | | - |
184 | 68 | def parse_additional_metadata(additional_metadata: str) -> JSONSerializableMapping: |
185 | 69 | try: |
186 | 70 | return cast( |
|
0 commit comments