-
Notifications
You must be signed in to change notification settings - Fork 3
Description
Hi @a-latyshev,
I've tried to utilize external operators for a relatively simple thermoelastic problem at finite deformations. Using mixed elements, I want to solve for displacements and temperature monolithically. Thus, those two fields act as the operands for my stress quantity, which is an external operator. Here is the example I calculated.
import numpy as np
import jax
import jax.numpy as jnp
from functools import partial
from mpi4py import MPI
import ufl
import basix
from dolfinx import fem, io, mesh
from dolfinx_external_operator import (
FEMExternalOperator,
evaluate_external_operators,
evaluate_operands,
replace_external_operators,
)
from dolfinx.fem.petsc import NonlinearProblem
from utilities import assemble_residual_with_callback
# ======================================== JAX options ================================================
jax.config.update("jax_platforms", "cpu")
jax.config.update("jax_enable_x64", True)
# ====================================== Auxillary functions ==========================================
def jaxt2n(A, x):
return jnp.array([A[0,0], A[1,1], A[2,2], x * A[0,1], x * A[0,2], x * A[1,2]])
def jaxn2t(v, x):
return jnp.array([[v[0], v[3] / x, v[4] / x],
[v[3] / x, v[1], v[5] / x],
[v[4] / x, v[5] / x, v[2]]])
def uflt2n(A, x):
return ufl.as_vector([A[0,0], A[1,1], A[2,2], x * A[0,1], x * A[0,2], x * A[1,2]])
def ufln2t(v, x):
return ufl.as_tensor([[v[0], v[3] / x, v[4] / x],
[v[3] / x, v[1], v[5] / x],
[v[4] / x, v[5] / x, v[2]]])
# =========================================== Mesh ====================================================
length = 5.0
width = 1.0
height = 1.0
N = 4
domain = mesh.create_box(MPI.COMM_WORLD, [[0.0, 0.0, 0.0], [length, width, height]],
[5*N, N, N], mesh.CellType.hexahedron)
gdim = domain.topology.dim
fdim = gdim - 1
domain.topology.create_connectivity(fdim, gdim)
# Locator functions
def x_left(x):
return np.isclose(x[0], 0)
def y_front(x):
return np.isclose(x[1], 0)
def z_bottom(x):
return np.isclose(x[2], 0)
def x_right(x):
return np.isclose(x[0], length)
def y_back(x):
return np.isclose(x[1], width)
def z_top(x):
return np.isclose(x[2], height)
boundaries = [x_left, y_front, z_bottom, x_right, y_back, z_top]
facet_indices, facet_markers = [], []
for (marker, locator) in enumerate(boundaries):
facets = mesh.locate_entities_boundary(domain, fdim, locator)
facet_indices.append(facets)
facet_markers.append(np.full_like(facets, marker))
facet_indices = np.hstack(facet_indices).astype(np.int32)
facet_markers = np.hstack(facet_markers).astype(np.int32)
sorted_facets = np.argsort(facet_indices)
facet_tags = mesh.meshtags(domain, fdim, facet_indices[sorted_facets], facet_markers[sorted_facets])
# ============================ Element formulations and function spaces ===============================
L3e = basix.ufl.element("Lagrange", domain.basix_cell(), 1, shape=(3,))
L1e = basix.ufl.element("Lagrange", domain.basix_cell(), 1)
M4e = basix.ufl.mixed_element([L3e, L1e])
G6e = basix.ufl.quadrature_element(domain.basix_cell(), degree=2, value_shape=(6, ))
M4 = fem.functionspace(domain, M4e)
M4ux, _ = M4.sub(0).sub(0).collapse()
M4uy, _ = M4.sub(0).sub(1).collapse()
M4uz, _ = M4.sub(0).sub(2).collapse()
M4t, _ = M4.sub(1).collapse()
G6 = fem.functionspace(domain, G6e)
# ============================= Solution fields and initialization ===================================
w = fem.Function(M4, name="Solution field")
w_n = fem.Function(M4, name="Solution field (previous time-step)")
w.sub(1).interpolate(lambda x: np.full(x.shape[1], 293.15))
w_n.x.array[:] = w.x.array[:]
u, T = ufl.split(w)
u_n, T_n = ufl.split(w_n)
# ========================================== Time ====================================================
t = fem.Constant(domain, 0.0)
dt = fem.Constant(domain, 1.0e-02)
# ==================================== Kinematics ====================================================
Id = ufl.Identity(3)
F = Id + ufl.grad(u)
RCG = F.T * F
GLS = 0.5 * (RCG - Id)
GLS_nye = uflt2n(GLS, 2)
# ================================ External operator formulation =====================================
PK2_nye = FEMExternalOperator(GLS_nye, T, function_space=G6)
# Model parameters
E = 200e3
NU = 0.3
CT = 910e-6
ALPHA = 2.31e-5
COND = 237e-05
T_0 = 293.15
LAMBDA = E * NU / (1 - 2 * NU) / (1 + NU)
MU = E / 2 / (1 + NU)
KBULK = LAMBDA + 2 / 3 * MU
def psi(GLS_nye, T_):
T = T_[0]
GLS = jaxn2t(GLS_nye, 2)
Id = jnp.identity(3)
RCG = 2.0 * GLS + Id
I1 = jnp.trace(RCG)
I3 = jnp.linalg.det(RCG)
return MU / 2 * (I1 - 3 - jnp.log(I3)) + LAMBDA / 4 * (I3 - 1 - jnp.log(I3)) - 3 * KBULK * ALPHA * (T - T_0) * I3
# 2nd PK stress and material tangent
PK2 = jax.grad(psi, argnums=0)
dPK2dGLS = jax.jacfwd(PK2, argnums=0)
dPK2dT = jax.jacfwd(PK2, argnums=1)
# Vectorization over all Gauss-points
PK2_vec = jax.jit(jax.vmap(PK2, in_axes=(0, 0)))
dPK2dGLS_vec = jax.jit(jax.vmap(dPK2dGLS, in_axes=(0, 0)))
dPK2dT_vec = jax.jit(jax.vmap(dPK2dT, in_axes=(0, 0)))
def stress(GLS_glo, T_glo):
GLS_ = GLS_glo.reshape((-1, 6))
T_ = T_glo.reshape((-1, 1))
PK2_ = PK2_vec(GLS_, T_)
return PK2_.reshape(-1)
def tang01(GLS_glo, T_glo):
GLS_ = GLS_glo.reshape((-1, 6))
T_ = T_glo.reshape((-1, 1))
dPK2dGLS_ = dPK2dGLS_vec(GLS_, T_)
return dPK2dGLS_.reshape(-1)
def tang02(GLS_glo, T_glo):
GLS_ = GLS_glo.reshape((-1, 6))
T_ = T_glo.reshape((-1, 1))
dPK2dT_ = dPK2dT_vec(GLS_, T_)
return dPK2dT_.reshape(-1)
def PK2_external(derivatives):
if derivatives == (0, 0):
return stress
elif derivatives == (1, 0):
return tang01
elif derivatives == (0, 1):
return tang02
else:
raise NotImplementedError(f'There is no external function for the derivative {derivatives}.')
PK2_nye.external_function = PK2_external
q_0 = - COND * ufl.dot(ufl.inv(RCG), ufl.grad(T))
# ======================================== Weak form ================================================
dx = ufl.Measure('dx', domain=domain, metadata={'quadrature_degree': 2, 'quadrature_rule':'default'}) # Volume integration measure
vu, vT = ufl.TestFunctions(M4)
virGLS = ufl.sym(F.T * ufl.grad(vu))
virGLS_nye = uflt2n(virGLS, 2)
R_u = ufl.inner(PK2_nye, virGLS_nye) * dx
R_T = CT * (T - T_n) / dt * vT * dx - ufl.inner(q_0, ufl.grad(vT)) * dx
R = R_u + R_T
dw = ufl.TrialFunction(M4)
J = ufl.derivative(R, w, dw)
J_expanded = ufl.algorithms.expand_derivatives(J)
R_replaced, R_external_operators = replace_external_operators(R)
J_replaced, J_external_operators = replace_external_operators(J_expanded)
for ex_op in J_external_operators:
if ex_op.derivatives == (0,0):
PK2_external_operators = [ex_op]
elif ex_op.derivatives == (1,0):
pdPK2dGLS_external_operators = [ex_op]
elif ex_op.derivatives == (0,1):
pdPK2dT_external_operators = [ex_op]
# ================================== Boundary conditions ============================================
ux_left_dofs = fem.locate_dofs_topological((M4.sub(0).sub(0), M4ux), facet_tags.dim, facet_tags.find(0))
uy_front_dofs = fem.locate_dofs_topological((M4.sub(0).sub(1), M4uy), facet_tags.dim, facet_tags.find(1))
uz_bottom_dofs = fem.locate_dofs_topological((M4.sub(0).sub(2), M4uz), facet_tags.dim, facet_tags.find(2))
T_left_dofs = fem.locate_dofs_topological((M4.sub(1), M4t), facet_tags.dim, facet_tags.find(0))
T_right_dofs = fem.locate_dofs_topological((M4.sub(1), M4t), facet_tags.dim, facet_tags.find(3))
ux_left = fem.Function(M4ux)
uy_front = fem.Function(M4uy)
uz_bottom = fem.Function(M4uz)
T_left = fem.Function(M4t)
T_left.x.array[:] = fem.Constant(domain, 393.15)
T_right = fem.Function(M4t)
T_right.x.array[:] = fem.Constant(domain, 493.15)
bcs_1 = fem.dirichletbc(ux_left, ux_left_dofs, M4.sub(0).sub(0))
bcs_2 = fem.dirichletbc(uy_front, uy_front_dofs, M4.sub(0).sub(1))
bcs_3 = fem.dirichletbc(uz_bottom, uz_bottom_dofs, M4.sub(0).sub(2))
bcs_4 = fem.dirichletbc(T_left, T_left_dofs, M4.sub(1))
bcs_5 = fem.dirichletbc(T_right, T_right_dofs, M4.sub(1))
bcs = [bcs_1, bcs_2, bcs_3, bcs_4, bcs_5]
# ==================================== Solver setup =================================================
def constitutive_update():
evaluated_operands = evaluate_operands(PK2_external_operators)
((PK2_coeff), ) = evaluate_external_operators(PK2_external_operators, evaluated_operands)
((pdexopdGLS_coeff), ) = evaluate_external_operators(pdPK2dGLS_external_operators, evaluated_operands)
((pdexopdT_coeff), ) = evaluate_external_operators(pdPK2dT_external_operators, evaluated_operands)
PK2_nye.ref_coefficient.x.array[:] = PK2_coeff
petsc_options = {
"snes_type": "newtonls",
"snes_linesearch_type": "bt",
"ksp_type": "preonly",
"pc_type": "lu",
"pc_factor_mat_solver_type": "mumps",
"snes_atol": 1.0e-8,
"snes_rtol": 1.0e-8,
"snes_max_it": 100,
"snes_monitor": ""
}
problem = NonlinearProblem(R_replaced, w, J=J_replaced, bcs=bcs, petsc_options_prefix="coupledexo_", petsc_options=petsc_options)
assemble_residual_with_callback_ = partial(assemble_residual_with_callback, problem.u, problem._F, problem._J, bcs, constitutive_update)
problem.solver.setFunction(assemble_residual_with_callback_, problem.b)
# ==================================== Visualisation setup =================================================
results_file = "output/results.pvd"
pvd = io.VTKFile(domain.comm, results_file, "w")
pvd.write_mesh(domain)
# ======================================== Loading loop ====================================================
TIME_MAX = 1.0
while t.value <= TIME_MAX:
problem.solve()
converged_reason = problem.solver.getConvergedReason()
assert converged_reason > 0
n_iter = problem.solver.getIterationNumber()
w.x.scatter_forward()
w_n.x.array[:] = w.x.array[:]
if MPI.COMM_WORLD.rank == 0:
print(f"Time {t.value}, Number of iterations {n_iter}")
u_out = w.sub(0).collapse()
u_out.name = "Displacements"
T_out = w.sub(1).collapse()
T_out.name = "Temperature"
pvd.write_function([u_out, T_out], t=t)
t.value += dt.value
pvd.close()It works fine, and the results are identical to a pure ufl implementation. I'm still confused by some aspects of it:
- In the
constitutive_update()I had to evaluate the external operators for the material tangents as well, else the solver would fail. From my understanding updating the stress values should be enough. Am I missing something here? - I realized that for some reason,
J_external_operatorsdoes not show the strain as an operand, instead it shows temperature twice. Here is how i checked it:
evaluated_operands_R = evaluate_operands(R_external_operators)
for (operand_number, operand_value) in enumerate(evaluated_operands_R.values()):
print(f'RESIDUUM operand number {operand_number}, Operand value = {operand_value}')
evaluated_operands_J = evaluate_operands(J_external_operators)
for (operand_number, operand_value) in enumerate(evaluated_operands_J.values()):
print(f'JACOBIAN operand number {operand_number}, Operand value = {operand_value}')
for ex_op in J_external_operators:
if ex_op.derivatives == (0,0):
PK2_external_operators = [ex_op]
elif ex_op.derivatives == (1,0):
pdPK2dGLS_external_operators = [ex_op]
elif ex_op.derivatives == (0,1):
pdPK2dT_external_operators = [ex_op]
evaluated_operands_PK2 = evaluate_operands(PK2_external_operators)
for (operand_number, operand_value) in enumerate(evaluated_operands_PK2.values()):
print(f'PK2 operand number {operand_number}, Operand value = {operand_value}')
evaluated_operands_pdPK2dGLS = evaluate_operands(pdPK2dGLS_external_operators)
for (operand_number, operand_value) in enumerate(evaluated_operands_pdPK2dGLS.values()):
print(f'pdPK2dGLS operand number {operand_number}, Operand value = {operand_value}')
evaluated_operands_pdPK2dT = evaluate_operands(pdPK2dT_external_operators)
for (operand_number, operand_value) in enumerate(evaluated_operands_pdPK2dT.values()):
print(f'pdPK2dT operand number {operand_number}, Operand value = {operand_value}')This is the output I get for a single element:
RESIDUUM operand number 0, Operand value = [[[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]]]
RESIDUUM operand number 1, Operand value = [[293.15 293.15 293.15 293.15 293.15 293.15 293.15 293.15]]
JACOBIAN operand number 0, Operand value = [[293.15 293.15 293.15 293.15 293.15 293.15 293.15 293.15]]
JACOBIAN operand number 1, Operand value = [[293.15 293.15 293.15 293.15 293.15 293.15 293.15 293.15]]
PK2 operand number 0, Operand value = [[[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]]]
PK2 operand number 1, Operand value = [[293.15 293.15 293.15 293.15 293.15 293.15 293.15 293.15]]
pdPK2dGLS operand number 0, Operand value = [[[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]]]
pdPK2dGLS operand number 1, Operand value = [[293.15 293.15 293.15 293.15 293.15 293.15 293.15 293.15]]
pdPK2dT operand number 0, Operand value = [[[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]]]
pdPK2dT operand number 1, Operand value = [[293.15 293.15 293.15 293.15 293.15 293.15 293.15 293.15]]As you can see, the operands are correct for the residuum and for the entries of the jacobian, but not for the jacobian itself.
Currently, I'm trying to extend the example to a fully coupled one, by introducing heating due to mechanical deformation. I'm encountering convergence issues and I'm wondering if the above "issues" could be the cause. From our previous exchange I know that mixed elements are currnently not you focus, still I wanted to ask if there is something important to consider when using external operators in a manner as I did?
Anything which might compromise using external operators in combination with mixed elements?
Best regards and thanks in advance
Nadir