1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from typing import Tuple , Optional , List , Dict , Callable
15- import math
1614from functools import partial
1715from itertools import product
16+ import math
17+ import os
18+ from typing import Tuple , Optional , List , Dict , Callable
1819
1920import torch
2021
3637# Tests run quite slowly for "mps". If this changes, switch this to True
3738RUN_TESTS_FOR_MPS = False
3839
40+ ENV_VAR_SKIP_CUDA_TESTS = "KEYSVALS_SKIP_CUDA_TESTS"
41+
3942
4043def create_kv_cache (
4144 name : str ,
@@ -275,7 +278,8 @@ def available_backends(do_mps: bool = True) -> List[torch.device]:
275278 result = [torch .device ("cpu" )]
276279 if do_mps and RUN_TESTS_FOR_MPS and torch .backends .mps .is_available ():
277280 result .append (torch .device ("mps" ))
278- if torch .cuda .is_available ():
281+ run_cuda_tests = os .environ .get (ENV_VAR_SKIP_CUDA_TESTS ) is None
282+ if run_cuda_tests and torch .cuda .is_available ():
279283 result .append (torch .device ("cuda:0" ))
280284 return result
281285
@@ -301,7 +305,8 @@ def device_for_cache_name(name: str) -> torch.device:
301305
302306
303307def filter_cache_names (names : List [str ]) -> List [str ]:
304- if torch .cuda .is_available ():
308+ run_cuda_tests = os .environ .get (ENV_VAR_SKIP_CUDA_TESTS ) is None
309+ if run_cuda_tests and torch .cuda .is_available ():
305310 return names
306311 else :
307312 return [name for name in names if not cache_name_gpu_only (name )]
0 commit comments