-
Notifications
You must be signed in to change notification settings - Fork 894
Separable-PINN in DeepXDE #1776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
There are too many modifications. We can start with a mathematically equivalent (but slow speed) implementation by repeating the n inputs to n**2. This is similar to DeepONet deepxde/deepxde/nn/tensorflow/deeponet.py Line 17 in ad6399b
vs DeepONetCartesianProd deepxde/deepxde/nn/tensorflow/deeponet.py Line 153 in ad6399b
Then the code change would be very minimal. |
|
OK, I'll try that when I have more time, hopefully soon. |
This is an implementation of the SPINN model: https://jwcho5576.github.io/spinn.github.io/
The code for the network architecture (
snn.py) is directly adapted from the original paper (https://github.com/stnamjef/SPINN)I've achieved really fast convergence with this implementation of SPINN compared to PINN (similar to the paper claim), for both forward and inverse quantification on the linear elastic plate problem.
Forward comparison :
https://github.com/lululxvi/deepxde/assets/53513604/499a961d-748c-458f-be99-56156b516ace
Inverse with PINN :
https://github.com/lululxvi/deepxde/assets/53513604/ab89554d-b82b-406a-8d73-05b3f72a3961
Inverse with SPINN :
https://github.com/lululxvi/deepxde/assets/53513604/07171442-ea03-48b4-87a6-8b5094f6809c
The implementation was more complicated than expected for the following reasons:
due to its architecture, SPINN takes an input of size n and outputs an array of size n**dim (it does the cartesian product of each coordinate) :
(n,2) --> SPINN --> n**2This brings some difficulty with how inputs are handled in
data.pde.Indeed, all inputs are concatenated (PDE and BCS points) and throw the net simultaneously.
So if we have
n_PDEPDE points andn_BCBC points we will end up with(n_PDE+n_BC)**2points instead ofn_PDE**2+n_BC**2I tried to find a workaround with minimal changes to
model.py, and came up with the following:adding a
list_handlerdecorator to theoutputsfunction in JAX so that it can handle list inputs by applying the function to each input and then concatenates.I then modified the
pde.pyfile by adding ais_SPINNargument, if true, PDE and BC inputs are put together in a list instead of stacked. Thebcs_startshould also be modified as the outputs sizes no longer equal the inputs.I understand that this brings a lot of changes to
data.pde, so another possibility is to create a separatedatasubclass dedicated to SPINN so that thedata.pdeclass isn't overly complicated.