Skip to content

BUG: ML Operator with Mixed Space input and Single Space output #4125

Open
@tlroy

Description

@tlroy

Describe the bug
Currently bugged: ML operator that takes a mixed function as input, and outputs a single function as output, i.e. W -> V. It seems like only W->W and V->V are supported. The error is something about BC mismatch with trial space.

Steps to Reproduce
Steps to reproduce the behavior:

from firedrake import *
import firedrake.ml.pytorch as fd_ml
import torch
import torch.nn as nn
mesh = UnitIntervalMesh(4)
x = SpatialCoordinate(mesh)
V = FunctionSpace(mesh,"CG",1)
W = V*V
u = Function(W)
u1, u2 = split(u)
v1, v2 = TestFunctions(W)
model = nn.Linear(W.dim(), V.dim())
I = torch.eye(V.dim(), V.dim())
model.weight.data = torch.cat([I, I], dim=1)
model.bias.data = torch.zeros(V.dim())
model.eval()
ml_op = fd_ml.ml_operator(model,function_space=V,inputs_format=1)
i_surrogate = ml_op(u)
# This actually runs: 
# i_surrogate = ml_op(Function(W))
a = inner(grad(u1), grad(v1)) * dx
a += i_surrogate * v1 * ds(1) # equivalent to (u_1 + u_2) * v1 * ds(1)
a += inner(grad(u2), grad(v2)) * dx
bcs = [DirichletBC(W.sub(0), Constant(1), 2),
        DirichletBC(W.sub(1), Constant(1), 2)]
solve(a==0, u, bcs=bcs)

Error message

/home/firedrake/firedrake/lib/python3.12/site-packages/pytools/persistent_dict.py:52: RecommendedHashNotFoundWarning: Unable to import recommended hash 'siphash24.siphash13', falling back to 'hashlib.sha256'. Run 'python3 -m pip install siphash24' to install the recommended hash.
  warn("Unable to import recommended hash 'siphash24.siphash13', "
Traceback (most recent call last):
  File "petsc4py/PETSc/PETSc.pyx", line 348, in petsc4py.PETSc.PetscPythonErrorHandler
  File "petsc4py/PETSc/PETSc.pyx", line 348, in petsc4py.PETSc.PetscPythonErrorHandler
  File "petsc4py/PETSc/PETSc.pyx", line 348, in petsc4py.PETSc.PetscPythonErrorHandler
  File "petsc4py/PETSc/petscsnes.pxi", line 367, in petsc4py.PETSc.SNES_Jacobian
  File "/home/firedrake/firedrake/src/firedrake/firedrake/solving_utils.py", line 461, in form_jacobian
    ctx._assemble_jac(ctx._jac)
  File "/home/firedrake/firedrake/src/firedrake/firedrake/assemble.py", line 404, in assemble
    result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/firedrake/firedrake/src/firedrake/firedrake/assemble.py", line 638, in base_form_postorder_traversal
    visited[e] = visitor(e, *(visited[arg] for arg in operands))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/firedrake/firedrake/src/firedrake/firedrake/assemble.py", line 400, in visitor
    return self.base_form_assembly_visitor(e, t, *operands)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/firedrake/firedrake/src/firedrake/firedrake/assemble.py", line 445, in base_form_assembly_visitor
    return assembler.assemble(tensor=tensor)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/firedrake/firedrake/src/firedrake/firedrake/assemble.py", line 1022, in assemble
    self._apply_bc(tensor, bc, u=current_state)
  File "/home/firedrake/firedrake/src/firedrake/firedrake/assemble.py", line 1471, in _apply_bc
    raise TypeError("bc space does not match the trial function space")
TypeError: bc space does not match the trial function space
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/home/firedrake/shared/co2rue-model-scaleup/1Dexample/minimal2.py", line 26, in <module>
    solve(a==0, u, bcs=bcs)
  File "petsc4py/PETSc/Log.pyx", line 188, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "petsc4py/PETSc/Log.pyx", line 189, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "/home/firedrake/firedrake/src/firedrake/firedrake/adjoint_utils/solving.py", line 57, in wrapper
    output = solve(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/firedrake/firedrake/src/firedrake/firedrake/solving.py", line 144, in solve
    _solve_varproblem(*args, **kwargs)
  File "/home/firedrake/firedrake/src/firedrake/firedrake/solving.py", line 194, in _solve_varproblem
    solver.solve()
  File "petsc4py/PETSc/Log.pyx", line 188, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "petsc4py/PETSc/Log.pyx", line 189, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "/home/firedrake/firedrake/src/firedrake/firedrake/adjoint_utils/variational_solver.py", line 104, in wrapper
    out = solve(self, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/firedrake/firedrake/src/firedrake/firedrake/variational_solver.py", line 330, in solve
    self.snes.solve(None, work)
  File "petsc4py/PETSc/SNES.pyx", line 1724, in petsc4py.PETSc.SNES.solve
petsc4py.PETSc.Error: error code -1
[0] SNESSolve() at /home/firedrake/petsc/src/snes/interface/snes.c:4839
[0] SNESSolve_NEWTONLS() at /home/firedrake/petsc/src/snes/impls/ls/ls.c:218
[0] SNESComputeJacobian() at /home/firedrake/petsc/src/snes/interface/snes.c:2967

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions