Skip to content

Commit e4b1d36

Browse files
authored
Fix: Spawn index incorrect on retry (#1752)
* fix: child index on retry not being reset correctly * chore: ver * chore: changelog * feat: test
1 parent 80a4757 commit e4b1d36

File tree

15 files changed

+239
-134
lines changed

15 files changed

+239
-134
lines changed

sdks/python/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ All notable changes to Hatchet's Python SDK will be documented in this changelog
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [1.10.2] - 2025-05-19
9+
10+
### Changed
11+
12+
- Fixing an issue with the spawn index being set at the `workflow_run_id` level and not the `(workflow_run_id, retry_count)` level, causing children to be spawned multiple times on retry.
13+
814
## [1.10.1] - 2025-05-16
915

1016
### Added

sdks/python/hatchet_sdk/clients/admin.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from hatchet_sdk.metadata import get_metadata
2121
from hatchet_sdk.rate_limit import RateLimitDuration
2222
from hatchet_sdk.runnables.contextvars import (
23+
ctx_action_key,
2324
ctx_step_run_id,
2425
ctx_worker_id,
2526
ctx_workflow_run_id,
@@ -281,11 +282,12 @@ def _create_workflow_run_request(
281282
workflow_run_id = ctx_workflow_run_id.get()
282283
step_run_id = ctx_step_run_id.get()
283284
worker_id = ctx_worker_id.get()
284-
spawn_index = workflow_spawn_indices[workflow_run_id] if workflow_run_id else 0
285+
action_key = ctx_action_key.get()
286+
spawn_index = workflow_spawn_indices[action_key] if action_key else 0
285287

286288
## Increment the spawn_index for the parent workflow
287-
if workflow_run_id:
288-
workflow_spawn_indices[workflow_run_id] += 1
289+
if action_key:
290+
workflow_spawn_indices[action_key] += 1
289291

290292
desired_worker_id = (
291293
(options.desired_worker_id or worker_id) if options.sticky else None

sdks/python/hatchet_sdk/clients/dispatcher/action_listener.py

Lines changed: 3 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import asyncio
22
import json
33
import time
4-
from dataclasses import field
5-
from enum import Enum
6-
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
4+
from typing import TYPE_CHECKING, AsyncGenerator, cast
75

86
import grpc
97
import grpc.aio
10-
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
8+
from pydantic import BaseModel, ConfigDict, Field, model_validator
119

1210
from hatchet_sdk.clients.event_ts import (
1311
ThreadSafeEvent,
@@ -30,8 +28,8 @@
3028
from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub
3129
from hatchet_sdk.logger import logger
3230
from hatchet_sdk.metadata import get_metadata
31+
from hatchet_sdk.runnables.action import Action, ActionPayload, ActionType
3332
from hatchet_sdk.utils.backoff import exp_backoff_sleep
34-
from hatchet_sdk.utils.opentelemetry import OTelAttribute
3533
from hatchet_sdk.utils.proto_enums import convert_proto_enum_to_python
3634
from hatchet_sdk.utils.typing import JSONSerializableMapping
3735

@@ -67,120 +65,6 @@ def validate_labels(self) -> "GetActionListenerRequest":
6765
return self
6866

6967

70-
class ActionPayload(BaseModel):
71-
model_config = ConfigDict(extra="allow")
72-
73-
input: JSONSerializableMapping = Field(default_factory=dict)
74-
parents: dict[str, JSONSerializableMapping] = Field(default_factory=dict)
75-
overrides: JSONSerializableMapping = Field(default_factory=dict)
76-
user_data: JSONSerializableMapping = Field(default_factory=dict)
77-
step_run_errors: dict[str, str] = Field(default_factory=dict)
78-
triggered_by: str | None = None
79-
triggers: JSONSerializableMapping = Field(default_factory=dict)
80-
filter_payload: JSONSerializableMapping = Field(default_factory=dict)
81-
82-
@field_validator(
83-
"input",
84-
"parents",
85-
"overrides",
86-
"user_data",
87-
"step_run_errors",
88-
"filter_payload",
89-
mode="before",
90-
)
91-
@classmethod
92-
def validate_fields(cls, v: Any) -> Any:
93-
return v or {}
94-
95-
@model_validator(mode="after")
96-
def validate_filter_payload(self) -> "ActionPayload":
97-
self.filter_payload = self.triggers.get("filter_payload", {})
98-
99-
return self
100-
101-
102-
class ActionType(str, Enum):
103-
START_STEP_RUN = "START_STEP_RUN"
104-
CANCEL_STEP_RUN = "CANCEL_STEP_RUN"
105-
START_GET_GROUP_KEY = "START_GET_GROUP_KEY"
106-
107-
108-
ActionKey = str
109-
110-
111-
class Action(BaseModel):
112-
worker_id: str
113-
tenant_id: str
114-
workflow_run_id: str
115-
workflow_id: str | None = None
116-
workflow_version_id: str | None = None
117-
get_group_key_run_id: str
118-
job_id: str
119-
job_name: str
120-
job_run_id: str
121-
step_id: str
122-
step_run_id: str
123-
action_id: str
124-
action_type: ActionType
125-
retry_count: int
126-
action_payload: ActionPayload
127-
additional_metadata: JSONSerializableMapping = field(default_factory=dict)
128-
129-
child_workflow_index: int | None = None
130-
child_workflow_key: str | None = None
131-
parent_workflow_run_id: str | None = None
132-
133-
priority: int | None = None
134-
135-
def _dump_payload_to_str(self) -> str:
136-
try:
137-
return json.dumps(self.action_payload.model_dump(), default=str)
138-
except Exception:
139-
return str(self.action_payload)
140-
141-
def get_otel_attributes(self, config: "ClientConfig") -> dict[str, str | int]:
142-
try:
143-
payload_str = json.dumps(self.action_payload.model_dump(), default=str)
144-
except Exception:
145-
payload_str = str(self.action_payload)
146-
147-
attrs: dict[OTelAttribute, str | int | None] = {
148-
OTelAttribute.TENANT_ID: self.tenant_id,
149-
OTelAttribute.WORKER_ID: self.worker_id,
150-
OTelAttribute.WORKFLOW_RUN_ID: self.workflow_run_id,
151-
OTelAttribute.STEP_ID: self.step_id,
152-
OTelAttribute.STEP_RUN_ID: self.step_run_id,
153-
OTelAttribute.RETRY_COUNT: self.retry_count,
154-
OTelAttribute.PARENT_WORKFLOW_RUN_ID: self.parent_workflow_run_id,
155-
OTelAttribute.CHILD_WORKFLOW_INDEX: self.child_workflow_index,
156-
OTelAttribute.CHILD_WORKFLOW_KEY: self.child_workflow_key,
157-
OTelAttribute.ACTION_PAYLOAD: payload_str,
158-
OTelAttribute.WORKFLOW_NAME: self.job_name,
159-
OTelAttribute.ACTION_NAME: self.action_id,
160-
OTelAttribute.GET_GROUP_KEY_RUN_ID: self.get_group_key_run_id,
161-
OTelAttribute.WORKFLOW_ID: self.workflow_id,
162-
OTelAttribute.WORKFLOW_VERSION_ID: self.workflow_version_id,
163-
}
164-
165-
return {
166-
f"hatchet.{k.value}": v
167-
for k, v in attrs.items()
168-
if v and k not in config.otel.excluded_attributes
169-
}
170-
171-
@property
172-
def key(self) -> ActionKey:
173-
"""
174-
This key is used to uniquely identify a single step run by its id + retry count.
175-
It's used when storing references to a task, a context, etc. in a dictionary so that
176-
we can look up those items in the dictionary by a unique key.
177-
"""
178-
if self.action_type == ActionType.START_GET_GROUP_KEY:
179-
return f"{self.get_group_key_run_id}/{self.retry_count}"
180-
else:
181-
return f"{self.step_run_id}/{self.retry_count}"
182-
183-
18468
def parse_additional_metadata(additional_metadata: str) -> JSONSerializableMapping:
18569
try:
18670
return cast(

sdks/python/hatchet_sdk/clients/dispatcher/dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from google.protobuf.timestamp_pb2 import Timestamp
55

66
from hatchet_sdk.clients.dispatcher.action_listener import (
7-
Action,
87
ActionListener,
98
GetActionListenerRequest,
109
)
@@ -29,6 +28,7 @@
2928
)
3029
from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub
3130
from hatchet_sdk.metadata import get_metadata
31+
from hatchet_sdk.runnables.action import Action
3232

3333
DEFAULT_REGISTER_TIMEOUT = 30
3434

sdks/python/hatchet_sdk/opentelemetry/instrumentor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@
3333
TriggerWorkflowOptions,
3434
WorkflowRunTriggerConfig,
3535
)
36-
from hatchet_sdk.clients.dispatcher.action_listener import Action
3736
from hatchet_sdk.clients.events import (
3837
BulkPushEventWithMetadata,
3938
EventClient,
4039
PushEventOptions,
4140
)
4241
from hatchet_sdk.contracts.events_pb2 import Event
42+
from hatchet_sdk.runnables.action import Action
4343
from hatchet_sdk.worker.runner.runner import Runner
4444
from hatchet_sdk.workflow_run import WorkflowRunRef
4545

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import json
2+
from dataclasses import field
3+
from enum import Enum
4+
from typing import TYPE_CHECKING, Any
5+
6+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
7+
8+
from hatchet_sdk.utils.opentelemetry import OTelAttribute
9+
from hatchet_sdk.utils.typing import JSONSerializableMapping
10+
11+
if TYPE_CHECKING:
12+
from hatchet_sdk.config import ClientConfig
13+
14+
ActionKey = str
15+
16+
17+
class ActionPayload(BaseModel):
18+
model_config = ConfigDict(extra="allow")
19+
20+
input: JSONSerializableMapping = Field(default_factory=dict)
21+
parents: dict[str, JSONSerializableMapping] = Field(default_factory=dict)
22+
overrides: JSONSerializableMapping = Field(default_factory=dict)
23+
user_data: JSONSerializableMapping = Field(default_factory=dict)
24+
step_run_errors: dict[str, str] = Field(default_factory=dict)
25+
triggered_by: str | None = None
26+
triggers: JSONSerializableMapping = Field(default_factory=dict)
27+
filter_payload: JSONSerializableMapping = Field(default_factory=dict)
28+
29+
@field_validator(
30+
"input",
31+
"parents",
32+
"overrides",
33+
"user_data",
34+
"step_run_errors",
35+
"filter_payload",
36+
mode="before",
37+
)
38+
@classmethod
39+
def validate_fields(cls, v: Any) -> Any:
40+
return v or {}
41+
42+
@model_validator(mode="after")
43+
def validate_filter_payload(self) -> "ActionPayload":
44+
self.filter_payload = self.triggers.get("filter_payload", {})
45+
46+
return self
47+
48+
49+
class ActionType(str, Enum):
50+
START_STEP_RUN = "START_STEP_RUN"
51+
CANCEL_STEP_RUN = "CANCEL_STEP_RUN"
52+
START_GET_GROUP_KEY = "START_GET_GROUP_KEY"
53+
54+
55+
class Action(BaseModel):
56+
worker_id: str
57+
tenant_id: str
58+
workflow_run_id: str
59+
workflow_id: str | None = None
60+
workflow_version_id: str | None = None
61+
get_group_key_run_id: str
62+
job_id: str
63+
job_name: str
64+
job_run_id: str
65+
step_id: str
66+
step_run_id: str
67+
action_id: str
68+
action_type: ActionType
69+
retry_count: int
70+
action_payload: ActionPayload
71+
additional_metadata: JSONSerializableMapping = field(default_factory=dict)
72+
73+
child_workflow_index: int | None = None
74+
child_workflow_key: str | None = None
75+
parent_workflow_run_id: str | None = None
76+
77+
priority: int | None = None
78+
79+
def _dump_payload_to_str(self) -> str:
80+
try:
81+
return json.dumps(self.action_payload.model_dump(), default=str)
82+
except Exception:
83+
return str(self.action_payload)
84+
85+
def get_otel_attributes(self, config: "ClientConfig") -> dict[str, str | int]:
86+
try:
87+
payload_str = json.dumps(self.action_payload.model_dump(), default=str)
88+
except Exception:
89+
payload_str = str(self.action_payload)
90+
91+
attrs: dict[OTelAttribute, str | int | None] = {
92+
OTelAttribute.TENANT_ID: self.tenant_id,
93+
OTelAttribute.WORKER_ID: self.worker_id,
94+
OTelAttribute.WORKFLOW_RUN_ID: self.workflow_run_id,
95+
OTelAttribute.STEP_ID: self.step_id,
96+
OTelAttribute.STEP_RUN_ID: self.step_run_id,
97+
OTelAttribute.RETRY_COUNT: self.retry_count,
98+
OTelAttribute.PARENT_WORKFLOW_RUN_ID: self.parent_workflow_run_id,
99+
OTelAttribute.CHILD_WORKFLOW_INDEX: self.child_workflow_index,
100+
OTelAttribute.CHILD_WORKFLOW_KEY: self.child_workflow_key,
101+
OTelAttribute.ACTION_PAYLOAD: payload_str,
102+
OTelAttribute.WORKFLOW_NAME: self.job_name,
103+
OTelAttribute.ACTION_NAME: self.action_id,
104+
OTelAttribute.GET_GROUP_KEY_RUN_ID: self.get_group_key_run_id,
105+
OTelAttribute.WORKFLOW_ID: self.workflow_id,
106+
OTelAttribute.WORKFLOW_VERSION_ID: self.workflow_version_id,
107+
}
108+
109+
return {
110+
f"hatchet.{k.value}": v
111+
for k, v in attrs.items()
112+
if v and k not in config.otel.excluded_attributes
113+
}
114+
115+
@property
116+
def key(self) -> ActionKey:
117+
"""
118+
This key is used to uniquely identify a single step run by its id + retry count.
119+
It's used when storing references to a task, a context, etc. in a dictionary so that
120+
we can look up those items in the dictionary by a unique key.
121+
"""
122+
if self.action_type == ActionType.START_GET_GROUP_KEY:
123+
return f"{self.get_group_key_run_id}/{self.retry_count}"
124+
else:
125+
return f"{self.step_run_id}/{self.retry_count}"

sdks/python/hatchet_sdk/runnables/contextvars.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22
from collections import Counter
33
from contextvars import ContextVar
44

5+
from hatchet_sdk.runnables.action import ActionKey
6+
57
ctx_workflow_run_id: ContextVar[str | None] = ContextVar(
68
"ctx_workflow_run_id", default=None
79
)
10+
ctx_action_key: ContextVar[ActionKey | None] = ContextVar(
11+
"ctx_action_key", default=None
12+
)
813
ctx_step_run_id: ContextVar[str | None] = ContextVar("ctx_step_run_id", default=None)
914
ctx_worker_id: ContextVar[str | None] = ContextVar("ctx_worker_id", default=None)
1015

11-
workflow_spawn_indices = Counter[str]()
16+
workflow_spawn_indices = Counter[ActionKey]()
1217
spawn_index_lock = asyncio.Lock()

sdks/python/hatchet_sdk/worker/action_listener_process.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010

1111
from hatchet_sdk.client import Client
1212
from hatchet_sdk.clients.dispatcher.action_listener import (
13-
Action,
1413
ActionListener,
15-
ActionType,
1614
GetActionListenerRequest,
1715
)
1816
from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
@@ -23,7 +21,9 @@
2321
STEP_EVENT_TYPE_STARTED,
2422
)
2523
from hatchet_sdk.logger import logger
24+
from hatchet_sdk.runnables.action import Action, ActionType
2625
from hatchet_sdk.runnables.contextvars import (
26+
ctx_action_key,
2727
ctx_step_run_id,
2828
ctx_worker_id,
2929
ctx_workflow_run_id,
@@ -230,6 +230,7 @@ async def start_action_loop(self) -> None:
230230
ctx_step_run_id.set(action.step_run_id)
231231
ctx_workflow_run_id.set(action.workflow_run_id)
232232
ctx_worker_id.set(action.worker_id)
233+
ctx_action_key.set(action.key)
233234

234235
# Process the action here
235236
match action.action_type:

sdks/python/hatchet_sdk/worker/runner/run_loop_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from typing import Any, Literal, TypeVar
55

66
from hatchet_sdk.client import Client
7-
from hatchet_sdk.clients.dispatcher.action_listener import Action
87
from hatchet_sdk.config import ClientConfig
98
from hatchet_sdk.logger import logger
9+
from hatchet_sdk.runnables.action import Action
1010
from hatchet_sdk.runnables.task import Task
1111
from hatchet_sdk.worker.action_listener_process import ActionEvent
1212
from hatchet_sdk.worker.runner.runner import Runner

0 commit comments

Comments
 (0)