Skip to content

Commit cc0975a

Browse files
authored
Release v3.0.1
2 parents 0e824eb + ba4fd4f commit cc0975a

30 files changed

Lines changed: 334 additions & 89 deletions

File tree

.github/workflows/code.release.branch.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
RELEASE_TAG=${{ github.event.inputs.release_tag }}
2727
git checkout -b release/${{ github.event.inputs.release_tag }}
2828
echo "$( jq --arg version ${RELEASE_TAG:1} '.version = $version' package.json )" > package.json
29-
sed -E -i "" -e "s/version = \"[0-9\.].+\"/version = \"${RELEASE_TAG:1}\"/g" lisa-sdk/pyproject.toml
29+
sed -E -i -e "s/version = \"[0-9\.].+\"/version = \"${RELEASE_TAG:1}\"/g" lisa-sdk/pyproject.toml
3030
echo ${RELEASE_TAG:1} > VERSION
3131
git commit -a -m "Updating version for release ${{ github.event.inputs.release_tag }}"
3232
git push origin release/${{ github.event.inputs.release_tag }}

CHANGELOG.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,25 @@
1+
# v3.0.1
2+
## Bug fixes
3+
- Updated our Lambda admin validation to work for no-auth if user has the admin secret token. This applies to model management APIs.
4+
- State machine for create model was not reporting failed status
5+
- Delete state machine could not delete models that weren't stored in LiteLLM DB
6+
7+
## Enhancements
8+
- Added units to the create model wizard to help with clarity
9+
- Increased default timeouts to 10 minutes to enable large documentation processing without errors
10+
- Updated ALB and Target group names to be lower cased by default to prevent networking issues
11+
12+
## Coming Soon
13+
- 3.1.0 will expand support for model management. Administrators will be able to modify, activate, and deactivate models through the UI or APIs. The following release we will continue to ease deployment steps for customers through a new deployment wizard and updated documentation.
14+
15+
## Acknowledgements
16+
* @petermuller
17+
* @estohlmann
18+
* @dustins
19+
20+
**Full Changelog**: https://github.com/awslabs/LISA/compare/v3.0.0...v3.0.1
21+
22+
123
# v3.0.0
224
## Key Features
325
### Model Management Administration

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,10 @@ curl -s -H "Authorization: Bearer <admin_token>" -X GET https://<apigw_endpoint>
659659

660660
LISA provides the `/models` endpoint for creating both ECS and LiteLLM-hosted models. Depending on the request payload, infrastructure will be created or bypassed (e.g., for LiteLLM-only models).
661661

662+
This API accepts the same model definition parameters that were accepted in the V2 model definitions within the config.yaml file with one notable difference: the `containerConfig.baseImage.path` field is
663+
now a path relative to the `lib/serve/ecs-model` directory, instead of from its original path relative to the repository root. This means that if the path used to be `lib/serve/ecs-model/textgen/tgi`, then
664+
it will now be `textgen/tgi` for the CreateModel API. For vLLM models, the `path` will be `vllm`, and for TEI, it will be `embedding/tei`.
665+
662666
#### Request Example:
663667

