Skip to content

Conversation

@Pheithar
Copy link

I made a few modifications in gradients_reserve.py and boundary_conditions.py to enable out-of-the-box support for JAX, at least for the example deepxde/examples/pinn_forward/Poisson_Neumann_1d.py.

I have tested the changes on a few examples, and it seems to work for without issues for both PyTorch and JAX, but the testing has not been exhaustive.

def error(self, X, inputs, outputs, beg, end, aux_var=None):
if self.batch_size is not None:
return self.func(inputs, outputs, X)[beg:end] - self.values[self.batch_indices]
return (
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't modify this.

raise NotImplementedError(
"Reverse-mode autodiff doesn't support 3D output"
)
raise NotImplementedError("Reverse-mode autodiff doesn't support 3D output")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't modify this.

dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)[beg:end]
if backend_name == "jax":
dydx = grad.jacobian(
(outputs, self.func), inputs, i=self.component, j=None
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why self.func? Here we should compute the derivative of network outputs wrt network inputs. The self.func is the function in the BC.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inside gradients, the Jacobian calculation needs an array with 2 elements, where the first element is the same as in the other backends (what we want to compute the gradient to), but the second element needs to be a function that computes the gradient (the BC in this case). Previously, only the output was fed to the grad.jacobian, and that was producing an error. It should have been an out of index error, in line 50 of deepxde/gradients/gradients_reverse.py and line 76 of deepxde/gradients/gradients_forward.py, but JAX does not throw an error in those instances.

Therefore, in those lines, the error was different, which made it difficult to pinpoint, and my conclusion was that the function to compute the gradient against needed to be passed forward to the Jacobian calculation.

elif backend_name == "jax":
    tangent = jax.numpy.zeros(self.dim_x).at[j].set(1)
    grad_fn = lambda x: jax.jvp(self.ys[1], (x,), (tangent,))[1]
    self.J[j] = (jax.vmap(grad_fn)(self.xs), grad_fn)

This is what was causing the error. When computing self.ys[1], instead of out of bounds, it was throwing a different error. (I believe it was something along the lines of TypeError: 'jaxlib._jax.ArrayImpl' object is not callable)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for JAX, we need another element of the function. My question is why it is self.func? Should it be the function of the network forward pass?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. We want to compute the boundaries conditions. That part of the code is called in the NeumannBC to compute the loss term, and if I am not mistaken, self.func is $f(X)$, and we want our loss term for Neumann boundaries to be $\mathcal{L}_{\text{Neumann}} = \Delta f(X) - Y$. (I hope I am explaining myself here).

In short, what I mean is that we do not want to compute the gradients w.r.t the network, as what we want is to compute a term for the loss, not the gradient of the loss that can be used for backpropagation.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neumann BC is dy/dn = f(x), where dy/dn dy/dx in the n direction.
https://en.wikipedia.org/wiki/Neumann_boundary_condition

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants