1717 workflow_call :
1818 inputs :
1919 build_jax :
20- description : " Should the jax artifact be built?"
20+ description : " Should the jax artifact be built? (1 to enable, 0 to disable)"
21+ type : string
2122 required : true
22- default : true
23- type : boolean
23+ default : " 1"
2424 build_jaxlib :
25- description : " Should the jaxlib artifact be built?"
25+ description : " Should the jaxlib artifact be built? (1 to enable, 0 to disable)"
26+ type : string
2627 required : true
27- default : true
28- type : boolean
28+ default : " 1"
2929 build_jax_cuda_plugin :
30- description : " Should the jax-cuda-plugin artifact be built?"
30+ description : " Should the jax-cuda-plugin artifact be built? (1 to enable, 0 to disable)"
31+ type : string
3132 required : true
32- default : true
33- type : boolean
33+ default : " 1"
3434 build_jax_cuda_pjrt :
35- description : " Should the jax-cuda-pjrt artifact be built?"
35+ description : " Should the jax-cuda-pjrt artifact be built? (1 to enable, 0 to disable)"
36+ type : string
3637 required : true
37- default : true
38- type : boolean
38+ default : " 1"
3939 clone_main_xla :
4040 description : " Should latest XLA be used? (1 to enable, 0 to disable)"
4141 type : string
@@ -62,13 +62,13 @@ jobs:
6262 - id : set-matrix
6363 run : |
6464 matrix='[]'
65- if ${{ inputs.build_jax }}; then
65+ if [[ ${{ inputs.build_jax }} == "1" ]] ; then
6666 matrix='["jax"]'
67- if ${{ inputs.build_jaxlib }}; then
67+ if [[ ${{ inputs.build_jaxlib }} == "1" ]] ; then
6868 matrix='["jax", "jaxlib"]'
69- if ${{ inputs.build_jax_cuda_pjrt }}; then
69+ if [[ ${{ inputs.build_jax_cuda_pjrt }} == "1" ]] ; then
7070 matrix='["jax", "jaxlib", "jax-cuda-pjrt"]'
71- if ${{ inputs.build_jax_cuda_plugin }}; then
71+ if [[ ${{ inputs.build_jax_cuda_plugin }} == "1" ]] ; then
7272 matrix='["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]'
7373 fi
7474 fi
0 commit comments