|
13 | 13 | from __future__ import absolute_import
|
14 | 14 | import pytest
|
15 | 15 | from unittest.mock import patch
|
| 16 | +from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME |
16 | 17 | from sagemaker.jumpstart.factory.estimator import (
|
17 | 18 | _add_model_uri_to_kwargs,
|
18 | 19 | get_model_info_default_kwargs,
|
@@ -119,3 +120,43 @@ def test_add_model_uri_to_kwargs_no_training_support(self, mock_supports_trainin
|
119 | 120 |
|
120 | 121 | mock_supports_training.assert_called_once()
|
121 | 122 | assert result.model_uri is None
|
| 123 | + |
| 124 | + @patch( |
| 125 | + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", |
| 126 | + return_value=False, |
| 127 | + ) |
| 128 | + @patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve") |
| 129 | + def test_add_model_uri_to_kwargs_private_hub( |
| 130 | + self, mock_retrieve, mock_supports_training, mock_kwargs |
| 131 | + ): |
| 132 | + """Test when model is from a private hub.""" |
| 133 | + default_uri = "s3://jumpstart-models/training/test-model/1.0.0" |
| 134 | + mock_retrieve.return_value = default_uri |
| 135 | + mock_kwargs.hub_arn = "arn:aws:sagemaker:us-west-2:123456789012:hub/private-hub" |
| 136 | + |
| 137 | + result = _add_model_uri_to_kwargs(mock_kwargs) |
| 138 | + |
| 139 | + # Should not check if model supports training model URI for private hub |
| 140 | + mock_supports_training.assert_not_called() |
| 141 | + mock_retrieve.assert_called_once() |
| 142 | + assert result.model_uri == default_uri |
| 143 | + |
| 144 | + @patch( |
| 145 | + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", |
| 146 | + return_value=False, |
| 147 | + ) |
| 148 | + @patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve") |
| 149 | + def test_add_model_uri_to_kwargs_public_hub( |
| 150 | + self, mock_retrieve, mock_supports_training, mock_kwargs |
| 151 | + ): |
| 152 | + """Test when model is from the public hub.""" |
| 153 | + mock_kwargs.hub_arn = ( |
| 154 | + f"arn:aws:sagemaker:us-west-2:123456789012:hub/{JUMPSTART_MODEL_HUB_NAME}" |
| 155 | + ) |
| 156 | + |
| 157 | + result = _add_model_uri_to_kwargs(mock_kwargs) |
| 158 | + |
| 159 | + # Should check if model supports training model URI for public hub |
| 160 | + mock_supports_training.assert_called_once() |
| 161 | + mock_retrieve.assert_not_called() |
| 162 | + assert result.model_uri is None |
0 commit comments