1616 - ' no'
1717
1818jobs :
19- jax-build-and-test :
20- runs-on : linux-x86-g2-48-l4-4gpu # Use a GPU-enabled runner
21- container :
22- image : " gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
23-
24- env :
25- JAXCI_HERMETIC_PYTHON_VERSION : 3.11
26-
27- steps :
28- - name : Checkout JAX Fork
29- uses : actions/checkout@v3
30- with :
31- repository : ' google-ml-infra/jax-fork'
32- path : jax-fork
33-
34- - name : Install JAX Dependencies
35- working-directory : jax-fork
36- run : |
37- python -m pip install --upgrade pip
38- pip install pytest
39- pip install absl-py
40- pip install "jax[cuda12_pip]" # Adjust CUDA version if needed
41- pip install google-benchmark
42- - name : Run JAX Multiprocess GPU Test
43- working-directory : jax-fork
44- continue-on-error : true
45- run : python -m pytest tests/multiprocess_gpu_test.py
46-
4719 build-xla-gpu-and-test :
4820 runs-on : linux-x86-g2-48-l4-4gpu # Use a GPU-enabled runner
4921 container :
@@ -113,29 +85,8 @@ jobs:
11385 - name : Run specific HLO file
11486 working-directory : xla
11587 run : |
116- echo "Before hlo_runner_main"
117- pwd # Print the current working directory
118- ls -l # list files in the directory
11988 ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_00881937d6d49056045c3325a12b108b.hlo &> results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
120- echo "After hlo_runner_main"
121- ls -l results # List files in the results directory
122- cat results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
12389
124-
125- # - name: Run HLO Module Benchmarks with GPU in xla/tests/fuzz
126- # working-directory: xla
127- # continue-on-error: true
128- # run: |
129- # for file in xla/tests/fuzz/*.hlo; do
130- # filename=$(basename "$file")
131- # # Skip expected failed hlo files.
132- # if [[ "$filename" == "rand_000060.hlo" || "$filename" == "rand_000067.hlo" || "$filename" == "rand_000072.hlo" ]]; then
133- # echo "Skipping benchmark on $file"
134- # continue
135- # fi
136- # echo "Running benchmark on $file"
137- # ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning "$file" &> results/"$filename".log
138- # done
13990 - name : Wait For Connection
14091 uses : google-ml-infra/actions/ci_connection@main
14192 with :
@@ -148,20 +99,39 @@ jobs:
14899 working-directory : xla
149100 run : python parse_xla_logs.py results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
150101
151- - name : Wait For Connection
152- uses : google-ml-infra/actions/ci_connection@main
153- with :
154- halt-dispatch-input : ${{ inputs.halt-for-connection }}
155-
156102 - name : Upload Results
157103 uses : actions/upload-artifact@v4
158104 with :
159105 name : gpu-xla-benchmarks
160106 path : xla/results
161- # - name: Wait For Connection
162- # uses: google-ml-infra/actions/ci_connection@main
163- # with:
164- # halt-dispatch-input: ${{ inputs.halt-for-connection }}
107+ # jax-build-and-test:
108+ # runs-on: linux-x86-g2-48-l4-4gpu # Use a GPU-enabled runner
109+ # container:
110+ # image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
111+
112+ # env:
113+ # JAXCI_HERMETIC_PYTHON_VERSION: 3.11
114+
115+ # steps:
116+ # - name: Checkout JAX Fork
117+ # uses: actions/checkout@v3
118+ # with:
119+ # repository: 'google-ml-infra/jax-fork'
120+ # path: jax-fork
121+
122+ # - name: Install JAX Dependencies
123+ # working-directory: jax-fork
124+ # run: |
125+ # python -m pip install --upgrade pip
126+ # pip install pytest
127+ # pip install absl-py
128+ # pip install "jax[cuda12_pip]" # Adjust CUDA version if needed
129+ # pip install google-benchmark
130+ # - name: Run JAX Multiprocess GPU Test
131+ # working-directory: jax-fork
132+ # continue-on-error: true
133+ # run: python -m pytest tests/multiprocess_gpu_test.py
134+
165135
166136 # - name: Run HLO Module Benchmarks withg GPU in xla/tests/fuzz
167137 # working-directory: xla
0 commit comments