|
64 | 64 | pip install "jax>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}" |
65 | 65 | fi |
66 | 66 | - name: Test with pytest |
67 | | - # TODO(yaning): Move these to an exclude target within pytest.ini. |
| 67 | + # TODO(nikhilbansall): Move these to an exclude target within pytest.ini. |
68 | 68 | run: | |
69 | | - python -m pytest --import-mode=importlib --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py --ignore=orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py --ignore=orbax/checkpoint/checkpoint_manager_test.py |
| 69 | + python -m pytest --import-mode=importlib --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py --ignore=orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py --ignore=orbax/checkpoint/checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py |
70 | 70 | # The below step just reports the success or failure of tests as a "commit status". |
71 | 71 | # This is needed for copybara integration. |
72 | 72 | - name: Report success or failure as github status |
@@ -260,9 +260,9 @@ jobs: |
260 | 260 | pip install -e . |
261 | 261 | pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
262 | 262 | pip uninstall -y orbax |
263 | | - if [[ "${{ matrix.jax-version }}" == "newest" ]]; then |
| 263 | + if [ "${{ matrix.jax-version }}" = "newest" ]; then |
264 | 264 | pip install -U jax[k8s,cuda12] jaxlib |
265 | | - elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then |
| 265 | + elif [ "${{ matrix.jax-version }}" = "nightly" ]; then |
266 | 266 | pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ |
267 | 267 | else |
268 | 268 | pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}" |
@@ -296,20 +296,61 @@ jobs: |
296 | 296 | echo "The following benchmarks failed:$failed_benchmarks" |
297 | 297 | exit 1 |
298 | 298 | fi |
299 | | - # cd orbax/checkpoint/_src/testing/benchmarks && python -c "import sys; import jax; jax.distributed.initialize(); print(jax.devices()); from absl import app; import run_benchmarks; sys.argv = ['run_benchmarks.py', '--config_file=configs/pytree_checkpoint_benchmark.yaml', '--output_directory=$GCS_BUCKET_PATH']; app.run(run_benchmarks.main)" |
300 | | - # cd ../../../../.. |
301 | | - # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py |
302 | | - # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH |
303 | 299 | # The below step just reports the success or failure of tests as a "commit status". |
304 | 300 | # This is needed for copybara integration. |
| 301 | + - name: Report success or failure as github status |
| 302 | + if: always() |
| 303 | + shell: bash |
| 304 | + run: | |
| 305 | + status="${{ job.status }}" |
| 306 | + lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') |
| 307 | + curl -sS --request POST \ |
| 308 | + --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ |
| 309 | + --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ |
| 310 | + --header 'content-type: application/json' \ |
| 311 | + --data '{ |
| 312 | + "state": "'$lowercase_status'", |
| 313 | + "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", |
| 314 | + "description": "'$status'", |
| 315 | + "context": "github-actions/build" |
| 316 | + }' |
| 317 | +
|
| 318 | + multiprocess-unit-tests: |
| 319 | + name: "multiprocess-unit-tests (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" |
| 320 | + runs-on: linux-x86-ct5lp-4tpu-x2 |
| 321 | + container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e |
| 322 | + defaults: |
| 323 | + run: |
| 324 | + working-directory: checkpoint |
| 325 | + strategy: |
| 326 | + matrix: |
| 327 | + python-version: ["3.11"] |
| 328 | + jax-version: ["newest"] |
| 329 | + steps: |
| 330 | + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 |
| 331 | + - name: Set up Python ${{ matrix.python-version }} |
| 332 | + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 |
| 333 | + with: |
| 334 | + python-version: ${{ matrix.python-version }} |
| 335 | + - name: Install dependencies |
| 336 | + run: | |
| 337 | + pip install -e . |
| 338 | + pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
| 339 | + pip uninstall -y orbax |
| 340 | + pip install gcsfs |
| 341 | + pip install portpicker pytest chex pyyaml |
| 342 | + if [ "${{ matrix.jax-version }}" = "newest" ]; then |
| 343 | + pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
| 344 | + elif [ "${{ matrix.jax-version }}" = "nightly" ]; then |
| 345 | + pip install -U --pre "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ |
| 346 | + else |
| 347 | + pip install "jax[tpu]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
| 348 | + fi |
305 | 349 | - name: Run multiprocess tests |
306 | 350 | env: |
307 | 351 | TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }} |
308 | 352 | run: | |
309 | | - python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)" |
310 | | - # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;" |
311 | | - # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH |
312 | | - # python -m pytest orbax/checkpoint/checkpoint_manager_test.py |
| 353 | + python orbax/checkpoint/_src/testing/multiprocess_unittests/run_tests.py --filename=orbax/checkpoint/_src/testing/multiprocess_unittests/tagged_tests.yaml --processes=4 |
313 | 354 | - name: Report success or failure as github status |
314 | 355 | if: always() |
315 | 356 | shell: bash |
|
0 commit comments