Skip to content

Commit eb524f7

Browse files
committed
fix conftest
1 parent c89a8f5 commit eb524f7

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed

test/sagemaker_tests/sglang/general/__init__.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,41 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
15+
import os
16+
from enum import Enum
17+
18+
import botocore
19+
20+
21+
def _botocore_resolver():
22+
"""
23+
Get the DNS suffix for the given region.
24+
:return: endpoint object
25+
"""
26+
loader = botocore.loaders.create_loader()
27+
return botocore.regions.EndpointResolver(loader.load_data("endpoints"))
28+
29+
30+
def get_ecr_registry(account, region):
31+
"""
32+
Get prefix of ECR image URI
33+
:param account: Account ID
34+
:param region: region where ECR repo exists
35+
:return: AWS ECR registry
36+
"""
37+
endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
38+
return "{}.dkr.{}".format(account, endpoint_data["hostname"])
39+
40+
def get_efa_test_instance_type(default: list):
41+
"""
42+
Get the instance type to be used for EFA tests from the environment, or default to a given value if the type
43+
isn't specified in the environment.
44+
45+
:param default: list of instance type to be used for tests
46+
:return: list of instance types to be parametrized for a test
47+
"""
48+
configured_instance_type = os.getenv("SM_EFA_TEST_INSTANCE_TYPE")
49+
if configured_instance_type:
50+
return [configured_instance_type]
51+
return default

test/sagemaker_tests/sglang/general/conftest.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
import os
1717
import sys
1818

19+
import boto3
1920
import pytest
21+
from sagemaker import LocalSession, Session
22+
23+
from . import get_ecr_registry, get_efa_test_instance_type
2024

2125
logger = logging.getLogger(__name__)
2226

@@ -81,3 +85,86 @@ def pytest_collection_modifyitems(session, config, items):
8185

8286
report_generator = TestReportGenerator(items, is_sagemaker=True)
8387
report_generator.generate_coverage_doc(framework="pytorch", job_type="training")
88+
89+
90+
@pytest.fixture(scope="session", name="docker_base_name")
91+
def fixture_docker_base_name(request):
92+
return request.config.getoption("--docker-base-name")
93+
94+
95+
@pytest.fixture(scope="session", name="region")
96+
def fixture_region(request):
97+
return request.config.getoption("--region")
98+
99+
100+
@pytest.fixture(scope="session", name="framework_version")
101+
def fixture_framework_version(request):
102+
return request.config.getoption("--framework-version")
103+
104+
105+
@pytest.fixture(scope="session", name="py_version")
106+
def fixture_py_version(request):
107+
return "py{}".format(int(request.config.getoption("--py-version")))
108+
109+
110+
@pytest.fixture(scope="session", name="processor")
111+
def fixture_processor(request):
112+
return request.config.getoption("--processor")
113+
114+
115+
@pytest.fixture(scope="session", name="sagemaker_regions")
116+
def fixture_sagemaker_regions(request):
117+
sagemaker_regions = request.config.getoption("--sagemaker-regions")
118+
return sagemaker_regions.split(",")
119+
120+
121+
@pytest.fixture(scope="session", name="tag")
122+
def fixture_tag(request, framework_version, processor, py_version):
123+
provided_tag = request.config.getoption("--tag")
124+
default_tag = "{}-{}-{}".format(framework_version, processor, py_version)
125+
return provided_tag if provided_tag else default_tag
126+
127+
128+
@pytest.fixture(scope="session", name="docker_image")
129+
def fixture_docker_image(docker_base_name, tag):
130+
return "{}:{}".format(docker_base_name, tag)
131+
132+
133+
@pytest.fixture(scope="session", name="sagemaker_session")
134+
def fixture_sagemaker_session(region):
135+
return Session(boto_session=boto3.Session(region_name=region))
136+
137+
138+
@pytest.fixture(name="efa_instance_type")
139+
def fixture_efa_instance_type(request):
140+
try:
141+
return request.param
142+
except AttributeError:
143+
return get_efa_test_instance_type(default=["ml.p4d.24xlarge"])[0]
144+
145+
146+
@pytest.fixture(scope="session", name="sagemaker_local_session")
147+
def fixture_sagemaker_local_session(region):
148+
return LocalSession(boto_session=boto3.Session(region_name=region))
149+
150+
151+
@pytest.fixture(name="aws_id", scope="session")
152+
def fixture_aws_id(request):
153+
return request.config.getoption("--aws-id")
154+
155+
156+
@pytest.fixture(name="instance_type", scope="session")
157+
def fixture_instance_type(request, processor):
158+
provided_instance_type = request.config.getoption("--instance-type")
159+
default_instance_type = "local" if processor == "cpu" else "local_gpu"
160+
return provided_instance_type or default_instance_type
161+
162+
163+
@pytest.fixture(name="docker_registry", scope="session")
164+
def fixture_docker_registry(aws_id, region):
165+
return get_ecr_registry(aws_id, region)
166+
167+
168+
@pytest.fixture(name="ecr_image", scope="session")
169+
def fixture_ecr_image(docker_registry, docker_base_name, tag):
170+
return "{}/{}:{}".format(docker_registry, docker_base_name, tag)

0 commit comments

Comments
 (0)