Skip to content

Commit c3e23a0

Browse files
committed
Merge branch 'main' of https://github.com/google-ml-infra/jax-fork into srnitin/task-jax-ci-rework
2 parents d8efebb + 97c8d5d commit c3e23a0

File tree

139 files changed

+6435
-3306
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

139 files changed

+6435
-3306
lines changed

.github/workflows/ci-build.yaml

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ jobs:
2929
runs-on: ubuntu-latest
3030
timeout-minutes: 5
3131
steps:
32-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
32+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
3333
- name: Set up Python 3.11
34-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
34+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
3535
with:
3636
python-version: 3.11
37-
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet: pre-commit/action@v3.0.1
37+
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
3838

3939
build:
4040
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
@@ -57,9 +57,9 @@ jobs:
5757
prng-upgrade: 0
5858
num_generated_cases: 1
5959
steps:
60-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
60+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
6161
- name: Set up Python ${{ matrix.python-version }}
62-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
62+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
6363
with:
6464
python-version: ${{ matrix.python-version }}
6565
- name: Get pip cache dir
@@ -68,7 +68,7 @@ jobs:
6868
python -m pip install --upgrade pip wheel
6969
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
7070
- name: pip cache
71-
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4
71+
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
7272
with:
7373
path: ${{ steps.pip-cache.outputs.dir }}
7474
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -104,9 +104,9 @@ jobs:
104104
matrix:
105105
python-version: ['3.10']
106106
steps:
107-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
107+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
108108
- name: Set up Python ${{ matrix.python-version }}
109-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
109+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
110110
with:
111111
python-version: ${{ matrix.python-version }}
112112
- name: Get pip cache dir
@@ -115,7 +115,7 @@ jobs:
115115
python -m pip install --upgrade pip wheel
116116
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
117117
- name: pip cache
118-
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4
118+
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
119119
with:
120120
path: ${{ steps.pip-cache.outputs.dir }}
121121
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -141,9 +141,9 @@ jobs:
141141
matrix:
142142
python-version: ['3.10']
143143
steps:
144-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
144+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
145145
- name: Set up Python ${{ matrix.python-version }}
146-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
146+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
147147
with:
148148
python-version: ${{ matrix.python-version }}
149149
- name: Get pip cache dir
@@ -152,7 +152,7 @@ jobs:
152152
python -m pip install --upgrade pip wheel
153153
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
154154
- name: pip cache
155-
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4
155+
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
156156
with:
157157
path: ${{ steps.pip-cache.outputs.dir }}
158158
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -177,9 +177,9 @@ jobs:
177177
enable-x64: 0
178178
num_generated_cases: 10
179179
steps:
180-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
180+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
181181
- name: Set up Python ${{ matrix.python-version }}
182-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
182+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
183183
with:
184184
python-version: ${{ matrix.python-version }}
185185
- name: Get pip cache dir
@@ -188,7 +188,7 @@ jobs:
188188
python -m pip install --upgrade pip wheel
189189
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
190190
- name: pip cache
191-
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4
191+
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
192192
with:
193193
path: ${{ steps.pip-cache.outputs.dir }}
194194
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -210,4 +210,37 @@ jobs:
210210
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
211211
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
212212
pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py
213-
213+
214+
ffi:
215+
name: FFI example
216+
runs-on: ubuntu-latest
217+
timeout-minutes: 5
218+
steps:
219+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
220+
- name: Set up Python 3.11
221+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
222+
with:
223+
python-version: 3.11
224+
- name: Get pip cache dir
225+
id: pip-cache
226+
run: |
227+
python -m pip install --upgrade pip wheel
228+
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
229+
- name: pip cache
230+
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
231+
with:
232+
path: ${{ steps.pip-cache.outputs.dir }}
233+
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}
234+
- name: Install JAX
235+
run: pip install .
236+
- name: Build and install example project
237+
run: python -m pip install -v ./examples/ffi[test]
238+
env:
239+
# We test building using GCC instead of clang. All other JAX builds use
240+
# clang, but it is useful to make sure that FFI users can compile using
241+
# a different toolchain. GCC is the default compiler on the
242+
# 'ubuntu-latest' runner, but we still set this explicitly just to be
243+
# clear.
244+
CMAKE_ARGS: -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++
245+
- name: Run tests
246+
run: python -m pytest examples/ffi/tests

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
# https://opensource.google/documentation/reference/github/services#actions
4444
# mandates using a specific commit for non-Google actions. We use
4545
# https://github.com/sethvargo/ratchet to pin specific versions.
46-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
46+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
4747
- name: Install JAX test requirements
4848
run: |
4949
pip install -U -r build/test-requirements.txt

