Skip to content

Commit b839617

Browse files
authored
fix: correct Tag class usage in pipeline creation (#5526)
Pipeline creation was failing with Pydantic validation errors when BenchmarkEvaluator attempted to create a new SageMaker Pipeline. This occurred because the code imported Tag from sagemaker.core.resources instead of sagemaker.core.shapes, which is what Pipeline.create() expects for its tags parameter. Root Cause: The SDK has two different Tag classes: - sagemaker.core.resources.Tag: Used for Tag.get_all() operations - sagemaker.core.shapes.Tag: Used for Pipeline.create() parameter Changes: - Import Tag from sagemaker.core.shapes for Pipeline.create() - Import Tag as ResourceTag from sagemaker.core.resources for Tag.get_all() - Create proper Tag objects instead of dicts - Add error handling for tag conversion - Update Tag.get_all() calls to use ResourceTag Impact: This fixes benchmark evaluation failures (MMLU_PRO, BBH, GPQA, etc.) when creating new pipelines. Testing: Verified both creating new pipeline and reusing existing pipeline.
1 parent e993405 commit b839617

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

sagemaker-train/src/sagemaker/train/evaluate/execution.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from pydantic import BaseModel, Field
1919
from sagemaker.core.common_utils import TagsDict
2020
from sagemaker.core.helper.session_helper import Session
21-
from sagemaker.core.resources import Pipeline, PipelineExecution, Tag
21+
from sagemaker.core.resources import Pipeline, PipelineExecution
22+
from sagemaker.core.resources import Tag as ResourceTag # For Tag.get_all()
23+
from sagemaker.core.shapes import Tag # For Pipeline.create() tags parameter
2224
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2325
from sagemaker.core.telemetry.constants import Feature
2426

@@ -68,9 +70,33 @@ def _create_evaluation_pipeline(
6870
resolved_pipeline_definition = template.render(pipeline_name=pipeline_name)
6971

7072
# Create tags for the pipeline
71-
tags.extend([
72-
{"key": _TAG_SAGEMAKER_MODEL_EVALUATION, "value": "true"}
73-
])
73+
# Note: Tags must be Tag objects, not dicts, for Pydantic validation to pass
74+
tag_objects = []
75+
76+
# Add evaluation tag
77+
tag_objects.append(Tag(key=_TAG_SAGEMAKER_MODEL_EVALUATION, value="true"))
78+
79+
# Process any additional tags passed in
80+
if tags:
81+
for i, tag_item in enumerate(tags):
82+
try:
83+
if hasattr(tag_item, '__class__') and 'Tag' in tag_item.__class__.__name__:
84+
# Already a Tag object
85+
tag_objects.append(tag_item)
86+
elif isinstance(tag_item, dict):
87+
# Convert dict to Tag object - handle both lowercase and capitalized keys
88+
key = tag_item.get("key") or tag_item.get("Key")
89+
value = tag_item.get("value") or tag_item.get("Value")
90+
if key and value:
91+
tag_objects.append(Tag(key=str(key), value=str(value)))
92+
else:
93+
logger.warning(f"Skipping invalid tag at index {i}: {tag_item}")
94+
else:
95+
logger.warning(f"Skipping unsupported tag type at index {i}: {type(tag_item)}")
96+
except Exception as e:
97+
logger.warning(f"Error processing tag at index {i}: {e}")
98+
99+
logger.info(f"Creating pipeline with {len(tag_objects)} tags")
74100

75101
pipeline = Pipeline.create(
76102
pipeline_name=pipeline_name,
@@ -79,7 +105,7 @@ def _create_evaluation_pipeline(
79105
pipeline_definition=resolved_pipeline_definition,
80106
pipeline_display_name=f"EvaluationPipeline-{eval_type.value}",
81107
pipeline_description=f"Pipeline for {eval_type.value} evaluation jobs",
82-
tags=tags,
108+
tags=tag_objects,
83109
session=session,
84110
region=region
85111
)
@@ -205,8 +231,8 @@ def _get_or_create_pipeline(
205231
for pipeline in pipelines:
206232
pipeline_arn = pipeline.pipeline_arn
207233

208-
# Get tags using Tag.get_all
209-
tags_list = Tag.get_all(resource_arn=pipeline_arn, session=session, region=region)
234+
# Get tags using ResourceTag.get_all
235+
tags_list = ResourceTag.get_all(resource_arn=pipeline_arn, session=session, region=region)
210236
tags = {tag.key: tag.value for tag in tags_list}
211237

212238
# Validate tag
@@ -647,8 +673,8 @@ def get_all(
647673
try:
648674
pipeline_arn = pipeline.pipeline_arn
649675

650-
# Get tags using Tag.get_all
651-
tags_list = Tag.get_all(resource_arn=pipeline_arn, session=session, region=region)
676+
# Get tags using ResourceTag.get_all
677+
tags_list = ResourceTag.get_all(resource_arn=pipeline_arn, session=session, region=region)
652678
tags = {tag.key: tag.value for tag in tags_list}
653679

654680
# Validate tag - only process evaluation pipelines

0 commit comments

Comments
 (0)