Skip to content

Commit e580f2f

Browse files
Add created IDs to async-upload job termination message (#2173)
* add entityIds to termination msg Signed-off-by: Adysen Rothman <85646824+adysenrothman@users.noreply.github.com> * add intent type to message Signed-off-by: Adysen Rothman <85646824+adysenrothman@users.noreply.github.com> * optional update_artifact intent ids Signed-off-by: Adysen Rothman <85646824+adysenrothman@users.noreply.github.com> * tests Signed-off-by: Adysen Rothman <85646824+adysenrothman@users.noreply.github.com> * update test Signed-off-by: Adysen Rothman <85646824+adysenrothman@users.noreply.github.com> --------- Signed-off-by: Adysen Rothman <85646824+adysenrothman@users.noreply.github.com>
1 parent 7e6e6fe commit e580f2f

5 files changed

Lines changed: 269 additions & 24 deletions

File tree

jobs/async-upload/job/entrypoint.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import json
23
import logging
34
import os
45

@@ -12,6 +13,7 @@
1213
create_version_and_artifact,
1314
validate_create_model_intent,
1415
validate_create_version_intent,
16+
CreatedEntityIds,
1517
)
1618
from .models import CreateModelIntent, CreateVersionIntent, UpdateArtifactIntent
1719
from .download import perform_download
@@ -44,6 +46,29 @@ def record_error(exc):
4446
logger.error(message)
4547

4648

49+
def write_success_result(entity_ids: CreatedEntityIds, intent_type: str) -> None:
50+
"""Write success result with entity IDs to termination message path.
51+
52+
Args:
53+
entity_ids: CreatedEntityIds object containing the IDs of created/updated entities.
54+
intent_type: The intent type string (e.g., "update_artifact", "create_model", "create_version").
55+
"""
56+
result_dict = {"intent": intent_type}
57+
58+
if entity_ids.registered_model_id:
59+
result_dict["RegisteredModel"] = {"id": entity_ids.registered_model_id}
60+
61+
if entity_ids.model_version_id:
62+
result_dict["ModelVersion"] = {"id": entity_ids.model_version_id}
63+
64+
if entity_ids.model_artifact_id:
65+
result_dict["ModelArtifact"] = {"id": entity_ids.model_artifact_id}
66+
67+
result_json = json.dumps(result_dict, indent=2)
68+
write_to_termination_message_path(result_json)
69+
logger.info(f"✅ Success result written to termination message path: {result_json}")
70+
71+
4772
async def main() -> None:
4873
"""
4974
Main entrypoint for the async upload job.
@@ -59,12 +84,21 @@ async def main() -> None:
5984
client = validate_and_get_model_registry_client(config.registry)
6085

6186
intent = config.model.intent
87+
entity_ids: CreatedEntityIds | None = None
88+
6289
if isinstance(intent, UpdateArtifactIntent):
6390
logger.info("📋 Processing update_artifact intent")
6491
await set_artifact_pending(client, intent.artifact_id)
6592
perform_download(config)
6693
uri = perform_upload(config)
67-
await update_model_artifact_uri(client, intent.artifact_id, uri)
94+
# Pass through optional model_id and version_id if provided
95+
entity_ids = await update_model_artifact_uri(
96+
client,
97+
intent.artifact_id,
98+
uri,
99+
registered_model_id=intent.model_id,
100+
model_version_id=intent.version_id,
101+
)
68102
elif isinstance(intent, CreateModelIntent):
69103
logger.info("📋 Processing create_model intent")
70104
if not config.metadata:
@@ -74,7 +108,7 @@ async def main() -> None:
74108
await validate_create_model_intent(client, config.metadata)
75109
perform_download(config)
76110
uri = perform_upload(config)
77-
await create_model_and_artifact(client, config.metadata, uri)
111+
entity_ids = await create_model_and_artifact(client, config.metadata, uri)
78112
elif isinstance(intent, CreateVersionIntent):
79113
logger.info("📋 Processing create_version intent")
80114
if not config.metadata:
@@ -84,9 +118,13 @@ async def main() -> None:
84118
await validate_create_version_intent(client, intent.model_id, config.metadata)
85119
perform_download(config)
86120
uri = perform_upload(config)
87-
await create_version_and_artifact(client, intent.model_id, config.metadata, uri)
121+
entity_ids = await create_version_and_artifact(client, intent.model_id, config.metadata, uri)
88122
else:
89123
raise ValueError(f"Unknown intent type: {type(intent)}")
124+
125+
# Write success result to termination message path
126+
if entity_ids:
127+
write_success_result(entity_ids, intent.intent_type)
90128
except BaseException as e:
91129
record_error(e)
92130
raise

jobs/async-upload/job/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ class CreateVersionIntent(BaseModel):
123123
class UpdateArtifactIntent(BaseModel):
124124
intent_type: Literal[UploadIntent.update_artifact] = UploadIntent.update_artifact
125125
artifact_id: str = Field(..., description="Model artifact ID to update")
126+
# Optional IDs to pass through to termination message output
127+
model_id: str | None = Field(default=None, description="Optional registered model ID to include in output")
128+
version_id: str | None = Field(default=None, description="Optional model version ID to include in output")
126129

127130

128131
IntentConfig = Union[CreateModelIntent, CreateVersionIntent, UpdateArtifactIntent]

jobs/async-upload/job/mr_client.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from dataclasses import dataclass
23

34
from model_registry import ModelRegistry
45
from model_registry.types import ArtifactState
@@ -14,6 +15,14 @@
1415
logger = logging.getLogger(__name__)
1516

1617

18+
@dataclass
19+
class CreatedEntityIds:
20+
"""IDs of created/updated entities in the model registry."""
21+
registered_model_id: str | None = None
22+
model_version_id: str | None = None
23+
model_artifact_id: str | None = None
24+
25+
1726
def validate_and_get_model_registry_client(config: RegistryConfig) -> ModelRegistry:
1827
"""
1928
Validates the model registry client configuration and returns a ModelRegistry client.
@@ -94,27 +103,64 @@ async def validate_create_version_intent(client: ModelRegistry, model_id: str, m
94103
logger.debug("✅ create_version intent validation passed")
95104

96105

97-
async def create_model_and_artifact(client: ModelRegistry, metadata: ConfigMapMetadata, uri: str) -> None:
98-
"""Creates a new registered model, model version, and model artifact."""
106+
async def create_model_and_artifact(client: ModelRegistry, metadata: ConfigMapMetadata, uri: str) -> CreatedEntityIds:
107+
"""Creates a new registered model, model version, and model artifact.
108+
109+
Returns:
110+
CreatedEntityIds: IDs of the created registered model, model version, and model artifact.
111+
"""
99112
logger.debug("🔍 Creating new registered model, version, and artifact")
100113
rm = await _create_registered_model(client, metadata.registered_model)
101-
await _create_version_and_artifact_for_model(client, rm, uri, metadata)
114+
mv, artifact = await _create_version_and_artifact_for_model(client, rm, uri, metadata)
115+
return CreatedEntityIds(
116+
registered_model_id=str(rm.id) if rm.id else None,
117+
model_version_id=str(mv.id) if mv.id else None,
118+
model_artifact_id=str(artifact.id) if artifact.id else None,
119+
)
102120

103121

104122
async def create_version_and_artifact(
105123
client: ModelRegistry, model_id: str, metadata: ConfigMapMetadata, uri: str
106-
) -> None:
107-
"""Creates a new model version and model artifact under an existing registered model."""
124+
) -> CreatedEntityIds:
125+
"""Creates a new model version and model artifact under an existing registered model.
126+
127+
Returns:
128+
CreatedEntityIds: IDs of the existing registered model, and created model version and model artifact.
129+
"""
108130
logger.debug("🔍 Creating new version and artifact for model ID: %s", model_id)
109131

