diff --git a/content/notebooks/jax_intro/jax_intro.ipynb b/content/notebooks/jax_intro/jax_intro.ipynb new file mode 100644 index 0000000..3e19b2e --- /dev/null +++ b/content/notebooks/jax_intro/jax_intro.ipynb @@ -0,0 +1,905 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "99cff3e2-951c-4ef4-8a4a-5debabef726f", + "metadata": {}, + "source": [ + "# JAX on TIKE\n", + "When we're running computationally expensive code on the cloud, there are three considerations we might have:\n", + "\n", + "1. We want to *optimize* our code\n", + "2. We want to *parallelize* our code across multiple cores\n", + "3. We want access to *gradients* of our functions to accelerate our algorithms\n", + "\n", + "[JAX](https://docs.jax.dev/en/latest/automatic-differentiation.html) is a Python library that helps with all three goals. In this and the following notebooks, we will briefly introduce ``JAX`` and how to effectively use it on TIKE." + ] + }, + { + "cell_type": "markdown", + "id": "ffbbc28c-0d7e-4f00-96e5-57a82909adc9", + "metadata": {}, + "source": [ + "First, let's install ``JAX`` in this kernel." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a78fdf1e-aa10-48a1-af26-570fa3389a71", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: jax in /opt/conda/envs/tess/lib/python3.11/site-packages (0.7.2)\n", + "Requirement already satisfied: jaxlib<=0.7.2,>=0.7.2 in /opt/conda/envs/tess/lib/python3.11/site-packages (from jax) (0.7.2)\n", + "Requirement already satisfied: ml_dtypes>=0.5.0 in /opt/conda/envs/tess/lib/python3.11/site-packages (from jax) (0.5.3)\n", + "Collecting numpy>=2.0 (from jax)\n", + " Using cached numpy-2.3.3-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)\n", + "Requirement already satisfied: opt_einsum in /opt/conda/envs/tess/lib/python3.11/site-packages (from jax) (3.3.0)\n", + "Requirement already satisfied: scipy>=1.13 in /opt/conda/envs/tess/lib/python3.11/site-packages (from jax) (1.13.1)\n", + " Using cached numpy-2.2.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)\n", + "Using cached numpy-2.2.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.8 MB)\n", + "Installing collected packages: numpy\n", + " Attempting uninstall: numpy\n", + " Found existing installation: numpy 1.26.4\n", + " Uninstalling numpy-1.26.4:\n", + " Successfully uninstalled numpy-1.26.4\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "astropy 6.0.0 requires numpy<2,>=1.22, but you have numpy 2.2.6 which is incompatible.\n", + "lcviz 1.0.0 requires numpy<2, but you have numpy 2.2.6 which is incompatible.\n", + "lkprf 1.1.1 requires numpy<2.0.0,>=1.20.0, but you have numpy 2.2.6 which is incompatible.\n", + "lksearch 1.1.0 requires numpy<2.0.0,>=1.26.4, but you have numpy 2.2.6 which is incompatible.\n", + "numba 0.59.1 requires numpy<1.27,>=1.22, but you have numpy 2.2.6 which is incompatible.\n", + "pytensor 2.22.1 requires numpy<2,>=1.17.0, but you have numpy 2.2.6 which is incompatible.\n", + "tensorflow 2.16.1 requires ml-dtypes~=0.3.1, but you have ml-dtypes 0.5.3 which is incompatible.\n", + "tensorflow 2.16.1 requires numpy<2.0.0,>=1.23.5; python_version <= \"3.11\", but you have numpy 2.2.6 which is incompatible.\n", + "tesscube 1.0.5 requires numpy<2.0.0,>=1.26.1, but you have numpy 2.2.6 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed numpy-2.2.6\n" + ] + } + ], + "source": [ + "!pip install -U jax" + ] + }, + { + "cell_type": "markdown", + "id": "bea67ba1-870b-4afb-ba49-bca63f33d1ae", + "metadata": {}, + "source": [ + "Next, let's import the package. We'll also import `numpy` and the `time` module for comparison." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "1942458c-1710-4e96-a377-045830d6408e", + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import time\n", + "\n", + "from jax import grad" + ] + }, + { + "cell_type": "markdown", + "id": "15e36da4-bb01-4750-a248-cc3feb2f957a", + "metadata": {}, + "source": [ + "## Arrays in `JAX`" + ] + }, + { + "cell_type": "markdown", + "id": "650161a9-cdf2-458f-aadb-d656b8687635", + "metadata": {}, + "source": [ + "Let's start by reviewing some array operations in `numpy`. We can create arrays of values like so:" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "2bcade3a-c4fb-449d-9c4a-a87ee440c6ca", + "metadata": {}, + "outputs": [], + "source": [ + "array1 = np.array([1,2,3])\n", + "array2 = np.array([2,4,5])" + ] + }, + { + "cell_type": "markdown", + "id": "d91e6c2a-6654-4fd4-9d4b-f630206d9db1", + "metadata": {}, + "source": [ + "These arrays obey some standard operations, such as addition..." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "bec1443c-3e9b-4e42-ad08-4864303d2880", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([3, 6, 8])" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "array3 = array1 + array2\n", + "array3" + ] + }, + { + "cell_type": "markdown", + "id": "ddc68d22-2f7f-4ea3-b2a7-099756360ffc", + "metadata": {}, + "source": [ + "...multiplication..." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "acbc3df4-ad23-4939-b5e4-ba5866e3a3d0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 2, 8, 15])" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "array4 = array1 * array2\n", + "array4" + ] + }, + { + "cell_type": "markdown", + "id": "10cbf916-42bd-4e09-8cd4-278d02e089d7", + "metadata": {}, + "source": [ + "...and indexing." + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "78605e4c-031c-4e13-8809-3699afa156c3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "np.int64(1)" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "val1 = array1[0]\n", + "val1" + ] + }, + { + "cell_type": "markdown", + "id": "fa088c27-a3d1-4ca4-98fa-97967b23ed4c", + "metadata": {}, + "source": [ + "We can also iterate through arrays to sequentially access the values within them:" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "276952d7-b53f-4e88-9b4d-1e93de0d82e7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n", + "2\n", + "3\n" + ] + } + ], + "source": [ + "for value in array1:\n", + " print(value)" + ] + }, + { + "cell_type": "markdown", + "id": "999762c8-c5a5-48ca-9dcd-e34f1bbf4e50", + "metadata": {}, + "source": [ + "`JAX` provides objects that are very similar to `numpy` arrays. This is by design, so that there is maximal interoperability between `numpy`-based code and `JAX`-based code. \n", + "\n", + "Let's create some `JAX` arrays!" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "e96b9329-0d2f-4040-8cf2-7aaddbc3154d", + "metadata": {}, + "outputs": [], + "source": [ + "jax_array1 = jnp.array([1,2,3])\n", + "jax_array2 = jnp.array([2,4,5])" + ] + }, + { + "cell_type": "markdown", + "id": "8ca1d0ff-1f5a-4ce8-b802-89d663d6f9b0", + "metadata": {}, + "source": [ + "You can perform many of the same operations on `JAX` arrays as `numpy` arrays:" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "0146af77-0ae6-43c5-9258-480358c6fa70", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([3, 6, 8], dtype=int32)" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax_array3 = jax_array1 + jax_array2\n", + "jax_array3" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "b061d7db-fcfe-4603-84cc-aed6bf37f5ac", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([ 2, 8, 15], dtype=int32)" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax_array4 = jax_array1 * jax_array2\n", + "jax_array4" + ] + }, + { + "cell_type": "markdown", + "id": "44504aa7-a9f6-4518-b9a9-86a5ece6748f", + "metadata": {}, + "source": [ + "There are some important caveats, though! One of the biggest logical shifts between writing `JAX` code and `numpy` code is that `JAX` arrays are immutable. Let's see what that looks like." + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "edd80119-c3de-4f0d-9927-e109fe4f9a46", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([100, 2, 3])" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "array1[0] = 100 # changing a numpy array value, no problem\n", + "array1" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "83874f3b-f9ef-4fed-83cb-5943ccb758a7", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[51], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mjax_array1\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m100\u001b[39m \u001b[38;5;66;03m# changing a JAX array value, and...\u001b[39;00m\n\u001b[1;32m 2\u001b[0m jax_array1\n", + "File \u001b[0;32m/opt/conda/envs/tess/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:617\u001b[0m, in \u001b[0;36m_unimplemented_setitem\u001b[0;34m(self, i, x)\u001b[0m\n\u001b[1;32m 613\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_unimplemented_setitem\u001b[39m(\u001b[38;5;28mself\u001b[39m, i, x):\n\u001b[1;32m 614\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mJAX arrays are immutable and do not support in-place item assignment.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 615\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method:\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 616\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 617\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)))\n", + "\u001b[0;31mTypeError\u001b[0m: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html" + ] + } + ], + "source": [ + "jax_array1[0] = 100 # changing a JAX array value, and...\n", + "jax_array1" + ] + }, + { + "cell_type": "markdown", + "id": "d8325d53-3df8-4b96-b305-8f8a10637ad1", + "metadata": {}, + "source": [ + "This seems quite inconvenient! As the error above shows, we have to use more verbose syntax to change the values inside a `JAX` array, like so:" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "4762fa26-4101-4b95-867e-0e81e7123ed1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([100, 2, 3], dtype=int32)" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax_array1 = jax_array1.at[0].set(100) # changing a JAX array value, and...\n", + "jax_array1" + ] + }, + { + "cell_type": "markdown", + "id": "dadee3ad-8140-4c8a-8497-138ed85ae027", + "metadata": {}, + "source": [ + "As it turns out, we aren't actually changing that same array even when we use that syntax — we're creating a modified copy of the original array.\n", + "\n", + "Why would we accept this seemingly more restricted version of `numpy`? As we'll see, design choices such as this one allow `JAX` to be more efficient down the road." + ] + }, + { + "cell_type": "markdown", + "id": "b222d8fe-a258-445d-9f02-5d65e290ebdd", + "metadata": {}, + "source": [ + "## Functions, `JAX`, and JIT\n", + "One of the easiest ways to see the benefits of `JAX` is by speeding up a single function. Let's begin with a simple function implemented in `numpy`:" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "a438fd76-39dd-43ef-ba36-86883ddbe928", + "metadata": {}, + "outputs": [], + "source": [ + "def numpy_function(x):\n", + " \"\"\"\n", + " A simple numpy-based function that invokes some trigonometric functions.\n", + "\n", + " Inputs\n", + " ------\n", + " :x: (array) input array, of floats.\n", + "\n", + " Outputs\n", + " -------\n", + " :result: (array) result of applying trigonometric functions. Should be equal to 1.\n", + " \"\"\"\n", + " return np.sin(x) ** 2 + np.cos(x) ** 2" + ] + }, + { + "cell_type": "markdown", + "id": "92081d85-6f4e-4fee-b733-a551273786d7", + "metadata": {}, + "source": [ + "Let's create the same function in `JAX`. Again, the only change we'll have to make in our syntax is swapping `np` for `jnp`!" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "ce6715ec-5bad-4078-85dd-6326162bc30e", + "metadata": {}, + "outputs": [], + "source": [ + "def jax_function(x):\n", + " \"\"\"\n", + " A simple JAX-based function that invokes some trigonometric functions.\n", + "\n", + " Inputs\n", + " ------\n", + " :x: (array) input array, of floats.\n", + "\n", + " Outputs\n", + " -------\n", + " :result: (array) result of applying trigonometric functions. Should be equal to 1.\n", + " \"\"\"\n", + " return jnp.sin(x) ** 2 + jnp.cos(x) ** 2" + ] + }, + { + "cell_type": "markdown", + "id": "999c1a4b-b9e7-4e1a-9fbe-ff114c5b15cf", + "metadata": {}, + "source": [ + "Now, let's define large input arrays. `JAX` tends to perform comparatively better over a large number of calculations. We'll create this in both `JAX` and `numpy` arrays." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "78b0501e-48d4-4163-8e9b-fe8397163d6a", + "metadata": {}, + "outputs": [], + "source": [ + "x_np = np.linspace(0, 100, 10_000_000)\n", + "x_jax = jnp.linspace(0, 100, 10_000_000)" + ] + }, + { + "cell_type": "markdown", + "id": "34dd22d1-51f4-4c41-b776-c5b6ab4282a1", + "metadata": {}, + "source": [ + "Let's see how long it takes to execute the `numpy` function. We can use the ``%%timeit`` magic to time this code over multiple evaluations. Doing so gives us a representative value for how fast the code is." + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "id": "2435da28-6069-4384-8485-f66f520082b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "311 ms ± 25.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "y_np = numpy_function(x_np)\n", + "t1 = time.time()" + ] + }, + { + "cell_type": "markdown", + "id": "567678f8-b86b-44a8-a3e5-f61cb850a63d", + "metadata": {}, + "source": [ + "Neat! And now for the `JAX` function ..." + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "2b6a3a89-877d-4df6-b990-abf194d4d11f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "130 ms ± 14.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "y_jax = jax_function(x_jax)" + ] + }, + { + "cell_type": "markdown", + "id": "c1b903ca-2abb-49ae-9ecb-b74cbed23315", + "metadata": {}, + "source": [ + "Wow! This is already much faster (between 2x and 3x speedup).\n", + "\n", + "We can further accelerate things by leveraging `JAX`'s \"just-in-time\", or JIT, compilation. By using the [OpenXLA backend](https://openxla.org/xla), `JAX` will read through your Python function on the first JIT execution, keeping track of Python operations as it goes. It then \"fuses\" the operations together into more efficient, vectorized machine code. This first pass-through can be a bit slower than `numpy`, because the program has to keep track of the speed gains it'd like to make. On subsequent runs, though, the vectorized function is remembered (\"cached\"), so the compilation cost is bypassed!\n", + "\n", + "Let's take a look at this in practice." + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "525d7b22-3d99-4b43-b349-8ac40b15b315", + "metadata": {}, + "outputs": [], + "source": [ + "f_jax_jit = jax.jit(jax_function)" + ] + }, + { + "cell_type": "markdown", + "id": "023cfbea-91ea-4a66-8fad-41116630df9d", + "metadata": {}, + "source": [ + "With the \"jitted\" function now available, let's time it. First, we'll use the `%%time` magic, which only times a cell *once*. This way, we can see how long the first pass (compilation) specifically takes.\n", + "\n", + "When performing timing experiments with `JAX`, we'll want to use the `block_until_ready` method. The reason for this is that JAX has asynchronous dispatch — i.e., Python code can continue past a JAX computation before the JAX computation finishes. This behavior is useful for accelerating programs, but it's not as useful for benchmarking. `block_until_ready` ensures that a program doesn't continue until the computation has actually completed." + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "d42c0583-6d14-4827-8ca5-0f15ae5e5033", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 95 ms, sys: 18.2 ms, total: 113 ms\n", + "Wall time: 59.8 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "# First JAX JIT run (includes compilation cost)\n", + "y_jax = f_jax_jit(x_jax).block_until_ready() # .block_until_ready ensures timing is correct" + ] + }, + { + "cell_type": "markdown", + "id": "b1c33d72-72c7-4ac0-a999-cd734c1446e8", + "metadata": {}, + "source": [ + "Now that the function is compiled, it should run more quickly. Let's find out!" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "dcf6453a-c846-444f-b146-3f5efa249609", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "61.9 ms ± 1.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "# Subsequent JAX JIT run (cached, fast!)\n", + "y_jax = f_jax_jit(x_jax).block_until_ready()" + ] + }, + { + "cell_type": "markdown", + "id": "803baf41-6c6b-40ac-a208-8bb209d3a55d", + "metadata": {}, + "source": [ + "Note how the second JIT run was indeed faster than the first. This is now faster than both our un-jitted `JAX` code and the original `numpy` code (by a factor of 5!).\n", + "\n", + "Note that, in line with the JAX coding philosophy, the speed increases from jitting a function come with extra constraints. For example, control flow (such is branching `if` statements) can be tricky in jitted functions. The reason for this is that when compiling, `JAX` uses the first function execution to trace through the stack. But each code execution will only move through one of the `if` blocks if the `if` block is conditional on the input, so `jax.jit` won't be aware of what the other code block is doing! For more information on this topic, see https://docs.jax.dev/en/latest/control-flow.html#control-flow. " + ] + }, + { + "cell_type": "markdown", + "id": "fa32deef-8820-42f4-a6ff-94802eede1aa", + "metadata": {}, + "source": [ + "## Autodifferentiation in `JAX`\n", + "Because `JAX` has the capability to track operations within a Python function, it can also chain-rule those operations together to calculate the derivative of a function. This strategy is known as automatic differentiation, or \"autodiff.\" We can see this play out quite neatly here:" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "5b5511f8-55c2-49aa-a27e-1eba775ca555", + "metadata": {}, + "outputs": [], + "source": [ + "grad_func = grad(jax_function)" + ] + }, + { + "cell_type": "markdown", + "id": "ebf4aaf6-9144-4ff5-b3b1-967667a3fcdf", + "metadata": {}, + "source": [ + "Note that this gradient is only defined for scalar-output functions. So, let's pass a scalar:" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "b5b92470-246d-446f-8535-77528f4efbf6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(0., dtype=float32, weak_type=True)" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grad_func(1.0)" + ] + }, + { + "cell_type": "markdown", + "id": "6de22447-4a6b-4c5f-b7a5-796df236285d", + "metadata": {}, + "source": [ + "We have a gradient of 0! Since our function returns 1 for every input value, within machine precision..." + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "6e730dcf-37b3-4d53-b609-97952d54cdc1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([1. , 1. , 1. , ..., 1.0000001 , 0.99999994,\n", + " 1. ], dtype=float32)" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_jax" + ] + }, + { + "cell_type": "markdown", + "id": "b4993604-b382-4bb5-8fde-0d614eef6b0e", + "metadata": {}, + "source": [ + "... the gradient should indeed be 0.\n", + "\n", + "This is a bit of a trivial example. What about one in which we know what the gradient should be from calculus? Let's try to implement this." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "665d9c93-ee02-49b0-9f8e-42b8e65c5bd3", + "metadata": {}, + "outputs": [], + "source": [ + "def func(x):\n", + " return x**2" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "709a43c7-decf-428e-a2ad-0316f8ba6a19", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([ 0., 1., 4., 9., 16.], dtype=float32)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_array = jnp.arange(0.0, 5.0)\n", + "func(input_array)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "048303ab-0e3d-40be-b6d1-bf3f03c024dd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0., 1., 2., 3., 4.], dtype=float32)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_array" + ] + }, + { + "cell_type": "markdown", + "id": "d9da4930-301e-48be-bff7-df91b30e71ad", + "metadata": {}, + "source": [ + "The gradient of $x^2$ is $2x$. Will our numerical gradient give us the same result?" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "17c0783b-3029-4c04-a8fd-3b5962ceda58", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0., 2., 4., 6., 8.], dtype=float32)" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grad_func = grad(func)\n", + "jnp.array([grad_func(i) for i in input_array])" + ] + }, + { + "cell_type": "markdown", + "id": "6d0b4a60-43b6-4341-8281-37cbe2ab06c9", + "metadata": {}, + "source": [ + "Perfect! But why didn't pass the whole array to the `grad_func`? Let's see what happens when we do." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "fc47723c-b325-4d08-bce8-32d65def405f", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Gradient only defined for scalar-output functions. Output had shape: (5,).", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[33], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mgrad_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_array\u001b[49m\u001b[43m)\u001b[49m\n", + " \u001b[0;31m[... skipping hidden 4 frame]\u001b[0m\n", + "File \u001b[0;32m/opt/conda/envs/tess/lib/python3.11/site-packages/jax/_src/api.py:501\u001b[0m, in \u001b[0;36m_check_scalar\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 499\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(aval, ShapedArray):\n\u001b[1;32m 500\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m aval\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m!=\u001b[39m ():\n\u001b[0;32m--> 501\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhad shape: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00maval\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 502\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 503\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhad abstract value \u001b[39m\u001b[38;5;132;01m{\u001b[39;00maval\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m))\n", + "\u001b[0;31mTypeError\u001b[0m: Gradient only defined for scalar-output functions. Output had shape: (5,)." + ] + } + ], + "source": [ + "grad_func(input_array)" + ] + }, + { + "cell_type": "markdown", + "id": "fc58f198-91ab-418d-ad1f-a310219bec23", + "metadata": {}, + "source": [ + "`JAX` gradients are only defined for scalars — the framework can't perform derivatives for vector output." + ] + }, + { + "cell_type": "markdown", + "id": "f4cd86e9-96bf-4805-904f-aee574852661", + "metadata": {}, + "source": [ + "You may be wondering why calculating the gradient is worthwhile. It turns out that this is a very helpful quantity to calculate in optimization problems (such as MCMC, or machine learning—for what JAX was developed) and physical systems (such as evolving the orbits of a binary star)." + ] + }, + { + "cell_type": "markdown", + "id": "967b64e8-7791-432b-bf45-658e24e7ce53", + "metadata": {}, + "source": [ + "Now we have a handle on the basics of JAX! In the next notebook, we'll explore how to scale this up to multiple cores on TIKE." + ] + }, + { + "cell_type": "markdown", + "id": "0deddd6d-9f6d-4f85-be47-f0326fd88162", + "metadata": {}, + "source": [ + "## Summary table: speedups\n", + "| Method | Execution Time (ms) | Speedup vs NumPy |\n", + "|---------------|-------------------------|------------------|\n", + "| NumPy | 311 ± 25.8 | 1× (baseline) |\n", + "| JAX | 130 ± 14.2 | 2.39× faster |\n", + "| JAX (jitted) | 61.9 ± 1.68 | 5.02× faster |\n", + "\n", + "Disclaimer: this notebook assumes you're running in a TIKE kernel with Python 3.11. Running locally will vary performance.\n" + ] + }, + { + "cell_type": "markdown", + "id": "7e312fe0-5c2d-42af-9d4f-f01b9fa1fa3b", + "metadata": {}, + "source": [ + "# Resources\n", + "- [JAX documentation](https://docs.jax.dev/en/latest/index.html)\n", + "- [common \"gotchas\" in JAX](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html)\n", + "- [Introduction to automatic differentiation](https://www.mathworks.com/help/optim/ug/autodiff-background.html)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "TESS Environment", + "language": "python", + "name": "tess" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/content/notebooks/webinar-series/05-statistics/statistics.ipynb b/content/notebooks/webinar-series/05-statistics/statistics.ipynb index 05c2223..85e0260 100644 --- a/content/notebooks/webinar-series/05-statistics/statistics.ipynb +++ b/content/notebooks/webinar-series/05-statistics/statistics.ipynb @@ -678,10 +678,7 @@ " random_seed=[10863087, 10863088],\n", " progressbar=False\n", " )\n", - "\n", - " approx_sample = trace.sample(1000)\n", " \n", - " \n", " period_samples = np.asarray(approx_sample.posterior[\"period\"]).flatten()\n", " np.save(f'period_samples{name}.npy', period_samples)\n", " return period_samples\n", @@ -722,7 +719,8 @@ "rot_rates_gp = []\n", "\n", "# iterate though stars\n", - "# for i, star in tqdm(enumerate(stars[::200]), total=len(stars[::200])):\n", + "rot_rates_gp = []\n", + "for i, star in tqdm(enumerate(stars[::200]), total=len(stars[::200])):\n", " \n", "# flare_rates_duration = flare_rates_durations[i]\n", "# # only fit the stars with \"good\" flare rates.\n", @@ -1480,7 +1478,7 @@ "source": [ "fig, axes = plt.subplots(n_dim, figsize=(10, 3), sharex=True)\n", "samples = sampler.get_chain()\n", - "labels = ['m', 'b', 'sigma']\n", + "labels = ['b', 'sigma']\n", "\n", "for i in range(n_dim):\n", " ax = axes[i]\n",