-
Notifications
You must be signed in to change notification settings - Fork 894
Added support for JAX for Neumann and Robin BC #2015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,7 +53,12 @@ def collocation_points(self, X): | |
| return self.filter(X) | ||
|
|
||
| def normal_derivative(self, X, inputs, outputs, beg, end): | ||
| dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)[beg:end] | ||
| if backend_name == "jax": | ||
| dydx = grad.jacobian( | ||
| (outputs, self.func), inputs, i=self.component, j=None | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inside gradients, the Jacobian calculation needs an array with 2 elements, where the first element is the same as in the other backends (what we want to compute the gradient to), but the second element needs to be a function that computes the gradient (the BC in this case). Previously, only the output was fed to the Therefore, in those lines, the error was different, which made it difficult to pinpoint, and my conclusion was that the function to compute the gradient against needed to be passed forward to the Jacobian calculation. elif backend_name == "jax":
tangent = jax.numpy.zeros(self.dim_x).at[j].set(1)
grad_fn = lambda x: jax.jvp(self.ys[1], (x,), (tangent,))[1]
self.J[j] = (jax.vmap(grad_fn)(self.xs), grad_fn)This is what was causing the error. When computing
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, for JAX, we need another element of the function. My question is why it is
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. We want to compute the boundaries conditions. That part of the code is called in the In short, what I mean is that we do not want to compute the gradients w.r.t the network, as what we want is to compute a term for the loss, not the gradient of the loss that can be used for backpropagation.
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Neumann BC is dy/dn = f(x), where dy/dn dy/dx in the n direction. |
||
| )[0][beg:end] | ||
| else: | ||
| dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)[beg:end] | ||
| n = self.boundary_normal(X, beg, end, None) | ||
| return bkd.sum(dydx * n, 1, keepdims=True) | ||
|
|
||
|
|
@@ -282,7 +287,9 @@ def collocation_points(self, X): | |
|
|
||
| def error(self, X, inputs, outputs, beg, end, aux_var=None): | ||
| if self.batch_size is not None: | ||
| return self.func(inputs, outputs, X)[beg:end] - self.values[self.batch_indices] | ||
| return ( | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't modify this. |
||
| self.func(inputs, outputs, X)[beg:end] - self.values[self.batch_indices] | ||
| ) | ||
| return self.func(inputs, outputs, X)[beg:end] - self.values | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't modify this.