1515 - ' yes'
1616 - ' no'
1717 workflow_call :
18+ inputs :
19+ build_jax :
20+ description : " Should the jax artifact be built?"
21+ required : true
22+ default : true
23+ type : boolean
24+ build_jaxlib :
25+ description : " Should the jaxlib artifact be built?"
26+ required : true
27+ default : true
28+ type : boolean
29+ build_jax_cuda_plugin :
30+ description : " Should the jax-cuda-plugin artifact be built?"
31+ required : true
32+ default : true
33+ type : boolean
34+ build_jax_cuda_pjrt :
35+ description : " Should the jax-cuda-pjrt artifact be built?"
36+ required : true
37+ default : true
38+ type : boolean
39+ clone_main_xla :
40+ description : " Should latest XLA be used? (1 to enable, 0 to disable)"
41+ type : string
42+ required : true
43+ default : " 0"
44+ upload_artifacts :
45+ description : " Should the artifacts be uploaded to a GCS bucket?"
46+ required : true
47+ default : false
48+ type : boolean
49+ upload_destination :
50+ description : " GCS location to where the artifacts should be uploaded"
51+ required : true
52+ default : ' ${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
53+ type : string
1854
1955jobs :
2056 build :
2763 matrix :
2864 runner : ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"]
2965 artifact : ["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]
30- python : ["3.10", "3.11", "3.12"]
66+ python : ["3.10"] # , "3.11", "3.12"]
3167 # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each
3268 # Python version.
3369 exclude :
5894 (contains(matrix.runner, 'windows-x86') && null) }}
5995
6096 env :
61- # Do not run Docker container for Linux runners. Linux runners already run in a Docker container.
62- JAXCI_RUN_DOCKER_CONTAINER : 0
97+ JAXCI_HERMETIC_PYTHON_VERSION : " ${{ matrix.python }} "
98+ JAXCI_CLONE_MAIN_XLA : " ${{ inputs.clone_main_xla }} "
6399
64100 steps :
65101 - uses : actions/checkout@v3
@@ -68,7 +104,19 @@ jobs:
68104 uses : google-ml-infra/actions/ci_connection@main
69105 with :
70106 halt-dispatch-input : ${{ inputs.halt-for-connection }}
71- - name : Build ${{ matrix.artifact }}
72- env :
73- JAXCI_HERMETIC_PYTHON_VERSION : " ${{ matrix.python }}"
74- run : ./ci/build_artifacts.sh "${{ matrix.artifact }}"
107+ - name : Build jax
108+ if : inputs.build_jax && matrix.artifact == 'jax'
109+ run : ./ci/build_artifacts.sh "jax"
110+ - name : Build jaxlib
111+ if : inputs.build_jaxlib && matrix.artifact == 'jaxlib'
112+ run : ./ci/build_artifacts.sh "jaxlib"
113+ - name : Build jax-cuda-plugin
114+ if : inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin'
115+ run : ./ci/build_artifacts.sh "jax-cuda-plugin"
116+ - name : Build jax-cuda-pjrt
117+ if : inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt'
118+ run : ./ci/build_artifacts.sh "jax-cuda-pjrt"
119+ - name : Upload artifacts to GCS bucket
120+ if : inputs.upload_artifacts
121+ run : ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"
122+
0 commit comments