Description
Thank you for sharing your work, it's very interesting! The new version using JAX is indeed much faster, but I'm not very familiar with it (I use PyTorch more). Recently, when solving a PDE, I encountered this problem:
When constructing the physical loss required_ujs_phys
to callback gradients, as shown in the following code framework:
def sample_constraints(all_params, domain, key, sampler, batch_shapes):
# physics loss
y_batch_phys = domain.sample_interior(all_params, key, sampler, batch_shapes[0])
required_ujs_phys = (
(0,()), # u
(1,()), # v
(2,()), # k
(2,(1,)), # k_y
)
return [[y_batch_phys, required_ujs_phys]]
This causes a problem: I can't calculate the gradient of required_ujs_phys
.
This kind of composite gradient is quite common. Do you have any good suggestions to solve this problem?
Thank you for your reading!
Activity