Skip to content

Commit bbba077

Browse files
authored
Merge pull request #31 from hardik01shah/main
Add JAX Support to NPBench and Implement JAX Benchmarks
2 parents 1a9e6f8 + 5595a87 commit bbba077

File tree

64 files changed

+2105
-1
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+2105
-1
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
# dace
132+
.dacecache/

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ python plot_results.py
2222
Currently, the following frameworks are supported (in alphabetical order):
2323
- CuPy
2424
- DaCe
25+
- JAX
2526
- Numba
2627
- NumPy
2728
- Pythran
@@ -55,6 +56,24 @@ However, you may want to install the latest version from the [GitHub repository]
5556
To run NPBench with DaCe, you have to select as framework (see details below)
5657
either `dace_cpu` or `dace_gpu`.
5758

59+
### Jax
60+
61+
JAX can be installed with pip:
62+
- CPU-only (Linux/macOS/Windows)
63+
```sh
64+
pip install -U jax
65+
```
66+
- GPU (NVIDIA, CUDA 12)
67+
```sh
68+
pip install -U "jax[cuda12]"
69+
```
70+
- TPU (Google Cloud TPU VM)
71+
```sh
72+
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
73+
```
74+
For more installation options, please consult the JAX [installation guide](https://jax.readthedocs.io/en/latest/installation.html#installation).
75+
76+
5877
### Numba
5978

6079
Numba can be installed with pip:

framework_info/jax.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"framework": {
3+
"simple_name": "jax",
4+
"full_name": "Jax",
5+
"prefix": "jax",
6+
"postfix": "jax",
7+
"class": "JaxFramework",
8+
"arch": "cpu"
9+
}
10+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2014 Jérôme Kieffer et al.
2+
# This is an open-access article distributed under the terms of the
3+
# Creative Commons Attribution License, which permits unrestricted use,
4+
# distribution, and reproduction in any medium, provided the original author
5+
# and source are credited.
6+
# http://creativecommons.org/licenses/by/3.0/
7+
# Jérôme Kieffer and Giannis Ashiotis. Pyfai: a python library for
8+
# high performance azimuthal integration on gpu, 2014. In Proceedings of the
9+
# 7th European Conference on Python in Science (EuroSciPy 2014).
10+
11+
import jax
12+
import jax.numpy as jnp
13+
from functools import partial
14+
15+
@partial(jax.jit, static_argnums=(2,))
16+
def azimint_hist(data: jax.Array, radius: jax.Array, npt):
17+
histu = jnp.histogram(radius, npt)[0]
18+
histw = jnp.histogram(radius, npt, weights=data)[0]
19+
return histw / histu
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2014 Jérôme Kieffer et al.
2+
# This is an open-access article distributed under the terms of the
3+
# Creative Commons Attribution License, which permits unrestricted use,
4+
# distribution, and reproduction in any medium, provided the original author
5+
# and source are credited.
6+
# http://creativecommons.org/licenses/by/3.0/
7+
# Jérôme Kieffer and Giannis Ashiotis. Pyfai: a python library for
8+
# high performance azimuthal integration on gpu, 2014. In Proceedings of the
9+
# 7th European Conference on Python in Science (EuroSciPy 2014).
10+
11+
import jax
12+
import jax.numpy as jnp
13+
from jax import lax
14+
from functools import partial
15+
16+
17+
@partial(jax.jit, static_argnums=(2,))
18+
def azimint_naive(data, radius, npt):
19+
rmax = radius.max()
20+
res = jnp.zeros(npt, dtype=jnp.float64)
21+
22+
def loop_body(i, res):
23+
r1 = rmax * i / npt
24+
r2 = rmax * (i + 1) / npt
25+
mask_r12 = jnp.logical_and((r1 <= radius), (radius < r2))
26+
mean = jnp.where(mask_r12, data, 0).mean(where=mask_r12)
27+
res = res.at[i].set(mean)
28+
return res
29+
30+
res = lax.fori_loop(0, npt, loop_body, res)
31+
32+
return res
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Barba, Lorena A., and Forsyth, Gilbert F. (2018).
2+
# CFD Python: the 12 steps to Navier-Stokes equations.
3+
# Journal of Open Source Education, 1(9), 21,
4+
# https://doi.org/10.21105/jose.00021
5+
# TODO: License
6+
# (c) 2017 Lorena A. Barba, Gilbert F. Forsyth.
7+
# All content is under Creative Commons Attribution CC-BY 4.0,
8+
# and all code is under BSD-3 clause (previously under MIT, and changed on March 8, 2018).
9+
10+
import jax.numpy as jnp
11+
import jax
12+
from jax import lax
13+
from functools import partial
14+
15+
16+
@partial(jax.jit, static_argnums=(1,))
17+
def build_up_b(b, rho, dt, u, v, dx, dy):
18+
19+
b = b.at[1:-1,
20+
1:-1].set(rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) +
21+
(v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) -
22+
((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 *
23+
((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) *
24+
(v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) -
25+
((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2))
26+
27+
return b
28+
29+
30+
@partial(jax.jit, static_argnums=(0,))
31+
def pressure_poisson(nit, p, dx, dy, b):
32+
def body_func(p, _):
33+
pn = p.copy()
34+
p = p.at[1:-1, 1:-1].set(((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 +
35+
(pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) /
36+
(2 * (dx**2 + dy**2)) - dx**2 * dy**2 /
37+
(2 * (dx**2 + dy**2)) * b[1:-1, 1:-1])
38+
39+
p = p.at[:, -1].set(p[:, -2]) # dp/dx = 0 at x = 2
40+
p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0
41+
p = p.at[:, 0].set(p[:, 1]) # dp/dx = 0 at x = 0
42+
p = p.at[-1, :].set(0) # p = 0 at y = 2
43+
44+
return p, None
45+
46+
p, _ = lax.scan(body_func, p, jnp.arange(nit))
47+
48+
return p
49+
50+
51+
@partial(jax.jit, static_argnums=(0,1,2,3,10,11,))
52+
def cavity_flow(nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu):
53+
b = jnp.zeros((ny, nx))
54+
array_vals = (u, v, p, b)
55+
56+
def body_func(array_vals, _):
57+
58+
u, v, p, b = array_vals
59+
60+
un = u.copy()
61+
vn = v.copy()
62+
63+
b = build_up_b(b, rho, dt, u, v, dx, dy)
64+
p = pressure_poisson(nit, p, dx, dy, b)
65+
66+
u = u.at[1:-1,
67+
1:-1].set(un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx *
68+
(un[1:-1, 1:-1] - un[1:-1, 0:-2]) -
69+
vn[1:-1, 1:-1] * dt / dy *
70+
(un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) *
71+
(p[1:-1, 2:] - p[1:-1, 0:-2]) + nu *
72+
(dt / dx**2 *
73+
(un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) +
74+
dt / dy**2 *
75+
(un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1])))
76+
77+
v = v.at[1:-1,
78+
1:-1].set(vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx *
79+
(vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) -
80+
vn[1:-1, 1:-1] * dt / dy *
81+
(vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) *
82+
(p[2:, 1:-1] - p[0:-2, 1:-1]) + nu *
83+
(dt / dx**2 *
84+
(vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) +
85+
dt / dy**2 *
86+
(vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1])))
87+
88+
u = u.at[0, :].set(0)
89+
u = u.at[:, 0].set(0)
90+
u = u.at[:, -1].set(0)
91+
u = u.at[-1, :].set(1) # set velocity on cavity lid equal to 1
92+
v = v.at[0, :].set(0)
93+
v = v.at[-1, :].set(0)
94+
v = v.at[:, 0].set(0)
95+
v = v.at[:, -1].set(0)
96+
97+
return (u, v, p, b), None
98+
99+
out_vals, _ = lax.scan(body_func, array_vals, jnp.arange(nt))
100+
u, v, p, b = out_vals
101+
102+
return u, v, p
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Barba, Lorena A., and Forsyth, Gilbert F. (2018).
2+
# CFD Python: the 12 steps to Navier-Stokes equations.
3+
# Journal of Open Source Education, 1(9), 21,
4+
# https://doi.org/10.21105/jose.00021
5+
# TODO: License
6+
# (c) 2017 Lorena A. Barba, Gilbert F. Forsyth.
7+
# All content is under Creative Commons Attribution CC-BY 4.0,
8+
# and all code is under BSD-3 clause (previously under MIT, and changed on March 8, 2018).
9+
10+
import jax.numpy as jnp
11+
import jax
12+
from jax import lax
13+
from functools import partial
14+
15+
16+
@partial(jax.jit, static_argnums=(0,))
17+
def build_up_b(rho, dt, dx, dy, u, v):
18+
b = jnp.zeros_like(u)
19+
b = b.at[1:-1,
20+
1:-1].set((rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) +
21+
(v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) -
22+
((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 *
23+
((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) *
24+
(v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) -
25+
((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2)))
26+
27+
# Periodic BC Pressure @ x = 2
28+
b = b.at[1:-1, -1].set((rho * (1 / dt * ((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx) +
29+
(v[2:, -1] - v[0:-2, -1]) / (2 * dy)) -
30+
((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx))**2 - 2 *
31+
((u[2:, -1] - u[0:-2, -1]) / (2 * dy) *
32+
(v[1:-1, 0] - v[1:-1, -2]) / (2 * dx)) -
33+
((v[2:, -1] - v[0:-2, -1]) / (2 * dy))**2)))
34+
35+
# Periodic BC Pressure @ x = 0
36+
b = b.at[1:-1, 0].set((rho * (1 / dt * ((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx) +
37+
(v[2:, 0] - v[0:-2, 0]) / (2 * dy)) -
38+
((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx))**2 - 2 *
39+
((u[2:, 0] - u[0:-2, 0]) / (2 * dy) *
40+
(v[1:-1, 1] - v[1:-1, -1]) /
41+
(2 * dx)) - ((v[2:, 0] - v[0:-2, 0]) / (2 * dy))**2)))
42+
43+
return b
44+
45+
@partial(jax.jit, static_argnums=(0,))
46+
def pressure_poisson_periodic(nit, p, dx, dy, b):
47+
48+
def body_func(p, q):
49+
pn = p.copy()
50+
p = p.at[1:-1, 1:-1].set(((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 +
51+
(pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) /
52+
(2 * (dx**2 + dy**2)) - dx**2 * dy**2 /
53+
(2 * (dx**2 + dy**2)) * b[1:-1, 1:-1])
54+
55+
# Periodic BC Pressure @ x = 2
56+
p = p.at[1:-1, -1].set(((pn[1:-1, 0] + pn[1:-1, -2]) * dy**2 +
57+
(pn[2:, -1] + pn[0:-2, -1]) * dx**2) /
58+
(2 * (dx**2 + dy**2)) - dx**2 * dy**2 /
59+
(2 * (dx**2 + dy**2)) * b[1:-1, -1])
60+
61+
# Periodic BC Pressure @ x = 0
62+
p = p.at[1:-1,
63+
0].set((((pn[1:-1, 1] + pn[1:-1, -1]) * dy**2 +
64+
(pn[2:, 0] + pn[0:-2, 0]) * dx**2) / (2 * (dx**2 + dy**2)) -
65+
dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, 0]))
66+
67+
# Wall boundary conditions, pressure
68+
p = p.at[-1, :].set(p[-2, :]) # dp/dy = 0 at y = 2
69+
p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0
70+
71+
return p, None
72+
73+
p, _ = lax.scan(body_func, p, jnp.arange(nit))
74+
75+
76+
@partial(jax.jit, static_argnums=(0,7,8,9))
77+
def channel_flow(nit, u, v, dt, dx, dy, p, rho, nu, F):
78+
udiff = 1
79+
stepcount = 0
80+
81+
array_vals = (udiff, stepcount, u, v, p)
82+
83+
def conf_func(array_vals):
84+
udiff, _, _, _ , _ = array_vals
85+
return udiff > .001
86+
87+
def body_func(array_vals):
88+
_, stepcount, u, v, p = array_vals
89+
90+
un = u.copy()
91+
vn = v.copy()
92+
93+
b = build_up_b(rho, dt, dx, dy, u, v)
94+
pressure_poisson_periodic(nit, p, dx, dy, b)
95+
96+
u = u.at[1:-1,
97+
1:-1].set(un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx *
98+
(un[1:-1, 1:-1] - un[1:-1, 0:-2]) -
99+
vn[1:-1, 1:-1] * dt / dy *
100+
(un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) *
101+
(p[1:-1, 2:] - p[1:-1, 0:-2]) + nu *
102+
(dt / dx**2 *
103+
(un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) +
104+
dt / dy**2 *
105+
(un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1])) +
106+
F * dt)
107+
108+
v = v.at[1:-1,
109+
1:-1].set(vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx *
110+
(vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) -
111+
vn[1:-1, 1:-1] * dt / dy *
112+
(vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) *
113+
(p[2:, 1:-1] - p[0:-2, 1:-1]) + nu *
114+
(dt / dx**2 *
115+
(vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) +
116+
dt / dy**2 *
117+
(vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1])))
118+
119+
# Periodic BC u @ x = 2
120+
u = u.at[1:-1, -1].set(
121+
un[1:-1, -1] - un[1:-1, -1] * dt / dx *
122+
(un[1:-1, -1] - un[1:-1, -2]) - vn[1:-1, -1] * dt / dy *
123+
(un[1:-1, -1] - un[0:-2, -1]) - dt / (2 * rho * dx) *
124+
(p[1:-1, 0] - p[1:-1, -2]) + nu *
125+
(dt / dx**2 *
126+
(un[1:-1, 0] - 2 * un[1:-1, -1] + un[1:-1, -2]) + dt / dy**2 *
127+
(un[2:, -1] - 2 * un[1:-1, -1] + un[0:-2, -1])) + F * dt)
128+
129+
# Periodic BC u @ x = 0
130+
u = u.at[1:-1,
131+
0].set(un[1:-1, 0] - un[1:-1, 0] * dt / dx *
132+
(un[1:-1, 0] - un[1:-1, -1]) - vn[1:-1, 0] * dt / dy *
133+
(un[1:-1, 0] - un[0:-2, 0]) - dt / (2 * rho * dx) *
134+
(p[1:-1, 1] - p[1:-1, -1]) + nu *
135+
(dt / dx**2 *
136+
(un[1:-1, 1] - 2 * un[1:-1, 0] + un[1:-1, -1]) + dt / dy**2 *
137+
(un[2:, 0] - 2 * un[1:-1, 0] + un[0:-2, 0])) + F * dt)
138+
139+
# Periodic BC v @ x = 2
140+
v = v.at[1:-1, -1].set(
141+
vn[1:-1, -1] - un[1:-1, -1] * dt / dx *
142+
(vn[1:-1, -1] - vn[1:-1, -2]) - vn[1:-1, -1] * dt / dy *
143+
(vn[1:-1, -1] - vn[0:-2, -1]) - dt / (2 * rho * dy) *
144+
(p[2:, -1] - p[0:-2, -1]) + nu *
145+
(dt / dx**2 *
146+
(vn[1:-1, 0] - 2 * vn[1:-1, -1] + vn[1:-1, -2]) + dt / dy**2 *
147+
(vn[2:, -1] - 2 * vn[1:-1, -1] + vn[0:-2, -1])))
148+
149+
# Periodic BC v @ x = 0
150+
v = v.at[1:-1,
151+
0].set(vn[1:-1, 0] - un[1:-1, 0] * dt / dx *
152+
(vn[1:-1, 0] - vn[1:-1, -1]) - vn[1:-1, 0] * dt / dy *
153+
(vn[1:-1, 0] - vn[0:-2, 0]) - dt / (2 * rho * dy) *
154+
(p[2:, 0] - p[0:-2, 0]) + nu *
155+
(dt / dx**2 *
156+
(vn[1:-1, 1] - 2 * vn[1:-1, 0] + vn[1:-1, -1]) + dt / dy**2 *
157+
(vn[2:, 0] - 2 * vn[1:-1, 0] + vn[0:-2, 0])))
158+
159+
# Wall BC: u,v = 0 @ y = 0,2
160+
u = u.at[0, :].set(0)
161+
u = u.at[-1, :].set(0)
162+
v = v.at[0, :].set(0)
163+
v = v.at[-1, :].set(0)
164+
165+
udiff = (jnp.sum(u) - jnp.sum(un)) / jnp.sum(u)
166+
stepcount += 1
167+
168+
return (udiff, stepcount, u, v, p)
169+
170+
_, stepcount, _, _, _ = lax.while_loop(conf_func, body_func, array_vals)
171+
172+
return stepcount
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html
2+
3+
import jax.numpy as jnp
4+
import jax
5+
6+
@jax.jit
7+
def compute(array_1, array_2, a, b, c):
8+
return jnp.clip(array_1, 2, 10) * a + array_2 * b + c

0 commit comments

Comments
 (0)