Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: autodiff docs #1580

Merged
merged 5 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions nx/guides/advanced/automatic_differentiation.livemd
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Automatic Differentation

```elixir
Mix.install([
{:nx, "~> 0.7"}
])
```

## What is Function Differentiation?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: let's make sure we are consistent with titles and if we use "What is function differentiation?" or "What is Function Differentiation?". FWIW, in the Elixir repo itself, we use the former, but we can follow another convention here.


Nx, through the `Nx.Defn.grad/2` and `Nx.Defn.value_and_grad/3` functions allows the user to differentiate functions that were defined through `defn`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what is worth, we don't explain in this section what is function differentiation. We do talk about differentiation, gradient and derivative, but we don't explain any of the terms. It may be valuable to give a step back and provide a high-level overview. And then either reduce the amount of terms or explicitly explain them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added 2 paragraphs explaining derivatives and gradients


This is really important in Machine Learning settings because, in general, the training process happens through optimization methods that require calculating the gradient of tensor functions.

For those more familiar with the mathematical terminology, the gradient of a tensor function is similar to the derivative of regular (scalar) functions.

Let's take for example the following $f(x)$ and $f'(x)$ scalar function and derivative pair:

$$
f(x) = x^3 + x\\
f'(x) = 3x^2 + 1
$$

We can define a similar function-derivative pair for tensor functions:

$$
f(\bold{x}) = \bold{x}^3 + \bold{x}\\
\nabla f(\bold{x}) = 3 \bold{x} ^ 2 + 1
$$

These may look similar, but the difference is that $f(\bold{x})$ takes in $\bold{x}$ which is a tensor argument. This means that we can have the following argument and results for the function and its gradient:

$$
\bold{x} =
\begin{bmatrix}
1 & 1 \\
2 & 3 \\
5 & 8 \\
\end{bmatrix}\\\
$$

$$
f(\bold{x}) = \bold{x}^3 + \bold{x} =
\begin{bmatrix}
2 & 2 \\
10 & 30 \\
130 & 520
\end{bmatrix}
$$

$$
\nabla f(\bold{x}) = 3 \bold{x} ^ 2 + 1 =
\begin{bmatrix}
4 & 4 \\
13 & 28 \\
76 & 193
\end{bmatrix}
$$

## Automatic Differentiation

Now that we have a general feeling of what a function and its gradient are, we can talk about how Nx can use `defn` to calculate gradients for us.

In the following code blocks we're going to define the same tensor function as above and then we'll differentiate it only using Nx, without having to write the explicit derivative at all.

```elixir
defmodule Math do
import Nx.Defn

defn f(x) do
x ** 3 + x
end

defn grad_f(x) do
Nx.Defn.grad(x, &f/1)
end
end
```

```elixir
x =
Nx.tensor([
[1, 1],
[2, 3],
[5, 8]
])

{
Math.f(x),
Math.grad_f(x)
}
```

As we can see, we get the results we expected, aside from the type of the grad, which will always be a floating-point number, even if you pass an integer tensor as input.

Next, we'll using `Nx.Defn.debug_expr` to see what's happening under the hood.

```elixir
Nx.Defn.debug_expr(&Math.f/1).(x)
```

```elixir
Nx.Defn.debug_expr(&Math.grad_f/1).(x)
```

If we look closely at the returned `Nx.Defn.Expr` representations for `f` and `grad_f`, we can see that they pretty much translate to the mathematical definitions we had originally.

This possible because Nx holds onto the symbolic representation of a `defn` function while inside `defn`-land, and thus `Nx.Defn.grad` (and similar) can operate on that symbolic representation to return a new symbolic representation (as seen in the second block).

<!-- livebook:{"break_markdown":true} -->

`Nx.Defn.value_and_grad` can be used to calculate both things at once for us:

```elixir
Nx.Defn.value_and_grad(x, &Math.f/1)
```

And if we use `debug_expr` again, we can see that the symbolic representation is actually both the function and the grad, returned in a tuple:

```elixir
Nx.Defn.debug_expr(Nx.Defn.value_and_grad(&Math.f/1)).(x)
```

Finally, we can talk about functions that receive many arguments, such as the following `add_multiply` function:

```elixir
add_multiply = fn x, y, z ->
addition = Nx.add(x, y)
Nx.multiply(z, addition)
end
```

At first you may think that if we want to differentiate it, we need to wrap it into a single-argument function so that we can differentiate with respect to a specific argument, which would treat other arguments as constants, as we can see below:

```elixir
x = Nx.tensor([1, 2])
y = Nx.tensor([3, 4])
z = Nx.tensor([5, 6])

{
Nx.Defn.grad(x, fn t -> add_multiply.(t, y, z) end),
Nx.Defn.grad(y, fn t -> add_multiply.(x, t, z) end),
Nx.Defn.grad(z, fn t -> add_multiply.(x, y, t) end)
}
```

However, Nx is smart enough to deal with multi-valued functions through `Nx.Container` representations such as a tuple or a map:

```elixir
Nx.Defn.grad({x, y, z}, fn {x, y, z} -> add_multiply.(x, y, z) end)
```

Likewise, we can also deal with functions that return multiple values.

`Nx.Defn.grad` requires us to return a scalar from function (that is, a tensor of shape `{}`).
However, there are instances where we might want to use `value_and_grad` to get out a tuple from our function, while still calculating its gradient.

For this, we have the `value_and_grad/3` arity, which accepts a transformation argument.

```elixir
x =
Nx.tensor([
[1, 1],
[2, 3],
[5, 8]
])

# Notice that the returned values are the 2 addition terms from `Math.f/1`
multi_valued_return_fn =
fn x ->
{Nx.pow(x, 3), x}
end

transform_fn = fn {x_cubed, x} -> Nx.add(x_cubed, x) end

{{x_cubed, x}, grad} = Nx.Defn.value_and_grad(x, multi_valued_return_fn, transform_fn)
```

If we go back to the start of this livebook, we can see that `grad` holds exactly the result `Math.grad_f`, but now we have access to `x ** 3`, which wasn't accessible before, as originally we could only obtain `x ** 3 + x`.

$$
$$
1 change: 1 addition & 0 deletions nx/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ defmodule Nx.MixProject do
"guides/intro-to-nx.livemd",
"guides/advanced/vectorization.livemd",
"guides/advanced/aggregation.livemd",
"guides/advanced/automatic_differentiation.livemd",
"guides/exercises/exercises-1-20.livemd"
],
skip_undefined_reference_warnings_on: ["CHANGELOG.md"],
Expand Down