110132
rm = await client._api.get_registered_model_by_id(model_id)
111133
if not rm:
112134
raise ValueError(f"RegisteredModel with ID '{model_id}' not found")
113135

114-
await _create_version_and_artifact_for_model(client, rm, uri, metadata)
136+
mv, artifact = await _create_version_and_artifact_for_model(client, rm, uri, metadata)
137+
return CreatedEntityIds(
138+
registered_model_id=str(rm.id) if rm.id else None,
139+
model_version_id=str(mv.id) if mv.id else None,
140+
model_artifact_id=str(artifact.id) if artifact.id else None,
141+
)
115142

116143

117-
async def update_model_artifact_uri(client: ModelRegistry, artifact_id: str, uri: str) -> None:
144+
async def update_model_artifact_uri(
145+
client: ModelRegistry,
146+
artifact_id: str,
147+
uri: str,
148+
registered_model_id: str | None = None,
149+
model_version_id: str | None = None,
150+
) -> CreatedEntityIds:
151+
"""Updates the model artifact URI and sets state to LIVE.
152+
153+
Args:
154+
client: Model registry client.
155+
artifact_id: ID of the artifact to update.
156+
uri: New URI for the artifact.
157+
registered_model_id: Optional registered model ID to pass through to output.
158+
model_version_id: Optional model version ID to pass through to output.
159+
160+
Returns:
161+
CreatedEntityIds: IDs passed through to output. For update_artifact intent,
162+
model and version IDs will be in the output if and only if they are passed in.
163+
"""
118164
logger.debug("🔍 Updating model artifact URI: %s", uri)
119165
artifact = await client._api.get_model_artifact_by_id(artifact_id)
120166

@@ -126,6 +172,13 @@ async def update_model_artifact_uri(client: ModelRegistry, artifact_id: str, uri
126172
artifact.uri = uri
127173
await client._api.upsert_model_artifact(artifact)
128174
logger.debug("✅ Model artifact URI updated: %s", uri)
175+
176+
# Pass through the IDs that were provided
177+
return CreatedEntityIds(
178+
registered_model_id=registered_model_id,
179+
model_version_id=model_version_id,
180+
model_artifact_id=artifact_id,
181+
)
129182

130183

131184
async def _create_registered_model(client: ModelRegistry, rm_metadata):
@@ -142,8 +195,12 @@ async def _create_registered_model(client: ModelRegistry, rm_metadata):
142195

143196
async def _create_version_and_artifact_for_model(
144197
client: ModelRegistry, rm, uri: str, metadata: ConfigMapMetadata
145-
) -> None:
146-
"""Creates a model version and artifact under the given registered model."""
198+
) -> tuple:
199+
"""Creates a model version and artifact under the given registered model.
200+
201+
Returns:
202+
tuple: (ModelVersion, ModelArtifact) - The created model version and artifact objects.
203+
"""
147204
mv_metadata = metadata.model_version
148205
version_name = mv_metadata.name or "1.0.0"
149206

@@ -181,3 +238,5 @@ async def _create_version_and_artifact_for_model(
181238
artifact.state = ArtifactState.LIVE
182239
await client._api.upsert_model_artifact(artifact)
183240
logger.debug("✅ Updated ModelArtifact state to LIVE: %s", artifact.id)
241+
242+
return (mv, artifact)

0 commit comments

Comments
 (0)