Thank you for sharing your work, it's very interesting! The new version using JAX is indeed much faster, but I'm not very familiar with it (I use PyTorch more). Recently, when solving a PDE, I encountered this problem:
$\mathrm{Loss}_1=\frac{\partial u}{\partial x}+\frac{\partial v}{\partial y}$
$\mathrm{Loss}_2=\frac{\partial}{\partial y}\left[ \left( v+\frac{v_t}{\sigma _k} \right) \frac{\partial k}{\partial y} \right] $
$\sigma _k$ is given, the input of the neural network is $x$, $y$, and the output of the neural network is $u$, $v$, and $k$.
When constructing the physical loss $Loss_2$ of the above equation, $\frac{\partial k}{\partial y}$ needs to be used. The current FBPINN framework uses required_ujs_phys to callback gradients, as shown in the following code framework:
def sample_constraints(all_params, domain, key, sampler, batch_shapes):
# physics loss
y_batch_phys = domain.sample_interior(all_params, key, sampler, batch_shapes[0])
required_ujs_phys = (
(0,()), # u
(1,()), # v
(2,()), # k
(2,(1,)), # k_y
)
return [[y_batch_phys, required_ujs_phys]]
This causes a problem: I can't calculate the gradient of $\frac{\partial}{\partial y}\left[ \left( v+\frac{v_t}{\sigma _k} \right) \frac{\partial k}{\partial y} \right] $, because it's a mixed second-order gradient that requires the first-order $\frac{\partial k}{\partial y}$ to calculate the final gradient. It can't be recalled through required_ujs_phys.
This kind of composite gradient is quite common. Do you have any good suggestions to solve this problem?
Thank you for your reading!
Thank you for sharing your work, it's very interesting! The new version using JAX is indeed much faster, but I'm not very familiar with it (I use PyTorch more). Recently, when solving a PDE, I encountered this problem:
When constructing the physical loss$Loss_2$ of the above equation, $\frac{\partial k}{\partial y}$ needs to be used. The current FBPINN framework uses
required_ujs_physto callback gradients, as shown in the following code framework:This causes a problem: I can't calculate the gradient of$\frac{\partial}{\partial y}\left[ \left( v+\frac{v_t}{\sigma _k} \right) \frac{\partial k}{\partial y} \right] $ , because it's a mixed second-order gradient that requires the first-order $\frac{\partial k}{\partial y}$ to calculate the final gradient. It can't be recalled through
required_ujs_phys.This kind of composite gradient is quite common. Do you have any good suggestions to solve this problem?
Thank you for your reading!