@@ -29,27 +29,38 @@ jobs:
2929 JAXCI_HERMETIC_PYTHON_VERSION : 3.11
3030
3131 steps :
32- - uses : actions/checkout@v3
33- # - name: Build jaxlib
34- # env :
35- # JAXCI_CLONE_MAIN_XLA: 1
36- # run: ./ci/build_artifacts.sh "jaxlib"
37- # - name: Build jax-cuda-plugin
38- # env:
39- # JAXCI_CLONE_MAIN_XLA: 1
40- # run: ./ci/build_artifacts.sh "jax-cuda-plugin"
41- # - name: Build jax-cuda-pjrt
42- # env:
43- # JAXCI_CLONE_MAIN_XLA: 1
44- # run: ./ci/build_artifacts.sh "jax-cuda-pjrt"
45- - name : Install dependencies
32+ - name : Checkout JAX Fork
33+ uses : actions/checkout@v3
34+ with :
35+ repository : ' google-ml-infra/jax-fork '
36+ path : jax-fork
37+
38+ - name : Checkout XLA
39+ uses : actions/checkout@v3
40+ with :
41+ repository : ' openxla/xla ' # Or your XLA fork
42+ path : xla
43+
44+ - name : Install JAX Dependencies
45+ working-directory : jax-fork
4646 run : |
4747 python -m pip install --upgrade pip
4848 pip install pytest
4949 pip install absl-py
50- # pip install -U jax
51- pip install -U "jax[cuda12]"
50+ pip install "jax[cuda12_pip]" # Adjust CUDA version if needed
5251 pip install google-benchmark
53- - name : Run Multiprocess GPU Test
54- run : |
55- python -m pytest tests/multiprocess_gpu_test.py
52+
53+ - name : Run JAX Multiprocess GPU Test
54+ working-directory : jax-fork
55+ continue-on-error : true
56+ run : python -m pytest tests/multiprocess_gpu_test.py
57+
58+ - name : Build XLA GPU Atomic Test
59+ working-directory : xla
60+ continue-on-error : true
61+ run : bazel build -c opt --config=cuda //xla/service/gpu/tests:gpu_atomic_test
62+
63+ - name : Run XLA GPU Atomic Test
64+ working-directory : xla
65+ continue-on-error : true
66+ run : bazel test -c opt --config=cuda //xla/service/gpu/tests:gpu_atomic_test
0 commit comments