Skip to content

Commit 3783c84

Browse files
committed
Merge branch 'main' of https://github.com/google-ml-infra/jax-fork into srnitin/task-jax-ci-rework
2 parents de602b5 + f8b2c2b commit 3783c84

File tree

394 files changed

+10829
-8586
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

394 files changed

+10829
-8586
lines changed

.github/ISSUE_TEMPLATE/bug-report.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ body:
2020
* If you prefer a non-templated issue report, click [here][Raw report].
2121
2222
23-
[Discussions]: https://github.com/google/jax/discussions
23+
[Discussions]: https://github.com/jax-ml/jax/discussions
2424
25-
[issue search]: https://github.com/google/jax/search?q=is%3Aissue&type=issues
25+
[issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues
2626
27-
[Raw report]: http://github.com/google/jax/issues/new
27+
[Raw report]: http://github.com/jax-ml/jax/issues/new
2828
- type: textarea
2929
attributes:
3030
label: Description

.github/ISSUE_TEMPLATE/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
blank_issues_enabled: false
22
contact_links:
33
- name: Have questions or need support?
4-
url: https://github.com/google/jax/discussions
4+
url: https://github.com/jax-ml/jax/discussions
55
about: Please ask questions on the Discussions tab

.github/workflows/jax-array-api.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
name: JAX Array API
22

33
on:
4-
workflow_dispatch: # allows triggering the workflow run manually
5-
pull_request: # Automatically trigger on pull requests affecting particular files
4+
push:
5+
branches:
6+
- main
7+
pull_request:
68
branches:
79
- main
8-
paths:
9-
- '**workflows/jax-array-api.yml'
10-
- '**experimental/array_api/**'
1110

1211
concurrency:
1312
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}

.github/workflows/self_hosted_runner_utils/setup_runner.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ runner_token="$3"
3131
# - sets empty string as default to avoid unbound variable error from set -u
3232
jax_repo_url="${4-}"
3333
if [ -z "${jax_repo_url}" ]; then
34-
jax_repo_url="https://github.com/google/jax"
34+
jax_repo_url="https://github.com/jax-ml/jax"
3535
fi
3636

3737
# Create `runner` user. This user won't have sudo access unless you ssh into the
@@ -67,7 +67,7 @@ cd ~/
6767
6868
git clone ${jax_repo_url}
6969
70-
# Based on https://github.com/google/jax/settings/actions/runners/new
70+
# Based on https://github.com/jax-ml/jax/settings/actions/runners/new
7171
# (will be 404 for github users with insufficient repo permissions)
7272
mkdir actions-runner && cd actions-runner
7373
curl -o actions-runner-linux-x64.tar.gz -L ${actions_runner_download}

.github/workflows/upstream-nightly.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ jobs:
8484
failure()
8585
&& steps.status.outcome == 'failure'
8686
&& github.event_name == 'schedule'
87-
&& github.repository == 'google/jax'
88-
uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
87+
&& github.repository == 'jax-ml/jax'
88+
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4
8989
with:
9090
name: output-${{ matrix.python-version }}-log.jsonl
9191
path: output-${{ matrix.python-version }}-log.jsonl

.github/workflows/wheel_win_x64.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
matrix:
1818
os: [windows-2019-32core]
1919
arch: [AMD64]
20-
pyver: ['3.10', '3.11', '3.12']
20+
pyver: ['3.10', '3.11', '3.12', '3.13.0-rc.2']
2121
name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build
2222
runs-on: ${{ matrix.os }}
2323

@@ -45,7 +45,7 @@ jobs:
4545
--bazel_options=--config=win_clang `
4646
--verbose
4747
48-
- uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
48+
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4
4949
with:
5050
name: wheels-${{ matrix.os }}-${{ matrix.pyver }}
5151
path: ${{ github.workspace }}\dist\*.whl

.github/workflows/windows_ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
--bazel_options=--color=yes `
5454
--bazel_options=--config=win_clang
5555
56-
- uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
56+
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4
5757
with:
5858
name: wheels
5959
path: ${{ github.workspace }}\jax\dist\*.whl

CHANGELOG.md

Lines changed: 121 additions & 103 deletions
Large diffs are not rendered by default.

CITATION.bib

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
@software{jax2018github,
22
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
33
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
4-
url = {http://github.com/google/jax},
4+
url = {http://github.com/jax-ml/jax},
55
version = {0.3.13},
66
year = {2018},
77
}

README.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
> for the offical JAX repo
44
55
<div align="center">
6-
<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" alt="logo"></img>
6+
<img src="https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png" alt="logo"></img>
77
</div>
88

99
# Transformable numerical computing at scale
1010

11-
![Continuous integration](https://github.com/google/jax/actions/workflows/ci-build.yaml/badge.svg)
11+
![Continuous integration](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml/badge.svg)
1212
![PyPI version](https://img.shields.io/pypi/v/jax)
1313

1414
[**Quickstart**](#quickstart-colab-in-the-cloud)
@@ -54,7 +54,7 @@ parallel programming of multiple accelerators, with more to come.
5454
This is a research project, not an official Google product. Expect bugs and
5555
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
5656
Please help by trying it out, [reporting
57-
bugs](https://github.com/google/jax/issues), and letting us know what you
57+
bugs](https://github.com/jax-ml/jax/issues), and letting us know what you
5858
think!
5959

6060
```python
@@ -88,16 +88,16 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra
8888
Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
8989
Here are some starter notebooks:
9090
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html)
91-
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
91+
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
9292

9393
**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
94-
Colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs).
94+
Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs).
9595

9696
For a deeper dive into JAX:
9797
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
9898
- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
9999
- See the [full list of
100-
notebooks](https://github.com/google/jax/tree/main/docs/notebooks).
100+
notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks).
101101

102102
## Transformations
103103

@@ -304,7 +304,7 @@ print(normalize(jnp.arange(4.)))
304304
# prints [0. 0.16666667 0.33333334 0.5 ]
305305
```
306306

307-
You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more
307+
You can even [nest `pmap` functions](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more
308308
sophisticated communication patterns.
309309

310310
It all composes, so you're free to differentiate through parallel computations:
@@ -337,9 +337,9 @@ When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the
337337
backward pass of the computation is parallelized just like the forward pass.
338338

339339
See the [SPMD
340-
Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
340+
Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
341341
and the [SPMD MNIST classifier from scratch
342-
example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py)
342+
example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py)
343343
for more.
344344

345345
## Current gotchas
@@ -353,7 +353,7 @@ Some standouts:
353353
1. [In-place mutating updates of
354354
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
355355
1. [Random numbers are
356-
different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/google/jax/blob/main/docs/jep/263-prng.md).
356+
different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
357357
1. If you're looking for [convolution
358358
operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
359359
they're in the `jax.lax` package.
@@ -441,7 +441,7 @@ To cite this repository:
441441
@software{jax2018github,
442442
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
443443
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
444-
url = {http://github.com/google/jax},
444+
url = {http://github.com/jax-ml/jax},
445445
version = {0.3.13},
446446
year = {2018},
447447
}

0 commit comments

Comments
 (0)