-
Notifications
You must be signed in to change notification settings - Fork 93
Description
Hello Author,
Thank you for creating such a useful tool. I have a question regarding the integration of JAX-FEM with neural networks.
Question: Can JAX-FEM be incorporated as a module within a neural network to serve as part of the training process?
Background: I am currently using FEniCS, and when training with MLP (Multi-Layer Perceptron), I encounter the following error:
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on traced array with shape float32
This error occurs because FEniCS is not compatible with JAX's automatic differentiation system, as it tries to convert JAX traced arrays to NumPy arrays during execution.
My Questions:
Can JAX-FEM be directly used as part of a neural network's computational graph?
Does JAX-FEM support end-to-end differentiation through the FEM solver?
Do you have any related work or examples demonstrating this integration?
I would greatly appreciate any guidance or references you could provide.
Thank you very much!