Skip to content

Jax intro notebook#67

Open
arjunsavel wants to merge 6 commits into
spacetelescope:mainfrom
arjunsavel:jax_intro
Open

Jax intro notebook#67
arjunsavel wants to merge 6 commits into
spacetelescope:mainfrom
arjunsavel:jax_intro

Conversation

@arjunsavel

Copy link
Copy Markdown
Contributor

I cover...

  • jax arrays
  • JIT compilation
  • autodifferentiation

Curious to hear thoughts on whether I move through these topics too quickly!

@arjunsavel

Copy link
Copy Markdown
Contributor Author

@snbianco first notebook here if you happen to be interested in taking a look!

@snbianco

snbianco commented Sep 3, 2025

Copy link
Copy Markdown
Contributor

Hey Arjun, sorry that I'm late in reviewing this! I took some vacation time over Labor Day. Overall, I think this is a strong notebook that's clear, to-the-point, and progresses logically. I also appreciate you being transparent about certain caveats.

Here are a few notes I had for review:

  • In the section that talks about Jax arrays being immutable, can we add an example of the proper syntax to use to actually change a Jax array? The error message is a little confusing.
  • Maybe add a disclaimer that the notebook assumes you're running in a TIKE kernel with Python 3.11, and running locally will vary performance.
  • Can you say more about the block_until_ready() function and what it's doing?
  • Typo in "compiled" right after the first JAX JIT run.
  • Typo in "that seems a bit odd" right after you call grad_func(1.0)
  • In the autodifferentiation section, maybe include another example of a function that has a non-trivial gradient.
  • Nitpick: Capitalize first letter in the "Resources" bullets.
  • At the end of the notebook, can you add some sort of summary that compares the 3 methods, the time for an individual run, and the relative speed compared to numpy?

@ttdu

ttdu commented Sep 17, 2025

Copy link
Copy Markdown
Collaborator

I pushed a few of the typo fixes to the latest commit. I'll also emphasize three of Sam's comments:

  1. adding a summary table of the speedup at the end would be informative
  2. we should say more about block_until_ready(), or link to information about it
  3. doing a non-trivial gradient function would be more instructive. x**2 could be a good example since that derivative is easy to calculate -- I'd be much more confident in my setup if I saw [0, 2, 4, 6, 8] than getting an all zeros array

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants