Skip to content

Convergence issue for 1D time domain Maxwell equation in the x-t domain. #2018

@liuspace

Description

@liuspace

Dear developers and users,

Recently I am exploring deepxde for solving Maxwell equation. I start with standing waves and PINN works pretty well for that. Then I came to a traveling wave problem.
The x-t domain is as [-1, 1]x[0, 1], and at x =0 , the gaussian pulse for E field is defined as, E = exp(-((t - mean)2 / (2 * std_dev2)))

The gaussian pusle will travel to both +x and -x direction. The input is (x,t) and output is (E,H).

At t=0, E=0, H=0 (The E field can be very small by suitably choosing mean and std_dev). At x=-1 and x=1, E=0 and H=0(because x=1 is far enough to make sure the wave doesn't reach the boundary).

Then for the excitation (Gaussian pusle), two arrays Xs and ys are created. Xs denotes the discrete locations, e.g., (0, 0), (0,0.1),(0,0.2),...(0,1.0).
ys denotes the E field from the gaussian puslse at these Xs locations. Then ic_s = dde.icbc.PointSetBC(Xs, ys, component=0) is used to impose this excitation. Does this make sense? Do you have any other suggestions?

Then the training points X_star are manually generated to avoid x = 0 (the source location). Finally in dde.data.TimePDE, anchor is used.
The convergence is stalled at roughly 1e-3 for the non-boundary loss.

Image

The final E and H filed are given as follows,

Image

This expected solution should be before t=0.45, both E and H should be zero. This may relate to the convergence stall. Do you have any suggestions?

The code is given as follows,

import os
os.environ['DDE_BACKEND'] = 'pytorch'
import deepxde as dde
print(f"Current DeepXDE backend: {dde.backend.backend_name}")

import numpy as np
import deepxde as dde

For plotting

import matplotlib.pyplot as plt
from scipy.interpolate import griddata

--- Domain and Constants ---

Use a normalized domain

x_lower = -1.0
x_upper = 1.0
t_lower = 0.0
t_upper = 1

Characteristic scales for normalization

L0 = 1.0 # Characteristic length (meters)
c0 = 299792458.0 # Speed of light in vacuum (m/s)
T0 = L0 / c0 # Characteristic time

print(f"Normalized domain: x in [{x_lower}, {x_upper}], t in [{t_lower}, {t_upper}]")
print(f"Characteristic length (L0): {L0} m")
print(f"Characteristic time (T0): {T0} s")

Creation of the 2D domain for plotting and sampling

nx = 128
deltx = (x_upper-x_lower)/(nx - 1)
x = np.linspace(x_lower, x_upper, nx)#256)
nt = 128
deltt = (t_upper-t_lower)/(nt - 1)
t = np.linspace(t_lower, t_upper, nt )#256)
X_plot, T_plot = np.meshgrid(x, t)

The manully sampled training points

X_star = np.hstack((X_plot.flatten()[:, None], T_plot.flatten()[:, None]))

Space and time domains/geometry for the deepxde model

space_domain = dde.geometry.Interval(x_lower, x_upper)
time_domain = dde.geometry.TimeDomain(t_lower, t_upper)
geomtime = dde.geometry.GeometryXTime(space_domain, time_domain)

--- The Physics-Informed Part of the Loss ---

def pde(x, y):
"""
Defines the normalized 1D Maxwell's equations.

The normalized equations are:
du/dt + dv/dx = 0
dv/dt + du/dx = 0

INPUTS:
    x: x[:,0] is the non-dimensional x-coordinate
       x[:,1] is the non-dimensional t-coordinate
    y: Network output, in this case:
       y[:,0] is u(x,t) the real part
       y[:,1] is v(x,t) the imaginary part
OUTPUT:
    The pde in standard form i.e. something that must be zero.
"""
u = y[:, 0:1]
v = y[:, 1:2]

# Compute first-order derivatives
u_t = dde.grad.jacobian(y, x, i=0, j=1)
v_t = dde.grad.jacobian(y, x, i=1, j=1)

u_x = dde.grad.jacobian(y, x, i=0, j=0)
v_x = dde.grad.jacobian(y, x, i=1, j=0)

# Normalized PDE residual terms
f_u = u_t + v_x  # Note the signs might depend on the specific field definitions
f_v = v_t + u_x

return [f_u, f_v]

Spatial boundaries (x=0 and x=1)

def boundary_l(x, on_boundary):
return on_boundary and np.isclose(x[0], x_lower)

def boundary_r(x, on_boundary):
return on_boundary and np.isclose(x[0], x_upper)

bc_u_l = dde.icbc.DirichletBC(geomtime, lambda x: 0, boundary_l, component=0)
bc_u_r = dde.icbc.DirichletBC(geomtime, lambda x: 0, boundary_r, component=0)
bc_v_l = dde.icbc.DirichletBC(geomtime, lambda x: 0, boundary_l, component=1)
bc_v_r = dde.icbc.DirichletBC(geomtime, lambda x: 0, boundary_r, component=1)

--- Source excitation and Initial Conditions ---

time_start = t_lower
time_end = t_upper
delta_t = t_upper/(nt-1)
time_values = np.arange(t_lower, t_upper+delta_t, delta_t)
first_column = np.zeros_like(time_values)
Xs = np.column_stack((first_column, time_values))

Define the function that will generate E values based on t

def func(t):
"""
Generates yd values as a Gaussian pulse based on t.
Adjust amplitude, mean, and standard deviation as needed.
"""
amplitude = 1.0 # Maximum height of the pulse
mean = 0.5 # Center of the pulse along the t-axis
std_dev = 0.05 # Controls the width of the pulse (smaller value = narrower pulse)
return amplitude * np.exp(-((t - mean)2 / (2 * std_dev2)))

Generate the ys array by applying the function to the t-coordinates

Reshape to ensure it's a column vector (N, 1)

ys = func(Xs[:,1:2]).reshape(-1, 1)

ic_s = dde.icbc.PointSetBC(Xs, ys, component=0)

Initial conditions (t=0)

def init_cond_u(x):
return 0.0 * x[:, 0:1]

def init_cond_v(x):
return 0.0 * x[:, 0:1]

ic_u = dde.icbc.IC(geomtime, init_cond_u, lambda _, on_initial: on_initial, component=0)
ic_v = dde.icbc.IC(geomtime, init_cond_v, lambda _, on_initial: on_initial, component=1)

#Build the whole sampling points
X_input = np.vstack((X_star, Xs))

Combine all conditions and the PDE

data = dde.data.TimePDE(
geomtime,
pde,
[bc_u_l, bc_u_r,bc_v_l, bc_v_r, ic_u, ic_v, ic_s],
num_domain=0,
num_boundary=0,
num_initial=0,
anchors = X_input
)

--- Network Architecture and Training ---

net = dde.nn.FNN([2] + [20] * 4 + [2], "tanh", "Glorot uniform")

model = dde.Model(data, net)

Compile with Adam optimizer first

model.compile("adam", lr=5e-3, loss="MSE")#, loss_weights=[1e4, 1e4, 1, 1, 1e0, 1e0, 1e4,1e4,1e4])
losshistory, train_state = model.train(iterations=60000, display_every=1000)
model.compile("L-BFGS")
model.train()

--- Make Prediction and Plot ---

prediction = model.predict(X_star, operator=None)

u_pred = prediction[:, 0].reshape(nx, nt)
v_pred = prediction[:, 1].reshape(nx, nt)

Convert back to physical units for visualization (optional)

x_phys = X_plot * 1.0
t_phys = T_plot * 1.0

Plot predictions

fig, ax = plt.subplots(2, 1, figsize=(10, 8))

Plot E

im_u = ax[0].imshow(
u_pred.T,
interpolation="nearest",
cmap="coolwarm",
extent=[t_phys.min(), t_phys.max(), x_phys.min(), x_phys.max()],
origin="lower",
aspect="auto",
)
#contour = ax[0].contour(u_pred.T, levels=10, colors='black', linewidths=1.5)
ax[0].set_title("Normalized Solution (E)")
ax[0].set_xlabel("Time [s]")
ax[0].set_ylabel("Position [m]")
fig.colorbar(im_u, ax=ax[0])

Plot H

im_v = ax[1].imshow(
v_pred.T,
interpolation="nearest",
cmap="coolwarm",
extent=[t_phys.min(), t_phys.max(), x_phys.min(), x_phys.max()],
origin="lower",
aspect="auto",
)
#contour = ax[1].contour(v_pred.T, levels=10, colors='black', linewidths=1.5)
ax[1].set_title("Normalized Solution (H)")
ax[1].set_xlabel("Time [s]")
ax[1].set_ylabel("Position [m]")
fig.colorbar(im_v, ax=ax[1])

plt.tight_layout()
plt.show()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions