11import logging
2+ from dataclasses import dataclass
23
34from model_registry import ModelRegistry
45from model_registry .types import ArtifactState
1415logger = logging .getLogger (__name__ )
1516
1617
18+ @dataclass
19+ class CreatedEntityIds :
20+ """IDs of created/updated entities in the model registry."""
21+ registered_model_id : str | None = None
22+ model_version_id : str | None = None
23+ model_artifact_id : str | None = None
24+
25+
1726def validate_and_get_model_registry_client (config : RegistryConfig ) -> ModelRegistry :
1827 """
1928 Validates the model registry client configuration and returns a ModelRegistry client.
@@ -94,27 +103,64 @@ async def validate_create_version_intent(client: ModelRegistry, model_id: str, m
94103 logger .debug ("✅ create_version intent validation passed" )
95104
96105
97- async def create_model_and_artifact (client : ModelRegistry , metadata : ConfigMapMetadata , uri : str ) -> None :
98- """Creates a new registered model, model version, and model artifact."""
106+ async def create_model_and_artifact (client : ModelRegistry , metadata : ConfigMapMetadata , uri : str ) -> CreatedEntityIds :
107+ """Creates a new registered model, model version, and model artifact.
108+
109+ Returns:
110+ CreatedEntityIds: IDs of the created registered model, model version, and model artifact.
111+ """
99112 logger .debug ("🔍 Creating new registered model, version, and artifact" )
100113 rm = await _create_registered_model (client , metadata .registered_model )
101- await _create_version_and_artifact_for_model (client , rm , uri , metadata )
114+ mv , artifact = await _create_version_and_artifact_for_model (client , rm , uri , metadata )
115+ return CreatedEntityIds (
116+ registered_model_id = str (rm .id ) if rm .id else None ,
117+ model_version_id = str (mv .id ) if mv .id else None ,
118+ model_artifact_id = str (artifact .id ) if artifact .id else None ,
119+ )
102120
103121
104122async def create_version_and_artifact (
105123 client : ModelRegistry , model_id : str , metadata : ConfigMapMetadata , uri : str
106- ) -> None :
107- """Creates a new model version and model artifact under an existing registered model."""
124+ ) -> CreatedEntityIds :
125+ """Creates a new model version and model artifact under an existing registered model.
126+
127+ Returns:
128+ CreatedEntityIds: IDs of the existing registered model, and created model version and model artifact.
129+ """
108130 logger .debug ("🔍 Creating new version and artifact for model ID: %s" , model_id )
109131
110132 rm = await client ._api .get_registered_model_by_id (model_id )
111133 if not rm :
112134 raise ValueError (f"RegisteredModel with ID '{ model_id } ' not found" )
113135
114- await _create_version_and_artifact_for_model (client , rm , uri , metadata )
136+ mv , artifact = await _create_version_and_artifact_for_model (client , rm , uri , metadata )
137+ return CreatedEntityIds (
138+ registered_model_id = str (rm .id ) if rm .id else None ,
139+ model_version_id = str (mv .id ) if mv .id else None ,
140+ model_artifact_id = str (artifact .id ) if artifact .id else None ,
141+ )
115142
116143
117- async def update_model_artifact_uri (client : ModelRegistry , artifact_id : str , uri : str ) -> None :
144+ async def update_model_artifact_uri (
145+ client : ModelRegistry ,
146+ artifact_id : str ,
147+ uri : str ,
148+ registered_model_id : str | None = None ,
149+ model_version_id : str | None = None ,
150+ ) -> CreatedEntityIds :
151+ """Updates the model artifact URI and sets state to LIVE.
152+
153+ Args:
154+ client: Model registry client.
155+ artifact_id: ID of the artifact to update.
156+ uri: New URI for the artifact.
157+ registered_model_id: Optional registered model ID to pass through to output.
158+ model_version_id: Optional model version ID to pass through to output.
159+
160+ Returns:
161+ CreatedEntityIds: IDs passed through to output. For update_artifact intent,
162+ model and version IDs will be in the output if and only if they are passed in.
163+ """
118164 logger .debug ("🔍 Updating model artifact URI: %s" , uri )
119165 artifact = await client ._api .get_model_artifact_by_id (artifact_id )
120166
@@ -126,6 +172,13 @@ async def update_model_artifact_uri(client: ModelRegistry, artifact_id: str, uri
126172 artifact .uri = uri
127173 await client ._api .upsert_model_artifact (artifact )
128174 logger .debug ("✅ Model artifact URI updated: %s" , uri )
175+
176+ # Pass through the IDs that were provided
177+ return CreatedEntityIds (
178+ registered_model_id = registered_model_id ,
179+ model_version_id = model_version_id ,
180+ model_artifact_id = artifact_id ,
181+ )
129182
130183
131184async def _create_registered_model (client : ModelRegistry , rm_metadata ):
@@ -142,8 +195,12 @@ async def _create_registered_model(client: ModelRegistry, rm_metadata):
142195
143196async def _create_version_and_artifact_for_model (
144197 client : ModelRegistry , rm , uri : str , metadata : ConfigMapMetadata
145- ) -> None :
146- """Creates a model version and artifact under the given registered model."""
198+ ) -> tuple :
199+ """Creates a model version and artifact under the given registered model.
200+
201+ Returns:
202+ tuple: (ModelVersion, ModelArtifact) - The created model version and artifact objects.
203+ """
147204 mv_metadata = metadata .model_version
148205 version_name = mv_metadata .name or "1.0.0"
149206
@@ -181,3 +238,5 @@ async def _create_version_and_artifact_for_model(
181238 artifact .state = ArtifactState .LIVE
182239 await client ._api .upsert_model_artifact (artifact )
183240 logger .debug ("✅ Updated ModelArtifact state to LIVE: %s" , artifact .id )
241+
242+ return (mv , artifact )
0 commit comments