@@ -100,18 +100,43 @@ jobs:
100100 if : inputs.build_jax || inputs.build_jaxlib || inputs.build_jax_cuda_plugin || inputs.build_jax_cuda_pjrt
101101 steps :
102102 - uses : actions/checkout@v3
103+ if : >-
104+ (inputs.build_jax && matrix.artifact == 'jax') ||
105+ (inputs.build_jaxlib && matrix.artifact == 'jaxlib') ||
106+ (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') ||
107+ (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt')
103108 # Halt for testing
104109 - name : Wait For Connection
105110 uses : google-ml-infra/actions/ci_connection@main
106111 with :
107112 halt-dispatch-input : ${{ inputs.halt-for-connection }}
113+ if : >-
114+ (inputs.build_jax && matrix.artifact == 'jax') ||
115+ (inputs.build_jaxlib && matrix.artifact == 'jaxlib') ||
116+ (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') ||
117+ (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt')
108118 - name : Build ${{ matrix.artifact }}
119+ if : >-
120+ (inputs.build_jax && matrix.artifact == 'jax') ||
121+ (inputs.build_jaxlib && matrix.artifact == 'jaxlib') ||
122+ (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') ||
123+ (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt')
109124 run : ./ci/build_artifacts.sh "${{ matrix.artifact }}"
110125 - name : Set Platform
126+ if : >-
127+ (inputs.build_jax && matrix.artifact == 'jax') ||
128+ (inputs.build_jaxlib && matrix.artifact == 'jaxlib') ||
129+ (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') ||
130+ (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt')
111131 run : |
112132 echo "PLATFORM=$(uname)_$(uname -m)" >> $GITHUB_ENV
113133 - name : Upload artifacts to GCS bucket
114134 # Upload if requested and one of the artifacts was built
115- if : inputs.upload_artifacts
135+ if : >-
136+ inputs.upload_artifacts &&
137+ (inputs.build_jax && matrix.artifact == 'jax') ||
138+ (inputs.build_jaxlib && matrix.artifact == 'jaxlib') ||
139+ (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') ||
140+ (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt')
116141 run : ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM
117142
0 commit comments