Skip to content

Commit 4ed9d91

Browse files
authored
Test update to serve tests to fix flakiness (#5480)
1 parent a7d470c commit 4ed9d91

File tree

2 files changed

+34
-16
lines changed

2 files changed

+34
-16
lines changed

sagemaker-serve/tests/integ/test_huggingface_integration.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,7 @@ def build_and_deploy():
9595
)
9696

9797
# Build and deploy your model. Returns SageMaker Core Model and Endpoint objects
98-
core_model = model_builder.build(
99-
model_name=f"{MODEL_NAME_PREFIX}-{unique_id}",
100-
region="us-east-1"
101-
)
98+
core_model = model_builder.build(model_name=f"{MODEL_NAME_PREFIX}-{unique_id}")
10299
logger.info(f"Model Successfully Created: {core_model.model_name}")
103100

104101
core_endpoint = model_builder.deploy(endpoint_name=f"{ENDPOINT_NAME_PREFIX}-{unique_id}")

sagemaker-serve/tests/integ/test_train_inference_e2e_integration.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@
3333
ENDPOINT_NAME_PREFIX = "train-inf-v3-test-endpoint"
3434
TRAINING_JOB_PREFIX = "e2e-v3-pytorch"
3535

36-
# Configuration
37-
AWS_REGION = "us-west-2"
38-
PYTORCH_TRAINING_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.13.1-cpu-py39"
39-
4036

4137
@pytest.mark.slow_test
4238
def test_train_inference_e2e_build_deploy_invoke_cleanup():
@@ -143,19 +139,28 @@ def create_schema_builder():
143139

144140
def train_model():
145141
"""Train model using ModelTrainer."""
142+
from sagemaker.core import image_uris
146143
from sagemaker.core.helper.session_helper import Session
147-
import boto3
148144

149-
# Create SageMaker session with AWS region
150-
boto_session = boto3.Session(region_name=AWS_REGION)
151-
sagemaker_session = Session(boto_session=boto_session)
145+
# Get the current region from a session
146+
session = Session()
147+
region = session.boto_region_name
152148

153149
training_code_dir = create_pytorch_training_code()
154150
unique_id = str(uuid.uuid4())[:8]
155151

152+
# Get training image for the current region
153+
training_image = image_uris.retrieve(
154+
framework="pytorch",
155+
region=region,
156+
version="1.13.1",
157+
py_version="py39",
158+
instance_type="ml.m5.xlarge",
159+
image_scope="training"
160+
)
161+
156162
model_trainer = ModelTrainer(
157-
sagemaker_session=sagemaker_session,
158-
training_image=PYTORCH_TRAINING_IMAGE,
163+
training_image=training_image,
159164
source_code=SourceCode(
160165
source_dir=training_code_dir,
161166
entry_script="train.py",
@@ -173,6 +178,8 @@ def train_model():
173178
def build_and_deploy(model_trainer, unique_id):
174179
"""Build and deploy model using ModelBuilder."""
175180
from sagemaker.serve.spec.inference_spec import InferenceSpec
181+
from sagemaker.core import image_uris
182+
from sagemaker.core.helper.session_helper import Session
176183

177184
class SimpleInferenceSpec(InferenceSpec):
178185
def load(self, model_dir):
@@ -185,16 +192,30 @@ def invoke(self, input_object, model):
185192

186193
schema_builder = create_schema_builder()
187194

195+
# Get the current region from a session
196+
session = Session()
197+
region = session.boto_region_name
198+
199+
# Get inference image for the current region
200+
inference_image = image_uris.retrieve(
201+
framework="pytorch",
202+
region=region,
203+
version="1.13.1",
204+
py_version="py39",
205+
instance_type="ml.m5.xlarge",
206+
image_scope="inference"
207+
)
208+
188209
model_builder = ModelBuilder(
189210
model=model_trainer,
190211
schema_builder=schema_builder,
191212
model_server=ModelServer.TORCHSERVE,
192213
inference_spec=SimpleInferenceSpec(),
193-
image_uri=PYTORCH_TRAINING_IMAGE.replace("training", "inference"),
214+
image_uri=inference_image,
194215
dependencies={"auto": False},
195216
)
196217

197-
core_model = model_builder.build(model_name=f"{MODEL_NAME_PREFIX}-{unique_id}", region="us-west-2")
218+
core_model = model_builder.build(model_name=f"{MODEL_NAME_PREFIX}-{unique_id}")
198219
logger.info(f"Model Successfully Created: {core_model.model_name}")
199220

200221
core_endpoint = model_builder.deploy(

0 commit comments

Comments
 (0)