Skip to content

Commit b23e3d7

Browse files
committed
Allow to skip GPU tests with env variable
1 parent 943dc2b commit b23e3d7

2 files changed

Lines changed: 10 additions & 5 deletions

File tree

.github/workflows/unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
python -m pip install .
3737
- name: Test with pytest
3838
run: |
39-
pytest test/
39+
export KEYSVALS_SKIP_CUDA_TESTS=1; pytest test/
4040
- name: Upload coverage reports to Codecov
4141
uses: codecov/codecov-action@eaaf4bedf32dbdc6b720b63067d99c4d77d6047d # v3.1.4
4242
with:

keys_values/kvcache/test_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Tuple, Optional, List, Dict, Callable
15-
import math
1614
from functools import partial
1715
from itertools import product
16+
import math
17+
import os
18+
from typing import Tuple, Optional, List, Dict, Callable
1819

1920
import torch
2021

@@ -36,6 +37,8 @@
3637
# Tests run quite slowly for "mps". If this changes, switch this to True
3738
RUN_TESTS_FOR_MPS = False
3839

40+
ENV_VAR_SKIP_CUDA_TESTS = "KEYSVALS_SKIP_CUDA_TESTS"
41+
3942

4043
def create_kv_cache(
4144
name: str,
@@ -275,7 +278,8 @@ def available_backends(do_mps: bool = True) -> List[torch.device]:
275278
result = [torch.device("cpu")]
276279
if do_mps and RUN_TESTS_FOR_MPS and torch.backends.mps.is_available():
277280
result.append(torch.device("mps"))
278-
if torch.cuda.is_available():
281+
run_cuda_tests = os.environ.get(ENV_VAR_SKIP_CUDA_TESTS) is None
282+
if run_cuda_tests and torch.cuda.is_available():
279283
result.append(torch.device("cuda:0"))
280284
return result
281285

@@ -301,7 +305,8 @@ def device_for_cache_name(name: str) -> torch.device:
301305

302306

303307
def filter_cache_names(names: List[str]) -> List[str]:
304-
if torch.cuda.is_available():
308+
run_cuda_tests = os.environ.get(ENV_VAR_SKIP_CUDA_TESTS) is None
309+
if run_cuda_tests and torch.cuda.is_available():
305310
return names
306311
else:
307312
return [name for name in names if not cache_name_gpu_only(name)]

0 commit comments

Comments
 (0)