|
24 | 24 | from botocore.config import Config |
25 | 25 | from models.clients.litellm_client import LiteLLMClient |
26 | 26 | from models.domain_objects import CreateModelRequest, ModelStatus |
| 27 | +from models.exception import ( |
| 28 | + MaxPollsExceededException, |
| 29 | + StackFailedToCreateException, |
| 30 | + UnexpectedCloudFormationStateException, |
| 31 | +) |
27 | 32 | from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config |
28 | 33 |
|
29 | 34 | lambdaConfig = Config(connect_timeout=60, read_timeout=600, retries={"max_attempts": 1}) |
@@ -113,8 +118,13 @@ def handle_poll_docker_image_available(event: Dict[str, Any], context: Any) -> D |
113 | 118 | output_dict["image_info"]["remaining_polls"] -= 1 |
114 | 119 | if output_dict["image_info"]["remaining_polls"] <= 0: |
115 | 120 | ec2Client.terminate_instances(InstanceIds=[event["image_info"]["instance_id"]]) |
116 | | - raise Exception( |
117 | | - "Maximum number of ECR poll attempts reached. Something went wrong building the docker image." |
| 121 | + raise MaxPollsExceededException( |
| 122 | + json.dumps( |
| 123 | + { |
| 124 | + "error": "Max number of ECR polls reached. Docker Image was not replicated successfully.", |
| 125 | + "event": event, |
| 126 | + } |
| 127 | + ) |
118 | 128 | ) |
119 | 129 | return output_dict |
120 | 130 |
|
@@ -153,7 +163,17 @@ def camelize_object(o): # type: ignore[no-untyped-def] |
153 | 163 |
|
154 | 164 | payload = response["Payload"].read() |
155 | 165 | payload = json.loads(payload) |
156 | | - stack_name = payload.get("stackName") |
| 166 | + stack_name = payload.get("stackName", None) |
| 167 | + |
| 168 | + if not stack_name: |
| 169 | + raise StackFailedToCreateException( |
| 170 | + json.dumps( |
| 171 | + { |
| 172 | + "error": "Failed to create Model CloudFormation Stack. Please validate model parameters are valid.", |
| 173 | + "event": event, |
| 174 | + } |
| 175 | + ) |
| 176 | + ) |
157 | 177 |
|
158 | 178 | response = cfnClient.describe_stacks(StackName=stack_name) |
159 | 179 | stack_arn = response["Stacks"][0]["StackId"] |
@@ -192,10 +212,24 @@ def handle_poll_create_stack(event: Dict[str, Any], context: Any) -> Dict[str, A |
192 | 212 | output_dict["continue_polling_stack"] = True |
193 | 213 | output_dict["remaining_polls_stack"] -= 1 |
194 | 214 | if output_dict["remaining_polls_stack"] <= 0: |
195 | | - raise Exception("Maximum number of CloudFormation polls reached") |
| 215 | + raise MaxPollsExceededException( |
| 216 | + json.dumps( |
| 217 | + { |
| 218 | + "error": "Max number of CloudFormation polls reached.", |
| 219 | + "event": event, |
| 220 | + } |
| 221 | + ) |
| 222 | + ) |
196 | 223 | return output_dict |
197 | 224 | else: |
198 | | - raise Exception(f"Stack in unexpected state: {stackStatus}") |
| 225 | + raise UnexpectedCloudFormationStateException( |
| 226 | + json.dumps( |
| 227 | + { |
| 228 | + "error": f"Stack entered unexpected state: {stackStatus}", |
| 229 | + "event": event, |
| 230 | + } |
| 231 | + ) |
| 232 | + ) |
199 | 233 |
|
200 | 234 |
|
201 | 235 | def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str, Any]: |
@@ -234,3 +268,38 @@ def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str |
234 | 268 | ) |
235 | 269 |
|
236 | 270 | return output_dict |
| 271 | + |
| 272 | + |
| 273 | +def handle_failure(event: Dict[str, Any], context: Any) -> Dict[str, Any]: |
| 274 | + """ |
| 275 | + Handle failures from state machine. |
| 276 | +
|
| 277 | + Possible causes of failures would be: |
| 278 | + 1. Docker Image failed to replicate into ECR in expected amount of time |
| 279 | + 2. CloudFormation Stack creation failed from parameter validation. |
| 280 | + 3. CloudFormation Stack creation failed from taking too long to stand up. |
| 281 | +
|
| 282 | + Expectation of this function is to terminate the EC2 instance if it is still running, and to set the model status |
| 283 | + to Failed. Cleaning up the CloudFormation stack, if it still exists, will happen in the DeleteModel API. |
| 284 | + """ |
| 285 | + error_dict = json.loads( # error from SFN is json payload on top of json payload we add to the exception |
| 286 | + json.loads(event["Cause"])["errorMessage"] |
| 287 | + ) |
| 288 | + error_reason = error_dict["error"] |
| 289 | + original_event = error_dict["event"] |
| 290 | + |
| 291 | + # terminate EC2 instance if we have one recorded |
| 292 | + if "image_info" in original_event and "instance_id" in original_event["image_info"]: |
| 293 | + ec2Client.terminate_instances(InstanceIds=[original_event["image_info"]["instance_id"]]) |
| 294 | + |
| 295 | + # set model as Failed in DDB, so it shows as such in the UI. adds error reason as well. |
| 296 | + model_table.update_item( |
| 297 | + Key={"model_id": original_event["modelId"]}, |
| 298 | + UpdateExpression="SET model_status = :ms, last_modified_date = :lm, failure_reason = :fr", |
| 299 | + ExpressionAttributeValues={ |
| 300 | + ":ms": ModelStatus.FAILED, |
| 301 | + ":lm": int(datetime.utcnow().timestamp()), |
| 302 | + ":fr": error_reason, |
| 303 | + }, |
| 304 | + ) |
| 305 | + return event |
0 commit comments