Skip to content

Commit 6b11961

Browse files
authored
Add arm64 windows support (#678)
* add arm64 windows to python.yml Also, added push event trigger to Python workflow to test on fork. * add arm64 windows handling in tests * Fix quotes * fix typo * Fix typo * use different url for arm64 * Update condition for installing jax and flax * run arm64 on limited set of tests * Rename test step for Windows arm64 * add arm64 build * Modify target condition for windows arm64 * add additional tests Add specific test files to the testing workflow. * update arm64 tests * run all but tf tests * change conditional, move install * change test command sytax to match previous syntax * remove push trigger Remove push trigger from Python workflow
1 parent 3979c8f commit 6b11961

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

.github/workflows/python-release.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ jobs:
9494
target: x64
9595
- runner: windows-latest
9696
target: x86
97+
- runner: windows-11-arm
98+
target: arm64
9799
steps:
98100
- uses: actions/checkout@v4
99101
- uses: actions/setup-python@v5
@@ -103,7 +105,7 @@ jobs:
103105
- name: Build wheels
104106
uses: PyO3/maturin-action@v1
105107
with:
106-
target: ${{ matrix.platform.target }}
108+
target: ${{ matrix.platform.target == 'arm64' && 'aarch64-pc-windows-msvc' || matrix.platform.target }}
107109
args: --release --out dist --manifest-path bindings/python/Cargo.toml
108110
sccache: 'true'
109111
- name: Upload wheels

.github/workflows/python.yml

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ jobs:
4040
python: "3.12"
4141
numpy: numpy
4242
arch: "arm64"
43+
- os: windows-11-arm
44+
version:
45+
torch: torch
46+
python: "3.12"
47+
numpy: numpy
48+
arch: "arm64"
4349
defaults:
4450
run:
4551
working-directory: ./bindings/python
@@ -81,7 +87,7 @@ jobs:
8187
# pip install -U pip
8288

8389
- name: Install (torch)
84-
if: matrix.version.arch != 'x64-freethreaded'
90+
if: matrix.version.arch != 'x64-freethreaded' && matrix.os != 'windows-11-arm'
8591
run: |
8692
pip install ${{ matrix.version.numpy }}
8793
pip install ${{ matrix.version.torch }}
@@ -93,22 +99,29 @@ jobs:
9399
pip install ${{ matrix.version.torch }} --index-url https://download.pytorch.org/whl/cu126
94100
shell: bash
95101

102+
- name: Install (torch windows arm64)
103+
if: matrix.os == 'windows-11-arm'
104+
run: |
105+
pip install ${{ matrix.version.numpy }}
106+
pip install ${{ matrix.version.torch }} --index-url https://download.pytorch.org/whl/cpu
107+
shell: bash
108+
96109
- name: Install (hdf5 non windows)
97110
if: matrix.os == 'ubuntu-latest' && matrix.version.arch != 'x64-freethreaded'
98111
run: |
99112
sudo apt-get update
100113
sudo apt-get install libhdf5-dev
101114
102115
- name: Install (tensorflow)
103-
if: matrix.version.arch != 'x64-freethreaded'
116+
if: matrix.version.arch != 'x64-freethreaded' && matrix.os != 'windows-11-arm'
104117
run: |
105118
pip install .[tensorflow]
106119
# Force reinstall of numpy, tensorflow uses numpy 2 even on 3.9
107120
pip install ${{ matrix.version.numpy }}
108121
shell: bash
109122

110123
- name: Install (jax, flax)
111-
if: matrix.os != 'windows-latest' && matrix.version.arch != 'x64-freethreaded'
124+
if: runner.os != 'Windows' && matrix.version.arch != 'x64-freethreaded'
112125
run:
113126
pip install .[jax]
114127
shell: bash
@@ -125,12 +138,19 @@ jobs:
125138
ruff format --check .
126139
127140
- name: Run tests
128-
if: matrix.version.arch != 'x64-freethreaded'
141+
if: matrix.version.arch != 'x64-freethreaded' && matrix.os != 'windows-11-arm'
129142
run: |
130143
cargo test
131144
pip install ".[testing]"
132145
pytest -sv tests/
133146
147+
- name: Run tests (Windows arm64)
148+
if: matrix.os == 'windows-11-arm'
149+
run: |
150+
cargo test
151+
pip install ".[testing]"
152+
pytest -sv tests/ --ignore=tests/test_tf_comparison.py
153+
134154
- name: Run tests (freethreaded)
135155
if: matrix.version.arch == 'x64-freethreaded'
136156
run: |

0 commit comments

Comments
 (0)