.github/workflows/jax-array-api.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,17 @@ jobs:
2222

2323
steps:
2424
- name: Checkout jax
25-
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet: actions/checkout@v4
25+
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
2626
- name: Checkout array-api-tests
27-
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet: actions/checkout@v4
27+
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
2828
with:
2929
repository: data-apis/array-api-tests
3030
# TODO(jakevdp) update this to a stable release/tag when available.
3131
ref: 'b4c0823469c02d6ce6e512ad4c2bd8ba42b1b4b2' # Latest commit as of 2024-09-09
3232
submodules: 'true'
3333
path: 'array-api-tests'
3434
- name: Set up Python ${{ matrix.python-version }}
35-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
35+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
3636
with:
3737
python-version: ${{ matrix.python-version }}
3838
- name: Install dependencies

.github/workflows/metal_plugin_ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727

2828
steps:
2929
- name: Get repo
30-
uses: actions/checkout@v4
30+
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
3131
with:
3232
path: jax
3333
- name: Setup build and test enviroment

.github/workflows/upstream-nightly.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ jobs:
3636
outputs:
3737
artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }}
3838
steps:
39-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
39+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
4040
- name: Set up Python ${{ matrix.python-version }}
41-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
41+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
4242
with:
4343
python-version: ${{ matrix.python-version }}
4444
- name: Install JAX test requirements
@@ -85,7 +85,7 @@ jobs:
8585
&& steps.status.outcome == 'failure'
8686
&& github.event_name == 'schedule'
8787
&& github.repository == 'jax-ml/jax'
88-
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4
88+
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
8989
with:
9090
name: output-${{ matrix.python-version }}-log.jsonl
9191
path: output-${{ matrix.python-version }}-log.jsonl
@@ -106,11 +106,11 @@ jobs:
106106
run:
107107
shell: bash
108108
steps:
109-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
110-
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
109+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
110+
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
111111
with:
112112
python-version: "3.x"
113-
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # ratchet:actions/download-artifact@v4
113+
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
114114
with:
115115
path: /tmp/workspace/logs
116116
- name: install requirements
@@ -123,7 +123,7 @@ jobs:
123123
cat logs/*.jsonl > pytest-logs.txt
124124
python .github/workflows/parse_logs.py pytest-logs.txt --outfile=parsed-logs.txt
125125
- name: Report failures
126-
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # ratchet:actions/github-script@v7
126+
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
127127
with:
128128
github-token: ${{ secrets.GITHUB_TOKEN }}
129129
script: |

.github/workflows/wheel_win_x64.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ jobs:
2525
- name: Install LLVM/Clang
2626
run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade
2727

28-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
28+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
2929

30-
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
30+
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
3131
with:
3232
python-version: ${{ matrix.pyver }}
3333
cache: 'pip'
@@ -45,7 +45,7 @@ jobs:
4545
--bazel_options=--config=win_clang `
4646
--verbose
4747
48-
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4
48+
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
4949
with:
5050
name: wheels-${{ matrix.os }}-${{ matrix.pyver }}
5151
path: ${{ github.workspace }}\dist\*.whl

.github/workflows/windows_ci.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ jobs:
3131
- name: Install LLVM/Clang
3232
run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade
3333

34-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
34+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
3535
with:
3636
path: jax
3737

38-
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
38+
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
3939
with:
4040
python-version: ${{ matrix.pyver }}
4141
cache: 'pip'
@@ -53,7 +53,7 @@ jobs:
5353
--bazel_options=--color=yes `
5454
--bazel_options=--config=win_clang
5555
56-
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4
56+
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
5757
with:
5858
name: wheels
5959
path: ${{ github.workspace }}\jax\dist\*.whl
@@ -66,7 +66,7 @@ jobs:
6666
PY_COLORS: 1
6767
run: |
6868
cd jax
69+
python -m pip install --pre --find-links ${{ github.workspace }}\jax\dist jaxlib
6970
python -m pip install -e ${{ github.workspace }}\jax
70-
python -m pip install --no-index --find-links ${{ github.workspace }}\jax\dist jaxlib
7171
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
7272
pytest -n auto --tb=short tests examples

CHANGELOG.md

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,37 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1515
* New Functionality
1616
* This release includes wheels for Python 3.13. Free-threading mode is not yet
1717
supported.
18+
* `jax.errors.JaxRuntimeError` has been added as a public alias for the
19+
formerly private `XlaRuntimeError` type.
20+
21+
* Breaking changes
22+
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
23+
* array[0] on a pmap result now introduces a reshape (use array[0:1]
24+
instead).
25+
* The per-shard shape (accessable via jax_array.addressable_shards or
26+
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
27+
that directly accesses shards accordingly. The rank of the per-shard-shape
28+
now matches that of the global shape which is the same behavior as jit.
29+
This avoids costly reshapes when passing results from pmap into jit.
30+
* `jax.experimental.host_callback` has been deprecated since March 2024, with
31+
JAX version 0.4.26. Now we set the default value of the
32+
`--jax_host_callback_legacy` configuration value to `True`, which means that
33+
if your code uses `jax.experimental.host_callback` APIs, those API calls
34+
will be implemented in terms of the new `jax.experimental.io_callback` API.
35+
If this breaks your code, for a very limited time, you can set the
36+
`--jax_host_callback_legacy` to `True`. Soon we will remove that
37+
configuration option, so you should instead transition to using the
38+
new JAX callback APIs. See {jax-issue}`#20385` for a discussion.
1839

1940
* Deprecations
2041
* In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike
2142
arguments with `ndim != 1` are now deprecated, and in the future will result
2243
in an error.
44+
* Internal pretty-printing tools `jax.core.pp_*` have been removed, after
45+
being deprecated in JAX v0.4.30.
46+
* `jax.lib.xla_client.Device` is deprecated; use `jax.Device` instead.
47+
* `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use
48+
`jax.errors.JaxRuntimeError` instead.
2349

2450
* Deletion:
2551
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
@@ -34,10 +60,18 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
3460
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
3561
* {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument.
3662
The argument was only used by `xmap` which was removed in 0.4.31.
63+
* `jax.tree.map(f, None, non-None)`, which previously emitted a
64+
`DeprecationWarning`, now raises an error in a future version of jax. `None`
65+
is only a tree-prefix of itself. To preserve the current behavior, you can
66+
ask `jax.tree.map` to treat `None` as a leaf value by writing:
67+
`jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.
68+
* `jax.sharding.XLACompatibleSharding` has been removed. Please use
69+
`jax.sharding.Sharding`.
3770

3871
* Bug fixes
3972
* Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs
4073
if a non-boolean input was provided and `dtype=bool` was specified.
74+
* Edit implementation of {func}`jax.numpy.ldexp` to get correct gradient.
4175

4276
## jax 0.4.33 (September 16, 2024)
4377

@@ -62,14 +96,6 @@ See the 0.4.33 release notes for more details.
6296
C++ and CUDA code from JAX.
6397

6498
* Changes
65-
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
66-
* array[0] on a pmap result now introduces a reshape (use array[0:1]
67-
instead).
68-
* The per-shard shape (accessable via jax_array.addressable_shards or
69-
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
70-
that directly accesses shards accordingly. The rank of the per-shard-shape
71-
now matches that of the global shape which is the same behavior as jit.
72-
This avoids costly reshapes when passing results from pmap into jit.
7399
* `jax_enable_memories` flag is set to `True` by default.
74100
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
75101
See {ref}`python-array-api` for more information.

benchmarks/mosaic/BUILD

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,11 @@ package(
2828

2929
jax_generate_backend_suites()
3030

31-
DISABLED_BACKENDS = [
32-
"cpu",
33-
"tpu",
34-
]
35-
36-
DISABLED_CONFIGS = [
37-
"gpu",
38-
"gpu_a100",
39-
"gpu_p100",
40-
"gpu_p100_x32",
41-
"gpu_x32",
42-
"gpu_pjrt_c_api",
43-
]
44-
4531
jax_multiplatform_test(
4632
name = "matmul_bench",
4733
srcs = ["matmul_bench.py"],
48-
disable_backends = DISABLED_BACKENDS,
49-
disable_configs = DISABLED_CONFIGS,
34+
enable_backends = [],
35+
enable_configs = ["gpu_h100"],
5036
tags = ["notap"],
5137
deps = [
5238
"//jax:mosaic_gpu",

build/build.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,13 @@ def write_bazelrc(*, remote_build,
284284
f.write("build --config=mkl_open_source_only\n")
285285
if enable_cuda:
286286
f.write("build --config=cuda\n")
287+
if use_cuda_nvcc:
288+
f.write("build --config=cuda_nvcc\n")
289+
else:
290+
f.write("build --config=cuda_clang\n")
287291
f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n")
288292
if not enable_nccl:
289293
f.write("build --config=nonccl\n")
290-
if use_cuda_nvcc:
291-
f.write("build --config=cuda_nvcc\n")
292294
if cuda_version:
293295
f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n"
294296
.format(cuda_version=cuda_version))

0 commit comments

Comments
 (0)