Skip to content

Commit 00124a9

Browse files
authored
Add sagemaker dependency for remote function by default (#5485)
* Add sagemaker dependency for remote function by default * Revise sagemaker compatibility check * Fixing unit and itnegration tests * Fix codestyle issues * More codestyle fixes * Fixing one more codestyle issue * Fixing flake errors * More codestyle fixes * More flake test fixes * Fixing one more flake error
1 parent a140cfc commit 00124a9

File tree

4 files changed

+748
-22
lines changed

4 files changed

+748
-22
lines changed

src/sagemaker/remote_function/job.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,11 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
12351235

12361236
local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies)
12371237

1238+
# Ensure sagemaker dependency is included to prevent version mismatch issues
1239+
# Resolves issue where computing hash for integrity check changed in 2.256.0
1240+
local_dependencies_path = _ensure_sagemaker_dependency(local_dependencies_path)
1241+
job_settings.dependencies = local_dependencies_path
1242+
12381243
if step_compilation_context:
12391244
with _tmpdir() as tmp_dir:
12401245
script_and_dependencies_s3uri = _prepare_dependencies_and_pre_execution_scripts(
@@ -1291,6 +1296,225 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
12911296
return input_data_config
12921297

12931298

1299+
def _decrement_version(version_str: str) -> str:
1300+
"""Decrement a version string by one minor or patch version.
1301+
1302+
Rules:
1303+
- If patch version is 0 (e.g., 2.256.0), decrement minor: 2.256.0 -> 2.255.0
1304+
- If patch version is not 0 (e.g., 2.254.2), decrement patch: 2.254.2 -> 2.254.1
1305+
1306+
Args:
1307+
version_str: Version string (e.g., "2.256.0")
1308+
1309+
Returns:
1310+
Decremented version string
1311+
"""
1312+
from packaging import version as pkg_version
1313+
1314+
try:
1315+
parsed = pkg_version.parse(version_str)
1316+
major = parsed.major
1317+
minor = parsed.minor
1318+
patch = parsed.micro
1319+
1320+
if patch == 0:
1321+
# Decrement minor version
1322+
minor = max(0, minor - 1)
1323+
else:
1324+
# Decrement patch version
1325+
patch = max(0, patch - 1)
1326+
1327+
return f"{major}.{minor}.{patch}"
1328+
except Exception:
1329+
return version_str
1330+
1331+
1332+
def _resolve_version_from_specifier(specifier_str: str) -> str:
1333+
"""Resolve the version to check based on upper bounds.
1334+
1335+
Upper bounds take priority. If upper bound is <3.0.0, it's safe (V2 only).
1336+
If no upper bound exists, it's safe (unbounded).
1337+
If the decremented upper bound is less than a lower bound, use the lower bound.
1338+
1339+
Args:
1340+
specifier_str: Version specifier string (e.g., ">=2.256.0", "<2.256.0", "==2.255.0")
1341+
1342+
Returns:
1343+
The resolved version string to check, or None if safe
1344+
"""
1345+
import re
1346+
from packaging import version as pkg_version
1347+
1348+
# Handle exact version pinning (==)
1349+
match = re.search(r"==\s*([\d.]+)", specifier_str)
1350+
if match:
1351+
return match.group(1)
1352+
1353+
# Extract lower bounds for comparison
1354+
lower_bounds = []
1355+
for match in re.finditer(r">=\s*([\d.]+)", specifier_str):
1356+
lower_bounds.append(match.group(1))
1357+
1358+
# Handle upper bounds - find the most restrictive one
1359+
upper_bounds = []
1360+
1361+
# Find all <= bounds
1362+
for match in re.finditer(r"<=\s*([\d.]+)", specifier_str):
1363+
upper_bounds.append(("<=", match.group(1)))
1364+
1365+
# Find all < bounds
1366+
for match in re.finditer(r"<\s*([\d.]+)", specifier_str):
1367+
upper_bounds.append(("<", match.group(1)))
1368+
1369+
if upper_bounds:
1370+
# Sort by version to find the most restrictive (lowest) upper bound
1371+
upper_bounds.sort(key=lambda x: pkg_version.parse(x[1]))
1372+
operator, version = upper_bounds[0]
1373+
1374+
# Special case: if upper bound is <3.0.0, it's safe (V2 only)
1375+
try:
1376+
parsed_upper = pkg_version.parse(version)
1377+
if (
1378+
operator == "<"
1379+
and parsed_upper.major == 3
1380+
and parsed_upper.minor == 0
1381+
and parsed_upper.micro == 0
1382+
):
1383+
# <3.0.0 means V2 only, which is safe
1384+
return None
1385+
except Exception:
1386+
pass
1387+
1388+
resolved_version = version
1389+
if operator == "<":
1390+
resolved_version = _decrement_version(version)
1391+
1392+
# If we have a lower bound and the resolved version is less than it, use the lower bound
1393+
if lower_bounds:
1394+
try:
1395+
resolved_parsed = pkg_version.parse(resolved_version)
1396+
for lower_bound_str in lower_bounds:
1397+
lower_parsed = pkg_version.parse(lower_bound_str)
1398+
if resolved_parsed < lower_parsed:
1399+
resolved_version = lower_bound_str
1400+
except Exception:
1401+
pass
1402+
1403+
return resolved_version
1404+
1405+
# For lower bounds only (>=, >), we don't check
1406+
return None
1407+
1408+
1409+
def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None:
1410+
"""Check if the sagemaker version requirement uses incompatible hashing.
1411+
1412+
Raises ValueError if the requirement would install a version that uses HMAC hashing
1413+
(which is incompatible with the current SHA256-based integrity checks).
1414+
1415+
Args:
1416+
sagemaker_requirement: The sagemaker requirement string (e.g., "sagemaker>=2.200.0")
1417+
1418+
Raises:
1419+
ValueError: If the requirement would install a version using HMAC hashing
1420+
"""
1421+
import re
1422+
from packaging import version as pkg_version
1423+
1424+
match = re.search(r"sagemaker\s*(.+)$", sagemaker_requirement.strip(), re.IGNORECASE)
1425+
if not match:
1426+
return
1427+
1428+
specifier_str = match.group(1).strip()
1429+
1430+
# Resolve the version that would be installed
1431+
resolved_version_str = _resolve_version_from_specifier(specifier_str)
1432+
if not resolved_version_str:
1433+
# No upper bound or exact version, so we can't determine if it's bad
1434+
return
1435+
1436+
try:
1437+
resolved_version = pkg_version.parse(resolved_version_str)
1438+
except Exception:
1439+
return
1440+
1441+
# Define HMAC thresholds for each major version
1442+
v2_hmac_threshold = pkg_version.parse("2.256.0")
1443+
v3_hmac_threshold = pkg_version.parse("3.2.0")
1444+
1445+
# Check if the resolved version uses HMAC hashing
1446+
uses_hmac = False
1447+
if resolved_version.major == 2 and resolved_version < v2_hmac_threshold:
1448+
uses_hmac = True
1449+
elif resolved_version.major == 3 and resolved_version < v3_hmac_threshold:
1450+
uses_hmac = True
1451+
1452+
if uses_hmac:
1453+
raise ValueError(
1454+
f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) "
1455+
f"could install a version using HMAC-based integrity checks which are incompatible "
1456+
f"with the current SHA256-based integrity checks. Please update to "
1457+
f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)."
1458+
)
1459+
1460+
1461+
def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str:
1462+
"""Ensure sagemaker>=2.256.0 is in the dependencies.
1463+
1464+
This function ensures that the remote environment has a compatible version of sagemaker
1465+
that includes the fix for the HMAC key security issue. Versions < 2.256.0 use HMAC-based
1466+
integrity checks which require the REMOTE_FUNCTION_SECRET_KEY environment variable.
1467+
Versions >= 2.256.0 use SHA256-based integrity checks which are secure and don't require
1468+
the secret key.
1469+
1470+
If no dependencies are provided, creates a temporary requirements.txt with sagemaker.
1471+
If dependencies are provided, appends sagemaker if not already present.
1472+
1473+
Args:
1474+
local_dependencies_path: Path to user's dependencies file or None
1475+
1476+
Returns:
1477+
Path to the dependencies file (created or modified)
1478+
1479+
Raises:
1480+
ValueError: If user has pinned sagemaker to a version using HMAC hashing
1481+
"""
1482+
import tempfile
1483+
1484+
SAGEMAKER_MIN_VERSION = "sagemaker>=2.256.0,<3.0.0"
1485+
1486+
if local_dependencies_path is None:
1487+
# Create a temporary requirements.txt in the system temp directory
1488+
fd, req_file = tempfile.mkstemp(suffix=".txt", prefix="sagemaker_requirements_")
1489+
os.close(fd)
1490+
1491+
with open(req_file, "w") as f:
1492+
f.write(f"{SAGEMAKER_MIN_VERSION}\n")
1493+
logger.info(
1494+
"Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION
1495+
)
1496+
return req_file
1497+
1498+
# If dependencies provided, ensure sagemaker is included
1499+
if local_dependencies_path.endswith(".txt"):
1500+
with open(local_dependencies_path, "r") as f:
1501+
content = f.read()
1502+
1503+
# Check if sagemaker is already specified
1504+
if "sagemaker" in content.lower():
1505+
# Extract the sagemaker requirement line for compatibility check
1506+
for line in content.split("\n"):
1507+
if "sagemaker" in line.lower():
1508+
_check_sagemaker_version_compatibility(line.strip())
1509+
break
1510+
else:
1511+
with open(local_dependencies_path, "a") as f:
1512+
f.write(f"\n{SAGEMAKER_MIN_VERSION}\n")
1513+
logger.info("Appended %s to requirements.txt", SAGEMAKER_MIN_VERSION)
1514+
1515+
return local_dependencies_path
1516+
1517+
12941518
def _prepare_dependencies_and_pre_execution_scripts(
12951519
local_dependencies_path: str,
12961520
pre_execution_commands: List[str],
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""Integration tests for sagemaker dependency injection in remote functions.
2+
3+
These tests verify that the sagemaker>=2.256.0 dependency is properly injected
4+
into remote function jobs, preventing version mismatch issues.
5+
"""
6+
7+
from __future__ import absolute_import
8+
9+
import os
10+
import sys
11+
import tempfile
12+
13+
import pytest
14+
15+
# Add src to path before importing sagemaker
16+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../src"))
17+
18+
from sagemaker.remote_function import remote # noqa: E402
19+
20+
# Skip decorator for AWS configuration
21+
skip_if_no_aws_region = pytest.mark.skipif(
22+
not os.environ.get("AWS_DEFAULT_REGION"), reason="AWS credentials not configured"
23+
)
24+
25+
26+
class TestRemoteFunctionDependencyInjection:
27+
"""Integration tests for dependency injection in remote functions."""
28+
29+
@pytest.mark.integ
30+
@skip_if_no_aws_region
31+
def test_remote_function_without_dependencies(self):
32+
"""Test remote function execution without explicit dependencies.
33+
34+
This test verifies that when no dependencies are provided, the remote
35+
function still executes successfully because sagemaker>=2.256.0 is
36+
automatically injected.
37+
"""
38+
39+
@remote(
40+
instance_type="ml.m5.large",
41+
# No dependencies specified - sagemaker should be injected automatically
42+
)
43+
def simple_add(x, y):
44+
"""Simple function that adds two numbers."""
45+
return x + y
46+
47+
# Execute the function
48+
result = simple_add(5, 3)
49+
50+
# Verify result
51+
assert result == 8, f"Expected 8, got {result}"
52+
print("✓ Remote function without dependencies executed successfully")
53+
54+
@pytest.mark.integ
55+
@skip_if_no_aws_region
56+
def test_remote_function_with_user_dependencies_no_sagemaker(self):
57+
"""Test remote function with user dependencies but no sagemaker.
58+
59+
This test verifies that when user provides dependencies without sagemaker,
60+
sagemaker>=2.256.0 is automatically appended.
61+
"""
62+
# Create a temporary requirements.txt without sagemaker
63+
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
64+
f.write("numpy>=1.20.0\npandas>=1.3.0\n")
65+
req_file = f.name
66+
67+
try:
68+
69+
@remote(
70+
instance_type="ml.m5.large",
71+
dependencies=req_file,
72+
)
73+
def compute_with_numpy(x):
74+
"""Function that uses numpy."""
75+
import numpy as np
76+
77+
return np.array([x, x * 2, x * 3]).sum()
78+
79+
# Execute the function
80+
result = compute_with_numpy(5)
81+
82+
# Verify result (5 + 10 + 15 = 30)
83+
assert result == 30, f"Expected 30, got {result}"
84+
print("✓ Remote function with user dependencies executed successfully")
85+
finally:
86+
os.remove(req_file)
87+
88+
89+
class TestRemoteFunctionVersionCompatibility:
90+
"""Tests for version compatibility between local and remote environments."""
91+
92+
@pytest.mark.integ
93+
@skip_if_no_aws_region
94+
def test_deserialization_with_injected_sagemaker(self):
95+
"""Test that deserialization works with injected sagemaker dependency.
96+
97+
This test verifies that the remote environment can properly deserialize
98+
functions when sagemaker>=2.256.0 is available.
99+
"""
100+
101+
@remote(
102+
instance_type="ml.m5.large",
103+
)
104+
def complex_computation(data):
105+
"""Function that performs complex computation."""
106+
result = sum(data) * len(data)
107+
return result
108+
109+
# Execute with various data types
110+
test_data = [1, 2, 3, 4, 5]
111+
result = complex_computation(test_data)
112+
113+
# Verify result (sum=15, len=5, 15*5=75)
114+
assert result == 75, f"Expected 75, got {result}"
115+
print("✓ Deserialization with injected sagemaker works correctly")
116+
117+
@pytest.mark.integ
118+
@skip_if_no_aws_region
119+
def test_multiple_remote_functions_with_dependencies(self):
120+
"""Test multiple remote functions with different dependency configurations.
121+
122+
This test verifies that the dependency injection works correctly
123+
when multiple remote functions are defined and executed.
124+
"""
125+
126+
@remote(instance_type="ml.m5.large")
127+
def func1(x):
128+
return x + 1
129+
130+
@remote(instance_type="ml.m5.large")
131+
def func2(x):
132+
return x * 2
133+
134+
# Execute both functions
135+
result1 = func1(5)
136+
result2 = func2(5)
137+
138+
assert result1 == 6, f"func1: Expected 6, got {result1}"
139+
assert result2 == 10, f"func2: Expected 10, got {result2}"
140+
print("✓ Multiple remote functions with dependencies executed successfully")
141+
142+
143+
if __name__ == "__main__":
144+
pytest.main([__file__, "-v", "-m", "integ"])

0 commit comments

Comments
 (0)