|
16 | 16 | import os |
17 | 17 | import sys |
18 | 18 |
|
| 19 | +import boto3 |
19 | 20 | import pytest |
| 21 | +from sagemaker import LocalSession, Session |
| 22 | + |
| 23 | +from . import get_ecr_registry, get_efa_test_instance_type |
20 | 24 |
|
21 | 25 | logger = logging.getLogger(__name__) |
22 | 26 |
|
@@ -81,3 +85,86 @@ def pytest_collection_modifyitems(session, config, items): |
81 | 85 |
|
82 | 86 | report_generator = TestReportGenerator(items, is_sagemaker=True) |
83 | 87 | report_generator.generate_coverage_doc(framework="pytorch", job_type="training") |
| 88 | + |
| 89 | + |
| 90 | +@pytest.fixture(scope="session", name="docker_base_name") |
| 91 | +def fixture_docker_base_name(request): |
| 92 | + return request.config.getoption("--docker-base-name") |
| 93 | + |
| 94 | + |
| 95 | +@pytest.fixture(scope="session", name="region") |
| 96 | +def fixture_region(request): |
| 97 | + return request.config.getoption("--region") |
| 98 | + |
| 99 | + |
| 100 | +@pytest.fixture(scope="session", name="framework_version") |
| 101 | +def fixture_framework_version(request): |
| 102 | + return request.config.getoption("--framework-version") |
| 103 | + |
| 104 | + |
| 105 | +@pytest.fixture(scope="session", name="py_version") |
| 106 | +def fixture_py_version(request): |
| 107 | + return "py{}".format(int(request.config.getoption("--py-version"))) |
| 108 | + |
| 109 | + |
| 110 | +@pytest.fixture(scope="session", name="processor") |
| 111 | +def fixture_processor(request): |
| 112 | + return request.config.getoption("--processor") |
| 113 | + |
| 114 | + |
| 115 | +@pytest.fixture(scope="session", name="sagemaker_regions") |
| 116 | +def fixture_sagemaker_regions(request): |
| 117 | + sagemaker_regions = request.config.getoption("--sagemaker-regions") |
| 118 | + return sagemaker_regions.split(",") |
| 119 | + |
| 120 | + |
| 121 | +@pytest.fixture(scope="session", name="tag") |
| 122 | +def fixture_tag(request, framework_version, processor, py_version): |
| 123 | + provided_tag = request.config.getoption("--tag") |
| 124 | + default_tag = "{}-{}-{}".format(framework_version, processor, py_version) |
| 125 | + return provided_tag if provided_tag else default_tag |
| 126 | + |
| 127 | + |
| 128 | +@pytest.fixture(scope="session", name="docker_image") |
| 129 | +def fixture_docker_image(docker_base_name, tag): |
| 130 | + return "{}:{}".format(docker_base_name, tag) |
| 131 | + |
| 132 | + |
| 133 | +@pytest.fixture(scope="session", name="sagemaker_session") |
| 134 | +def fixture_sagemaker_session(region): |
| 135 | + return Session(boto_session=boto3.Session(region_name=region)) |
| 136 | + |
| 137 | + |
| 138 | +@pytest.fixture(name="efa_instance_type") |
| 139 | +def fixture_efa_instance_type(request): |
| 140 | + try: |
| 141 | + return request.param |
| 142 | + except AttributeError: |
| 143 | + return get_efa_test_instance_type(default=["ml.p4d.24xlarge"])[0] |
| 144 | + |
| 145 | + |
| 146 | +@pytest.fixture(scope="session", name="sagemaker_local_session") |
| 147 | +def fixture_sagemaker_local_session(region): |
| 148 | + return LocalSession(boto_session=boto3.Session(region_name=region)) |
| 149 | + |
| 150 | + |
| 151 | +@pytest.fixture(name="aws_id", scope="session") |
| 152 | +def fixture_aws_id(request): |
| 153 | + return request.config.getoption("--aws-id") |
| 154 | + |
| 155 | + |
| 156 | +@pytest.fixture(name="instance_type", scope="session") |
| 157 | +def fixture_instance_type(request, processor): |
| 158 | + provided_instance_type = request.config.getoption("--instance-type") |
| 159 | + default_instance_type = "local" if processor == "cpu" else "local_gpu" |
| 160 | + return provided_instance_type or default_instance_type |
| 161 | + |
| 162 | + |
| 163 | +@pytest.fixture(name="docker_registry", scope="session") |
| 164 | +def fixture_docker_registry(aws_id, region): |
| 165 | + return get_ecr_registry(aws_id, region) |
| 166 | + |
| 167 | + |
| 168 | +@pytest.fixture(name="ecr_image", scope="session") |
| 169 | +def fixture_ecr_image(docker_registry, docker_base_name, tag): |
| 170 | + return "{}/{}:{}".format(docker_registry, docker_base_name, tag) |
0 commit comments