Skip to content

Commit 8bdbd90

Browse files
author
Orbax Authors
committed
Refactor multiprocess test execution in Orbax.
PiperOrigin-RevId: 867416598
1 parent fa287e0 commit 8bdbd90

File tree

12 files changed

+3347
-795
lines changed

12 files changed

+3347
-795
lines changed

.github/workflows/build.yml

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ jobs:
6464
pip install "jax>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
6565
fi
6666
- name: Test with pytest
67-
# TODO(yaning): Move these to an exclude target within pytest.ini.
67+
# TODO(nikhilbansall): Move these to an exclude target within pytest.ini.
6868
run: |
69-
python -m pytest --import-mode=importlib --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py --ignore=orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py --ignore=orbax/checkpoint/checkpoint_manager_test.py
69+
python -m pytest --import-mode=importlib --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py --ignore=orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py --ignore=orbax/checkpoint/checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py
7070
# The below step just reports the success or failure of tests as a "commit status".
7171
# This is needed for copybara integration.
7272
- name: Report success or failure as github status
@@ -260,9 +260,9 @@ jobs:
260260
pip install -e .
261261
pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
262262
pip uninstall -y orbax
263-
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
263+
if [ "${{ matrix.jax-version }}" = "newest" ]; then
264264
pip install -U jax[k8s,cuda12] jaxlib
265-
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
265+
elif [ "${{ matrix.jax-version }}" = "nightly" ]; then
266266
pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
267267
else
268268
pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
@@ -296,20 +296,61 @@ jobs:
296296
echo "The following benchmarks failed:$failed_benchmarks"
297297
exit 1
298298
fi
299-
# cd orbax/checkpoint/_src/testing/benchmarks && python -c "import sys; import jax; jax.distributed.initialize(); print(jax.devices()); from absl import app; import run_benchmarks; sys.argv = ['run_benchmarks.py', '--config_file=configs/pytree_checkpoint_benchmark.yaml', '--output_directory=$GCS_BUCKET_PATH']; app.run(run_benchmarks.main)"
300-
# cd ../../../../..
301-
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
302-
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
303299
# The below step just reports the success or failure of tests as a "commit status".
304300
# This is needed for copybara integration.
301+
- name: Report success or failure as github status
302+
if: always()
303+
shell: bash
304+
run: |
305+
status="${{ job.status }}"
306+
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
307+
curl -sS --request POST \
308+
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
309+
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
310+
--header 'content-type: application/json' \
311+
--data '{
312+
"state": "'$lowercase_status'",
313+
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
314+
"description": "'$status'",
315+
"context": "github-actions/build"
316+
}'
317+
318+
multiprocess-unit-tests:
319+
name: "multiprocess-unit-tests (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
320+
runs-on: linux-x86-ct5lp-4tpu-x2
321+
container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e
322+
defaults:
323+
run:
324+
working-directory: checkpoint
325+
strategy:
326+
matrix:
327+
python-version: ["3.11"]
328+
jax-version: ["newest"]
329+
steps:
330+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
331+
- name: Set up Python ${{ matrix.python-version }}
332+
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
333+
with:
334+
python-version: ${{ matrix.python-version }}
335+
- name: Install dependencies
336+
run: |
337+
pip install -e .
338+
pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
339+
pip uninstall -y orbax
340+
pip install gcsfs
341+
pip install portpicker pytest chex pyyaml
342+
if [ "${{ matrix.jax-version }}" = "newest" ]; then
343+
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
344+
elif [ "${{ matrix.jax-version }}" = "nightly" ]; then
345+
pip install -U --pre "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
346+
else
347+
pip install "jax[tpu]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
348+
fi
305349
- name: Run multiprocess tests
306350
env:
307351
TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }}
308352
run: |
309-
python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)"
310-
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;"
311-
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
312-
# python -m pytest orbax/checkpoint/checkpoint_manager_test.py
353+
python orbax/checkpoint/_src/testing/multiprocess_unittests/run_tests.py --filename=orbax/checkpoint/_src/testing/multiprocess_unittests/tagged_tests.yaml --processes=4
313354
- name: Report success or failure as github status
314355
if: always()
315356
shell: bash

0 commit comments

Comments
 (0)