Skip to content

Conversation

@bonneted
Copy link
Contributor

This is an implementation of the SPINN model: https://jwcho5576.github.io/spinn.github.io/

The code for the network architecture (snn.py) is directly adapted from the original paper (https://github.com/stnamjef/SPINN)

I've achieved really fast convergence with this implementation of SPINN compared to PINN (similar to the paper claim), for both forward and inverse quantification on the linear elastic plate problem.
Forward comparison :
https://github.com/lululxvi/deepxde/assets/53513604/499a961d-748c-458f-be99-56156b516ace

Inverse with PINN :
https://github.com/lululxvi/deepxde/assets/53513604/ab89554d-b82b-406a-8d73-05b3f72a3961

Inverse with SPINN :
https://github.com/lululxvi/deepxde/assets/53513604/07171442-ea03-48b4-87a6-8b5094f6809c

The implementation was more complicated than expected for the following reasons:
due to its architecture, SPINN takes an input of size n and outputs an array of size n**dim (it does the cartesian product of each coordinate) :
(n,2) --> SPINN --> n**2

This brings some difficulty with how inputs are handled in data.pde.
Indeed, all inputs are concatenated (PDE and BCS points) and throw the net simultaneously.
So if we have n_PDE PDE points and n_BC BC points we will end up with (n_PDE+n_BC)**2 points instead of n_PDE**2+n_BC**2

I tried to find a workaround with minimal changes to model.py, and came up with the following:
adding a list_handler decorator to the outputs function in JAX so that it can handle list inputs by applying the function to each input and then concatenates.

I then modified the pde.py file by adding a is_SPINN argument, if true, PDE and BC inputs are put together in a list instead of stacked. The bcs_start should also be modified as the outputs sizes no longer equal the inputs.

I understand that this brings a lot of changes to data.pde, so another possibility is to create a separate data subclass dedicated to SPINN so that the data.pde class isn't overly complicated.

@lululxvi
Copy link
Owner

lululxvi commented Jun 17, 2024

There are too many modifications. We can start with a mathematically equivalent (but slow speed) implementation by repeating the n inputs to n**2. This is similar to DeepONet

class DeepONet(NN):

vs DeepONetCartesianProd

class DeepONetCartesianProd(NN):

Then the code change would be very minimal.

@bonneted
Copy link
Contributor Author

bonneted commented Jul 2, 2024

OK, I'll try that when I have more time, hopefully soon.

@bonneted bonneted marked this pull request as draft July 2, 2024 09:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants