Skip to content

JAX-FEM with MLP #80

@leedz0210

Description

@leedz0210

Hello Author,

Thank you for creating such a useful tool. I have a question regarding the integration of JAX-FEM with neural networks.

Question: Can JAX-FEM be incorporated as a module within a neural network to serve as part of the training process?

Background: I am currently using FEniCS, and when training with MLP (Multi-Layer Perceptron), I encounter the following error:

jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on traced array with shape float32

This error occurs because FEniCS is not compatible with JAX's automatic differentiation system, as it tries to convert JAX traced arrays to NumPy arrays during execution.

My Questions:

Can JAX-FEM be directly used as part of a neural network's computational graph?
Does JAX-FEM support end-to-end differentiation through the FEM solver?
Do you have any related work or examples demonstrating this integration?
I would greatly appreciate any guidance or references you could provide.

Thank you very much!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions