Skip to content

Commit 58d80cd

Browse files
authored
fix(integ): remove module-level env var from MTRL tests (#5929)
The 4 MTRL test files added in #5919 set SAGEMAKER_REGION=us-west-2 at module level via os.environ.setdefault(). When pytest collects tests, it imports all modules in the directory—including these files—even when filtering by mark (e.g. -m "gpu_intensive and us_east_1"). This poisons the SageMakerClient singleton's region for the entire process, causing the Nova us-east-1 integ tests (SFT, RLVR, BenchmarkEvaluator) to fail with "ModelPackageGroup does not exist" because API calls land in us-west-2 instead of us-east-1. Changes: - Remove os.environ.setdefault() calls from module level in all 4 files - Move _ACCOUNT_ID resolution from module-level boto3 calls into lazy fixtures/functions (avoids import-time side effects) - Use boto3.Session().client() pattern instead of boto3.client() to follow SDK conventions (wrapper over boto, not direct boto usage) - Align with existing test patterns (DPO, SFT, RLVR) that rely on conftest fixtures for session management Affected files: - test_mtrl_evaluator.py - test_mtrl_evaluator_3p_agent.py - test_mtrl_trainer_integration.py - test_multi_turn_rl_trainer_integration.py
1 parent b82b3aa commit 58d80cd

4 files changed

Lines changed: 196 additions & 176 deletions

File tree

sagemaker-train/tests/integ/train/test_mtrl_evaluator.py

Lines changed: 82 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@
2222
import pytest
2323
import logging
2424

25-
os.environ.setdefault("AWS_DEFAULT_REGION", "us-west-2")
26-
os.environ.setdefault("SAGEMAKER_REGION", "us-west-2")
27-
2825
import boto3
2926
from sagemaker.core.helper.session_helper import Session
3027
from sagemaker.train.evaluate import MultiTurnRLEvaluator
@@ -36,22 +33,30 @@
3633
# Timeout for evaluation pipeline execution (4 hours)
3734
EVALUATION_TIMEOUT_SECONDS = 14400
3835

39-
# Resolve current account ID for account-agnostic paths
4036
_REGION = "us-west-2"
41-
_ACCOUNT_ID = boto3.client("sts", region_name=_REGION).get_caller_identity()["Account"]
42-
43-
# Test configuration — uses current account for all resource paths
44-
TEST_CONFIG = {
45-
#"base_model": "huggingface-vlm-qwen3-6-27b",
46-
"base_model": "openai-reasoning-gpt-oss-20b",
47-
"agent_arn": f"arn:aws:bedrock-agentcore:{_REGION}:{_ACCOUNT_ID}:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS",
48-
"dataset": f"s3://sagemaker-rft-{_ACCOUNT_ID}/prompts/gsm8k_small/prompts.parquet",
49-
"s3_output_path": f"s3://sagemaker-{_REGION}-{_ACCOUNT_ID}/model-evaluation/output-artifacts/",
50-
"mlflow_resource_arn": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6",
51-
"model_package_group": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:model-package-group/openai-reasoning-gpt-oss-20b-mtrl-mpg",
52-
"role": f"arn:aws:iam::{_ACCOUNT_ID}:role/Admin",
53-
"region": _REGION,
54-
}
37+
38+
39+
def _get_test_config():
40+
"""Build test configuration lazily (only when tests actually run)."""
41+
boto_session = boto3.Session(region_name=_REGION)
42+
account_id = boto_session.client("sts").get_caller_identity()["Account"]
43+
return {
44+
"base_model": "openai-reasoning-gpt-oss-20b",
45+
"agent_arn": f"arn:aws:bedrock-agentcore:{_REGION}:{account_id}:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS",
46+
"dataset": f"s3://sagemaker-rft-{account_id}/prompts/gsm8k_small/prompts.parquet",
47+
"s3_output_path": f"s3://sagemaker-{_REGION}-{account_id}/model-evaluation/output-artifacts/",
48+
"mlflow_resource_arn": f"arn:aws:sagemaker:{_REGION}:{account_id}:mlflow-app/app-TTAUWUNMUHH6",
49+
"model_package_group": f"arn:aws:sagemaker:{_REGION}:{account_id}:model-package-group/openai-reasoning-gpt-oss-20b-mtrl-mpg",
50+
"role": f"arn:aws:iam::{account_id}:role/Admin",
51+
"region": _REGION,
52+
"account_id": account_id,
53+
}
54+
55+
56+
@pytest.fixture(scope="module")
57+
def test_config():
58+
"""Lazily resolve test configuration (avoids module-level API calls)."""
59+
return _get_test_config()
5560

5661

5762
def _ensure_model_package_group_exists(sm_client, group_name):
@@ -84,35 +89,35 @@ def _ensure_model_package_exists(sm_client, group_name, base_model_name):
8489

8590

8691
@pytest.fixture(scope="module")
87-
def sagemaker_session():
92+
def sagemaker_session_mtrl():
8893
"""Create a SageMaker session with explicit region for CI environments."""
89-
boto_session = boto3.Session(region_name=TEST_CONFIG["region"])
94+
boto_session = boto3.Session(region_name=_REGION)
9095
return Session(boto_session=boto_session)
9196

9297

9398
@pytest.fixture(scope="module")
94-
def mtrl_trainer(sagemaker_session):
99+
def mtrl_trainer(sagemaker_session_mtrl, test_config):
95100
"""Create a lightweight MultiTurnRLTrainer-like object for evaluator tests.
96101
97102
Instead of going through the full constructor (which validates remote
98103
resources), we build a minimal object with the attributes the evaluator
99104
needs. This makes the test account-agnostic — it creates the required
100105
resources (model package group + model package) on the fly.
101106
"""
102-
sm_client = sagemaker_session.boto_session.client("sagemaker")
107+
sm_client = sagemaker_session_mtrl.boto_session.client("sagemaker")
103108
group_name = "mtrl-integ-test-evaluator"
104109
_ensure_model_package_group_exists(sm_client, group_name)
105110
model_package_arn = _ensure_model_package_exists(
106-
sm_client, group_name, TEST_CONFIG["base_model"]
111+
sm_client, group_name, test_config["base_model"]
107112
)
108113

109114
trainer = object.__new__(MultiTurnRLTrainer)
110-
trainer._model_name = TEST_CONFIG["base_model"]
111-
trainer._model_arn = f"arn:aws:sagemaker:{_REGION}:aws:hub-content/SageMakerPublicHub/Model/{TEST_CONFIG['base_model']}/1.0.0"
112-
trainer.agent_env = TEST_CONFIG["agent_arn"]
115+
trainer._model_name = test_config["base_model"]
116+
trainer._model_arn = f"arn:aws:sagemaker:{_REGION}:aws:hub-content/SageMakerPublicHub/Model/{test_config['base_model']}/1.0.0"
117+
trainer.agent_env = test_config["agent_arn"]
113118
trainer.bedrock_agentcore_qualifier = "DEFAULT"
114-
trainer.output_model_package_group = TEST_CONFIG["model_package_group"]
115-
trainer.sagemaker_session = sagemaker_session
119+
trainer.output_model_package_group = test_config["model_package_group"]
120+
trainer.sagemaker_session = sagemaker_session_mtrl
116121

117122
# Use the real model package ARN from the account
118123
class _FakeJob:
@@ -131,27 +136,27 @@ class _FakeJob:
131136
class TestMTRLEvaluatorJobConfigDocument:
132137
"""Tests validating the JobConfigDocument field naming for GA API contract."""
133138

134-
def test_bedrock_agent_config_fields(self, mtrl_trainer):
139+
def test_bedrock_agent_config_fields(self, mtrl_trainer, test_config):
135140
"""Verify BedrockAgentCoreConfig uses AgentRuntimeArn and Qualifier."""
136141
evaluator = MultiTurnRLEvaluator(
137142
model=mtrl_trainer,
138-
dataset=TEST_CONFIG["dataset"],
139-
s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-fields-bedrock/',
140-
mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"],
141-
role=TEST_CONFIG["role"],
142-
region=TEST_CONFIG["region"],
143-
agent_config=TEST_CONFIG["agent_arn"],
143+
dataset=test_config["dataset"],
144+
s3_output_path=f'{test_config["s3_output_path"]}integ-fields-bedrock/',
145+
mlflow_resource_arn=test_config["mlflow_resource_arn"],
146+
role=test_config["role"],
147+
region=test_config["region"],
148+
agent_config=test_config["agent_arn"],
144149
agent_qualifier="PROD",
145150
)
146151

147152
evaluator._resolve_trainer_defaults()
148153
evaluator._resolve_agent_arn()
149154

150155
ctx = evaluator._build_template_context(
151-
aws_context={"region": TEST_CONFIG["region"], "account_id": _ACCOUNT_ID,
152-
"role_arn": TEST_CONFIG["role"]},
156+
aws_context={"region": test_config["region"], "account_id": test_config["account_id"],
157+
"role_arn": test_config["role"]},
153158
artifacts={},
154-
model_package_group_arn=TEST_CONFIG["model_package_group"],
159+
model_package_group_arn=test_config["model_package_group"],
155160
)
156161

157162
doc = json.loads(ctx["job_config_document_ft_str"])
@@ -166,27 +171,27 @@ def test_bedrock_agent_config_fields(self, mtrl_trainer):
166171
assert "AgentArn" not in agent_cfg.get("BedrockAgentCoreConfig", {})
167172
assert "BedrockAgentCoreQualifier" not in agent_cfg.get("BedrockAgentCoreConfig", {})
168173

169-
def test_lambda_agent_config_fields(self, mtrl_trainer):
174+
def test_lambda_agent_config_fields(self, mtrl_trainer, test_config):
170175
"""Verify Lambda agent uses CustomAgentLambdaConfig (not LambdaConfig)."""
171176
lambda_arn = "arn:aws:lambda:us-east-1:060795915353:function:SageMaker-agent-adapter-gsm8k"
172177
evaluator = MultiTurnRLEvaluator(
173178
model=mtrl_trainer,
174-
dataset=TEST_CONFIG["dataset"],
175-
s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-fields-lambda/',
176-
mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"],
177-
role=TEST_CONFIG["role"],
178-
region=TEST_CONFIG["region"],
179+
dataset=test_config["dataset"],
180+
s3_output_path=f'{test_config["s3_output_path"]}integ-fields-lambda/',
181+
mlflow_resource_arn=test_config["mlflow_resource_arn"],
182+
role=test_config["role"],
183+
region=test_config["region"],
179184
agent_config=lambda_arn,
180185
)
181186

182187
evaluator._resolve_trainer_defaults()
183188
evaluator._resolve_agent_arn()
184189

185190
ctx = evaluator._build_template_context(
186-
aws_context={"region": TEST_CONFIG["region"], "account_id": _ACCOUNT_ID,
187-
"role_arn": TEST_CONFIG["role"]},
191+
aws_context={"region": test_config["region"], "account_id": test_config["account_id"],
192+
"role_arn": test_config["role"]},
188193
artifacts={},
189-
model_package_group_arn=TEST_CONFIG["model_package_group"],
194+
model_package_group_arn=test_config["model_package_group"],
190195
)
191196

192197
doc = json.loads(ctx["job_config_document_ft_str"])
@@ -198,26 +203,26 @@ def test_lambda_agent_config_fields(self, mtrl_trainer):
198203
# Ensure old field name is NOT present
199204
assert "LambdaConfig" not in agent_cfg
200205

201-
def test_model_package_config_fields(self, mtrl_trainer):
206+
def test_model_package_config_fields(self, mtrl_trainer, test_config):
202207
"""Verify ModelPackageConfig uses InputModelPackageArn only (no OutputModelPackageGroupArn for eval)."""
203208
evaluator = MultiTurnRLEvaluator(
204209
model=mtrl_trainer,
205-
dataset=TEST_CONFIG["dataset"],
206-
s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-fields-mpc/',
207-
mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"],
208-
role=TEST_CONFIG["role"],
209-
region=TEST_CONFIG["region"],
210-
agent_config=TEST_CONFIG["agent_arn"],
210+
dataset=test_config["dataset"],
211+
s3_output_path=f'{test_config["s3_output_path"]}integ-fields-mpc/',
212+
mlflow_resource_arn=test_config["mlflow_resource_arn"],
213+
role=test_config["role"],
214+
region=test_config["region"],
215+
agent_config=test_config["agent_arn"],
211216
)
212217

213218
evaluator._resolve_trainer_defaults()
214219
evaluator._resolve_agent_arn()
215220

216221
ctx = evaluator._build_template_context(
217-
aws_context={"region": TEST_CONFIG["region"], "account_id": _ACCOUNT_ID,
218-
"role_arn": TEST_CONFIG["role"]},
222+
aws_context={"region": test_config["region"], "account_id": test_config["account_id"],
223+
"role_arn": test_config["role"]},
219224
artifacts={},
220-
model_package_group_arn=TEST_CONFIG["model_package_group"],
225+
model_package_group_arn=test_config["model_package_group"],
221226
)
222227

223228
doc = json.loads(ctx["job_config_document_ft_str"])
@@ -239,41 +244,41 @@ class TestMTRLEvaluatorIntegration:
239244
in accounts with the feature flag enabled (e.g., 742774200982).
240245
"""
241246

242-
def test_evaluator_construction_with_trainer(self, mtrl_trainer):
247+
def test_evaluator_construction_with_trainer(self, mtrl_trainer, test_config):
243248
"""Test that MultiTurnRLEvaluator can be constructed from a trainer."""
244249
evaluator = MultiTurnRLEvaluator(
245250
model=mtrl_trainer,
246-
dataset=TEST_CONFIG["dataset"],
247-
s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-construct/',
248-
mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"],
249-
role=TEST_CONFIG["role"],
250-
region=TEST_CONFIG["region"],
251-
agent_config=TEST_CONFIG["agent_arn"],
251+
dataset=test_config["dataset"],
252+
s3_output_path=f'{test_config["s3_output_path"]}integ-construct/',
253+
mlflow_resource_arn=test_config["mlflow_resource_arn"],
254+
role=test_config["role"],
255+
region=test_config["region"],
256+
agent_config=test_config["agent_arn"],
252257
)
253258

254259
assert evaluator is not None
255260
assert evaluator.model is mtrl_trainer
256-
assert evaluator.dataset == TEST_CONFIG["dataset"]
257-
assert evaluator.region == TEST_CONFIG["region"]
261+
assert evaluator.dataset == test_config["dataset"]
262+
assert evaluator.region == test_config["region"]
258263

259-
def test_evaluator_construction_with_base_model(self):
264+
def test_evaluator_construction_with_base_model(self, test_config):
260265
"""Test that MultiTurnRLEvaluator can be constructed from a base model string."""
261266
evaluator = MultiTurnRLEvaluator(
262-
model=TEST_CONFIG["base_model"],
263-
dataset=TEST_CONFIG["dataset"],
264-
s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-base/',
265-
agent_config=TEST_CONFIG["agent_arn"],
266-
mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"],
267-
role=TEST_CONFIG["role"],
268-
region=TEST_CONFIG["region"],
267+
model=test_config["base_model"],
268+
dataset=test_config["dataset"],
269+
s3_output_path=f'{test_config["s3_output_path"]}integ-base/',
270+
agent_config=test_config["agent_arn"],
271+
mlflow_resource_arn=test_config["mlflow_resource_arn"],
272+
role=test_config["role"],
273+
region=test_config["region"],
269274
)
270275

271276
assert evaluator is not None
272-
assert evaluator.model == TEST_CONFIG["base_model"]
277+
assert evaluator.model == test_config["base_model"]
273278

274-
def test_get_all_mtrl_evaluations(self):
279+
def test_get_all_mtrl_evaluations(self, test_config):
275280
"""Test listing all MTRL evaluation executions."""
276-
all_execs = MultiTurnRLEvaluator.get_all(region=TEST_CONFIG["region"])
281+
all_execs = MultiTurnRLEvaluator.get_all(region=test_config["region"])
277282

278283
if hasattr(all_execs, '__iter__'):
279284
all_execs = list(all_execs)

0 commit comments

Comments
 (0)