Skip to content

Commit 04c4de8

Browse files
committed
update
update update update update update update
1 parent 089f499 commit 04c4de8

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

.github/actions/setup/action.yml

+12-5
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,27 @@ runs:
2828
if: ${{ inputs.torch-version != 'nightly' }}
2929
run: |
3030
pip install torch==${{ inputs.torch-version }}.* --extra-index-url https://download.pytorch.org/whl/${{ inputs.cuda-version }}
31-
python -c "import torch; print('PyTorch:', torch.__version__)"
32-
python -c "import torch; print('CUDA:', torch.version.cuda)"
3331
shell: bash
3432

3533
- name: Install PyTorch ${{ inputs.torch-version }}+${{ inputs.cuda-version }}
3634
if: ${{ inputs.torch-version == 'nightly' }}
3735
run: |
3836
pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/${{ inputs.cuda-version }}
39-
python -c "import torch; print('PyTorch:', torch.__version__)"
40-
python -c "import torch; print('CUDA:', torch.version.cuda)"
37+
shell: bash
38+
39+
- name: Downgrade NumPy if necessary
40+
run: |
41+
torch_version=$(python -c "import torch; print(torch.__version__)")
42+
if [[ $torch_version == *"2.0"* || $torch_version == *"2.1"* || $torch_version == *"2.2"* ]]; then
43+
pip install 'numpy<2'
44+
fi
4145
shell: bash
4246

4347
- name: List installed packages
44-
run: pip list
48+
run: |
49+
pip list
50+
python -c "import torch; print('PyTorch:', torch.__version__)"
51+
python -c "import torch; print('CUDA:', torch.version.cuda)"
4552
shell: bash
4653

4754
# TODO: Include catboost in Python 3.13 CI when catboost supports it:

0 commit comments

Comments
 (0)