Skip to content

Commit 49e53e8

Browse files
Chore: Add unit and integ tests for already upgraded search functionality in jumpstart code (#5544)
* Chore: Add unit and integ tests for already upgraded search functionality in jumpstart code * fix existing unit test error with sagemaker-train evaluate test execution * fix existing unit test error with sagemaker-train evaluate test execution 2
1 parent b839617 commit 49e53e8

File tree

3 files changed

+106
-4
lines changed

3 files changed

+106
-4
lines changed

sagemaker-core/tests/integ/jumpstart/test_search_integ.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,19 @@ def test_search_public_hub_models_all_args():
6666

6767
assert isinstance(results, list)
6868
assert all(isinstance(m, HubContent) for m in results)
69+
70+
71+
@pytest.mark.serial
72+
@pytest.mark.integ
73+
def test_search_public_hub_models_safe_from_injection():
74+
"""Integration test to verify malicious queries don't execute code."""
75+
# This would have executed code with the old eval() implementation
76+
malicious_query = "__import__('os').system('echo test')"
77+
78+
# Should safely return empty results without executing code
79+
results = search_public_hub_models(malicious_query)
80+
81+
# Verify it returns a list (even if empty) and doesn't crash
82+
assert isinstance(results, list)
83+
# Should not match any models since it's not a valid filter expression
84+
assert len(results) == 0

sagemaker-core/tests/unit/jumpstart/test_search_unit.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,85 @@ def test_filter_match(query, keywords, expected):
4747
assert f.match(keywords) == expected
4848

4949

50+
@pytest.mark.parametrize(
51+
"malicious_query,keywords",
52+
[
53+
# Code injection attempts that would work with eval()
54+
("__import__('os').system('echo pwned')", ["test"]),
55+
("exec('import os; os.system(\"ls\")')", ["test"]),
56+
("eval('1+1')", ["test"]),
57+
("__builtins__.__import__('os').system('ls')", ["test"]),
58+
# Attribute access attempts
59+
("keywords.__class__.__bases__[0].__subclasses__()", ["test"]),
60+
# Lambda injection
61+
("(lambda: __import__('os').system('ls'))()", ["test"]),
62+
# Dict/list comprehension injection
63+
("[x for x in ().__class__.__bases__[0].__subclasses__()]", ["test"]),
64+
# Function call injection
65+
("open('/etc/passwd').read()", ["test"]),
66+
# Module access
67+
("sys.exit()", ["test"]),
68+
("os.system('ls')", ["test"]),
69+
],
70+
)
71+
def test_filter_blocks_code_injection(malicious_query, keywords):
72+
"""Test that malicious code injection attempts are safely handled."""
73+
f = _Filter(malicious_query)
74+
# Should not execute code, just return False for non-matching patterns
75+
result = f.match(keywords)
76+
assert isinstance(result, bool)
77+
# The filter should safely fail to match rather than execute code
78+
assert result is False
79+
80+
81+
@pytest.mark.parametrize(
82+
"injection_query",
83+
[
84+
# Various eval-based injection patterns
85+
"'; __import__('os').system('ls'); '",
86+
"\"; exec('import os'); \"",
87+
"') or __import__('os').system('ls') or ('",
88+
# Nested injection attempts
89+
"test AND (__import__('os').system('ls'))",
90+
"NOT (__import__('subprocess').call(['ls']))",
91+
# String escape attempts
92+
"test' + str(__import__('os').system('ls')) + '",
93+
],
94+
)
95+
def test_filter_injection_variants(injection_query):
96+
"""Test various code injection patterns are blocked."""
97+
f = _Filter(injection_query)
98+
result = f.match(["test", "keyword"])
99+
assert isinstance(result, bool)
100+
# Should not raise exceptions or execute code
101+
assert result in [True, False]
102+
103+
104+
def test_filter_no_eval_execution():
105+
"""Verify that expressions are parsed safely without eval()."""
106+
# This would execute code if eval() was used
107+
dangerous_expr = "__import__('sys').exit(1)"
108+
f = _Filter(dangerous_expr)
109+
110+
# Should not crash the program or execute the exit
111+
result = f.match(["test"])
112+
assert result is False
113+
114+
115+
def test_filter_safe_ast_parsing():
116+
"""Test that the filter uses AST parsing instead of eval()."""
117+
f = _Filter("test AND keyword")
118+
119+
# Verify AST is created
120+
assert f._ast is None # Not parsed yet
121+
f.match(["test", "keyword"])
122+
assert f._ast is not None # AST created after first match
123+
124+
# Verify it's an AST node, not a string for eval
125+
from sagemaker.core.jumpstart.search import _ExpressionNode
126+
assert isinstance(f._ast, _ExpressionNode)
127+
128+
50129
def test_search_public_hub_models():
51130
mock_models = [
52131
HubContent(

sagemaker-train/tests/unit/train/evaluate/test_execution.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,17 @@ class MockUnassigned:
5757

5858
@pytest.fixture
5959
def mock_session():
60-
"""Mock SageMaker session."""
60+
"""Mock SageMaker session that passes isinstance checks."""
61+
from sagemaker.core.helper.session_helper import Session
62+
63+
# Create a mock that will pass isinstance(obj, Session) checks
6164
session = MagicMock()
6265
session.boto_region_name = DEFAULT_REGION
6366
session.client.return_value = MagicMock()
67+
68+
# Make isinstance check pass
69+
session.__class__ = type('MockSession', (Session,), {})
70+
6471
return session
6572

6673

@@ -247,7 +254,7 @@ def test_extract_with_exception(self):
247254
class TestGetOrCreatePipeline:
248255
"""Tests for _get_or_create_pipeline function."""
249256

250-
@patch("sagemaker.train.evaluate.execution.Tag")
257+
@patch("sagemaker.train.evaluate.execution.ResourceTag")
251258
@patch("sagemaker.train.evaluate.execution.Pipeline")
252259
def test_get_existing_pipeline_and_update(self, mock_pipeline_class, mock_tag_class, mock_session):
253260
"""Test getting and updating existing pipeline via Pipeline.get_all with prefix."""
@@ -757,7 +764,7 @@ def test_get_execution_generic_exception(self, mock_pe_class, mock_session):
757764
class TestEvaluationPipelineExecutionGetAll:
758765
"""Tests for EvaluationPipelineExecution.get_all() method."""
759766

760-
@patch("sagemaker.train.evaluate.execution.Tag")
767+
@patch("sagemaker.train.evaluate.execution.ResourceTag")
761768
@patch("sagemaker.train.evaluate.execution.Pipeline")
762769
@patch("sagemaker.train.evaluate.execution.PipelineExecution")
763770
def test_get_all_executions(self, mock_pe_class, mock_pipeline_class, mock_tag_class, mock_session):
@@ -800,7 +807,7 @@ def test_get_all_executions(self, mock_pe_class, mock_pipeline_class, mock_tag_c
800807
# Verify PipelineExecution.get_all was called with the pipeline name
801808
mock_pe_class.get_all.assert_called_once()
802809

803-
@patch("sagemaker.train.evaluate.execution.Tag")
810+
@patch("sagemaker.train.evaluate.execution.ResourceTag")
804811
@patch("sagemaker.train.evaluate.execution.Pipeline")
805812
@patch("sagemaker.train.evaluate.execution.PipelineExecution")
806813
def test_get_all_multiple_eval_types(self, mock_pe_class, mock_pipeline_class, mock_tag_class, mock_session):

0 commit comments

Comments
 (0)