664668
```

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.0.0
1+
3.0.1

ecs_model_deployer/src/lib/ecsCluster.ts

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ export class ECSCluster extends Construct {
109109
],
110110
});
111111

112+
new CfnOutput(this, 'autoScalingGroup', {
113+
key: 'autoScalingGroup',
114+
value: autoScalingGroup.autoScalingGroupName,
115+
});
116+
112117
const environment = ecsConfig.environment;
113118
const volumes: Volume[] = [];
114119
const mountPoints: MountPoint[] = [];
@@ -272,10 +277,11 @@ export class ECSCluster extends Construct {
272277
const loadBalancer = new ApplicationLoadBalancer(this, createCdkId([ecsConfig.identifier, 'ALB']), {
273278
deletionProtection: config.removalPolicy !== RemovalPolicy.DESTROY,
274279
internetFacing: false,
275-
loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2),
280+
loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2).toLowerCase(),
276281
dropInvalidHeaderFields: true,
277282
securityGroup,
278283
vpc,
284+
idleTimeout: Duration.seconds(600)
279285
});
280286

281287
// Add listener
@@ -294,7 +300,7 @@ export class ECSCluster extends Construct {
294300
// Add targets
295301
const loadBalancerHealthCheckConfig = ecsConfig.loadBalancerConfig.healthCheckConfig;
296302
const targetGroup = listener.addTargets(createCdkId([ecsConfig.identifier, 'TgtGrp']), {
297-
targetGroupName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2),
303+
targetGroupName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2).toLowerCase(),
298304
healthCheck: {
299305
path: loadBalancerHealthCheckConfig.path,
300306
interval: Duration.seconds(loadBalancerHealthCheckConfig.interval),

lambda/authorizer/lambda_functions.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,20 @@
1616
import logging
1717
import os
1818
import ssl
19+
from functools import cache
1920
from typing import Any, Dict
2021

22+
import boto3
2123
import create_env_variables # noqa: F401
2224
import jwt
2325
import requests
24-
from utilities.common_functions import authorization_wrapper, get_id_token
26+
from botocore.exceptions import ClientError
27+
from utilities.common_functions import authorization_wrapper, get_id_token, retry_config
2528

2629
logger = logging.getLogger(__name__)
2730

31+
secrets_manager = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config)
32+
2833

2934
@authorization_wrapper
3035
def lambda_handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: ignore [no-untyped-def]
@@ -48,6 +53,11 @@ def lambda_handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: i
4853

4954
deny_policy = generate_policy(effect="Deny", resource=event["methodArn"])
5055

56+
if id_token in get_management_tokens():
57+
allow_policy = generate_policy(effect="Allow", resource=event["methodArn"], username="lisa-management-token")
58+
logger.debug(f"Generated policy: {allow_policy}")
59+
return allow_policy
60+
5161
if jwt_data := id_token_is_valid(id_token=id_token, client_id=client_id, authority=authority):
5262
is_admin_user = is_admin(jwt_data, admin_group, jwt_groups_property)
5363
allow_policy = generate_policy(effect="Allow", resource=event["methodArn"], username=jwt_data["sub"])
@@ -134,3 +144,22 @@ def is_admin(jwt_data: dict[str, Any], admin_group: str, jwt_groups_property: st
134144
else:
135145
return False
136146
return admin_group in current_node
147+
148+
149+
@cache
150+
def get_management_tokens() -> list[str]:
151+
"""Return secret management tokens if they exist."""
152+
secret_tokens: list[str] = []
153+
secret_id = os.environ.get("MANAGEMENT_KEY_NAME")
154+
155+
try:
156+
secret_tokens.append(
157+
secrets_manager.get_secret_value(SecretId=secret_id, VersionStage="AWSCURRENT")["SecretString"]
158+
)
159+
secret_tokens.append(
160+
secrets_manager.get_secret_value(SecretId=secret_id, VersionStage="AWSPREVIOUS")["SecretString"]
161+
)
162+
except ClientError as e:
163+
logger.warn(f"Unable to fetch {secret_id}. {e.response['Error']['Code']}: {e.response['Error']['Message']}")
164+
165+
return secret_tokens

lambda/models/domain_objects.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Annotated, List, Optional, Union
1919

2020
from pydantic import BaseModel
21-
from pydantic.functional_validators import AfterValidator
21+
from pydantic.functional_validators import AfterValidator, field_validator
2222
from utilities.validators import validate_instance_type
2323

2424

@@ -87,7 +87,7 @@ class LoadBalancerConfig(BaseModel):
8787

8888

8989
class AutoScalingConfig(BaseModel):
90-
"""Autoscaling configuration."""
90+
"""Autoscaling configuration upon model creation."""
9191

9292
minCapacity: int
9393
maxCapacity: int
@@ -96,6 +96,14 @@ class AutoScalingConfig(BaseModel):
9696
metricConfig: MetricConfig
9797

9898

99+
class AutoScalingInstanceConfig(BaseModel):
100+
"""Autoscaling instance count configuration upon model update."""
101+
102+
minCapacity: Optional[int] = None
103+
maxCapacity: Optional[int] = None
104+
desiredCapacity: Optional[int] = None
105+
106+
99107
class ContainerHealthCheckConfig(BaseModel):
100108
"""Health check configuration for a container."""
101109

@@ -180,16 +188,24 @@ class GetModelResponse(ApiResponseBase):
180188
class UpdateModelRequest(BaseModel):
181189
"""Request object when updating a model."""
182190

183-
autoScalingConfig: Optional[AutoScalingConfig] = None
184-
containerConfig: Optional[ContainerConfig] = None
185-
inferenceContainer: Optional[InferenceContainer] = None
186-
instanceType: Optional[Annotated[str, AfterValidator(validate_instance_type)]] = None
187-
loadBalancerConfig: Optional[LoadBalancerConfig] = None
188-
modelId: str
189-
modelName: Optional[str] = None
191+
autoScalingInstanceConfig: Optional[AutoScalingInstanceConfig] = None
192+
enabled: Optional[bool] = None
190193
modelType: Optional[ModelType] = None
191194
streaming: Optional[bool] = None
192195

196+
@field_validator("autoScalingInstanceConfig") # type: ignore
197+
@classmethod
198+
def validate_autoscaling_instance_config(cls, config: AutoScalingInstanceConfig) -> AutoScalingInstanceConfig:
199+
"""Validate that the AutoScaling instance config has at least one positive value."""
200+
if not config:
201+
raise ValueError("The autoScalingInstanceConfig must not be null if defined in request payload.")
202+
config_fields = (config.minCapacity, config.maxCapacity, config.desiredCapacity)
203+
if all((field is None for field in config_fields)):
204+
raise ValueError("At least one option of autoScalingInstanceConfig must be defined.")
205+
if any((isinstance(field, int) and field < 0 for field in config_fields)):
206+
raise ValueError("All autoScalingInstanceConfig fields must be >= 0.")
207+
return config
208+
193209

194210
class UpdateModelResponse(ApiResponseBase):
195211
"""Response object when updating a model."""

lambda/models/exception/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
"""Exception definitions for model management APIs."""
1616

1717

18+
# LiteLLM errors
19+
20+
1821
class ModelNotFoundError(LookupError):
1922
"""Error to raise when a specified model cannot be found in the database."""
2023

@@ -25,3 +28,24 @@ class ModelAlreadyExistsError(LookupError):
2528
"""Error to raise when a specified model already exists in the database."""
2629

2730
pass
31+
32+
33+
# State machine exceptions
34+
35+
36+
class MaxPollsExceededException(Exception):
37+
"""Exception to indicate that polling for a state timed out."""
38+
39+
pass
40+
41+
42+
class StackFailedToCreateException(Exception):
43+
"""Exception to indicate that the CDK for creating a model stack failed."""
44+
45+
pass
46+
47+
48+
class UnexpectedCloudFormationStateException(Exception):
49+
"""Exception to indicate that the CloudFormation stack has transitioned to a non-healthy state."""
50+
51+
pass

lambda/models/lambda_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _create_dummy_model(model_name: str, model_type: ModelType, model_status: Mo
123123
),
124124
sharedMemorySize=2048,
125125
healthCheckConfig=ContainerHealthCheckConfig(
126-
command=["CMD-SHELL", "exit 0"], Interval=10, StartPeriod=30, Timeout=5, Retries=5
126+
command=["CMD-SHELL", "exit 0"], interval=10, startPeriod=30, timeout=5, retries=5
127127
),
128128
environment={
129129
"MAX_CONCURRENT_REQUESTS": "128",
@@ -177,7 +177,7 @@ async def update_model(
177177
) -> UpdateModelResponse:
178178
"""Endpoint to update a model."""
179179
# TODO add service to update model
180-
model = _create_dummy_model("model_name", ModelType.TEXTGEN, ModelStatus.UPDATING)
180+
model = _create_dummy_model(model_id, ModelType.TEXTGEN, ModelStatus.UPDATING)
181181
return UpdateModelResponse(model=model)
182182

183183

lambda/models/state_machine/create_model.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
from botocore.config import Config
2525
from models.clients.litellm_client import LiteLLMClient
2626
from models.domain_objects import CreateModelRequest, ModelStatus
27+
from models.exception import (
28+
MaxPollsExceededException,
29+
StackFailedToCreateException,
30+
UnexpectedCloudFormationStateException,
31+
)
2732
from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config
2833

2934
lambdaConfig = Config(connect_timeout=60, read_timeout=600, retries={"max_attempts": 1})
@@ -113,8 +118,13 @@ def handle_poll_docker_image_available(event: Dict[str, Any], context: Any) -> D
113118
output_dict["image_info"]["remaining_polls"] -= 1
114119
if output_dict["image_info"]["remaining_polls"] <= 0:
115120
ec2Client.terminate_instances(InstanceIds=[event["image_info"]["instance_id"]])
116-
raise Exception(
117-
"Maximum number of ECR poll attempts reached. Something went wrong building the docker image."
121+
raise MaxPollsExceededException(
122+
json.dumps(
123+
{
124+
"error": "Max number of ECR polls reached. Docker Image was not replicated successfully.",
125+
"event": event,
126+
}
127+
)
118128
)
119129
return output_dict
120130

@@ -153,7 +163,17 @@ def camelize_object(o): # type: ignore[no-untyped-def]
153163

154164
payload = response["Payload"].read()
155165
payload = json.loads(payload)
156-
stack_name = payload.get("stackName")
166+
stack_name = payload.get("stackName", None)
167+
168+
if not stack_name:
169+
raise StackFailedToCreateException(
170+
json.dumps(
171+
{
172+
"error": "Failed to create Model CloudFormation Stack. Please validate model parameters are valid.",
173+
"event": event,
174+
}
175+
)
176+
)
157177

158178
response = cfnClient.describe_stacks(StackName=stack_name)
159179
stack_arn = response["Stacks"][0]["StackId"]
@@ -192,10 +212,24 @@ def handle_poll_create_stack(event: Dict[str, Any], context: Any) -> Dict[str, A
192212
output_dict["continue_polling_stack"] = True
193213
output_dict["remaining_polls_stack"] -= 1
194214
if output_dict["remaining_polls_stack"] <= 0:
195-
raise Exception("Maximum number of CloudFormation polls reached")
215+
raise MaxPollsExceededException(
216+
json.dumps(
217+
{
218+
"error": "Max number of CloudFormation polls reached.",
219+
"event": event,
220+
}
221+
)
222+
)
196223
return output_dict
197224
else:
198-
raise Exception(f"Stack in unexpected state: {stackStatus}")
225+
raise UnexpectedCloudFormationStateException(
226+
json.dumps(
227+
{
228+
"error": f"Stack entered unexpected state: {stackStatus}",
229+
"event": event,
230+
}
231+
)
232+
)
199233

200234

201235
def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
@@ -234,3 +268,38 @@ def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str
234268
)
235269

236270
return output_dict
271+
272+
273+
def handle_failure(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
274+
"""
275+
Handle failures from state machine.
276+
277+
Possible causes of failures would be:
278+
1. Docker Image failed to replicate into ECR in expected amount of time
279+
2. CloudFormation Stack creation failed from parameter validation.
280+
3. CloudFormation Stack creation failed from taking too long to stand up.
281+
282+
Expectation of this function is to terminate the EC2 instance if it is still running, and to set the model status
283+
to Failed. Cleaning up the CloudFormation stack, if it still exists, will happen in the DeleteModel API.
284+
"""
285+
error_dict = json.loads( # error from SFN is json payload on top of json payload we add to the exception
286+
json.loads(event["Cause"])["errorMessage"]
287+
)
288+
error_reason = error_dict["error"]
289+
original_event = error_dict["event"]
290+
291+
# terminate EC2 instance if we have one recorded
292+
if "image_info" in original_event and "instance_id" in original_event["image_info"]:
293+
ec2Client.terminate_instances(InstanceIds=[original_event["image_info"]["instance_id"]])
294+
295+
# set model as Failed in DDB, so it shows as such in the UI. adds error reason as well.
296+
model_table.update_item(
297+
Key={"model_id": original_event["modelId"]},
298+
UpdateExpression="SET model_status = :ms, last_modified_date = :lm, failure_reason = :fr",
299+
ExpressionAttributeValues={
300+
":ms": ModelStatus.FAILED,
301+
":lm": int(datetime.utcnow().timestamp()),
302+
":fr": error_reason,
303+
},
304+
)
305+
return event

0 commit comments

Comments
 (0)