Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 115 additions & 17 deletions deepxde/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,38 +568,136 @@ def on_train_end(self):


class PDEPointResampler(Callback):
"""Resample the training points for PDE and/or BC losses every given period.
"""
Resample the training points for PDE and/or BC losses every given period.

Args:
period: How often to resample the training points (default is 100 iterations).
pde_points: If True, resample the training points for PDE losses (default is
True).
bc_points: If True, resample the training points for BC losses (default is
False; only supported by PyTorch and PaddlePaddle backend currently).
period (int): How often to resample the training points (default is 100 iterations).

name (str): Choose one of the following options:
- 'static': Resample using the same method that was used initially.
- 'RAR-G': Residual-based adaptive refinement with greed (only supported by PyTorch).
- 'RAD': Residual-based adaptive distribution (only supported by PyTorch).
- 'RAR-D': Residual-based adaptive refinement distribution (only supported by PyTorch).

sampler (dict): Specify the sampler to resample the training points (default is {'pde_points': True,
'bc_points': False,
'k': 1.0,
'c': 0.0,
'number_of_points': 100}).

Only for non-residual-based methods:

- 'pde_points' (bool): Whether to resample the PDE points (default is True).
- 'bc_points' (bool): Whether to resample the BC points (default is False; always False for residual-based methods).

Only for distribution-based adaptive refinement methods:

- 'k' (float): The exponent for the residuals (default is 1.0).
- 'c' (float): A constant determining the 'strength' of the PDF compared to randomness (default is 0.0).

Only for residual-based adaptive refinement methods:

- 'number_of_points' (int): The number of points sampled from the PDF and added to the original PDE training points.

save (bool): Whether to save the new training points in the model's data (default is False).

For more information on the residual-based adaptive refinement methods, see the paper: https://www.sciencedirect.com/science/article/pii/S0045782522006260
"""

def __init__(self, period=100, pde_points=True, bc_points=False):

def __init__(self, period:int=100, name:str='static', sampler_config:dict={}):
super().__init__()

default_config = {'pde_points': True, 'bc_points': False, 'k': 1.0, 'c': 0.0, 'number_of_points': 100}
default_config.update(sampler_config)

self.period = period
self.pde_points = pde_points
self.bc_points = bc_points

self.name = name
self.num_bcs_initial = None
self.epochs_since_last_resample = 0
self.sampler_config = default_config

def on_train_begin(self):
self.num_bcs_initial = self.model.data.num_bcs

def generate_dense_training_set(self, num_domain:int=None):
"""Generating a training set"""
X = np.empty((0, self.model.data.geom.dim), dtype=config.real(np))

if self.model.data.train_distribution == "uniform":
X = self.model.data.geom.uniform_points(num_domain, boundary=False)
else:
X = self.model.data.geom.random_points(num_domain, random=self.model.data.train_distribution)

return X

def generate_pdf(self, residual:np.array=None, k:float=1.0, c:float=0.0):
"""Generating the probability density function (PDF) based on the PDE residuals."""

eps = np.nan_to_num(residual, nan=0.0)
eps_k = np.abs(np.pow(eps, k))
pdf = (eps_k / np.sum(eps_k)) + c
if c != 0: pdf /= np.sum(pdf) # need to renormalize the PDF if c != 0 in order for np.random.choice(..., p=pdf) to work

return pdf

def get_residual(self, X:np.array=None):
"""Calculating the residual of the PDE for the given inputs."""

if backend_name == 'pytorch':
inputs = torch.as_tensor(X)
inputs.requires_grad_()
outputs = self.model.net(inputs)
residual = self.model.data.pde(inputs, outputs)
return residual.detach().cpu().numpy().flatten()
else:
raise ValueError("Unsupported backend.")

def on_epoch_end(self):
self.epochs_since_last_resample += 1
if self.epochs_since_last_resample < self.period:
return
self.epochs_since_last_resample = 0
self.model.data.resample_train_points(self.pde_points, self.bc_points)

if self.name == 'static':

self.model.data.resample_train_points(self.sampler_config['pde_points'], self.sampler_config['bc_points'])

if not np.array_equal(self.num_bcs_initial, self.model.data.num_bcs):
print("Initial value of self.num_bcs:", self.num_bcs_initial)
print("self.model.data.num_bcs:", self.model.data.num_bcs)
raise ValueError(
"`num_bcs` changed! Please update the loss function by `model.compile`."
)

elif self.name == 'RAR-G':

if not np.array_equal(self.num_bcs_initial, self.model.data.num_bcs):
print("Initial value of self.num_bcs:", self.num_bcs_initial)
print("self.model.data.num_bcs:", self.model.data.num_bcs)
raise ValueError(
"`num_bcs` changed! Please update the loss function by `model.compile`."
)
inputs = self.generate_dense_training_set(self.model.data.num_domain)
residual = self.get_residual(inputs)

indices = np.argpartition(residual,-self.sampler_config['number_of_points'])[-self.sampler_config['number_of_points']:]
self.model.data.add_anchors(inputs[indices, :])

elif self.name == 'RAD':

inputs = self.generate_dense_training_set(2*self.model.data.num_domain)
residual = self.get_residual(inputs)

pdf = self.generate_pdf(residual, self.sampler_config['k'], self.sampler_config['c'])

indices = np.random.choice(pdf.size, size=self.model.data.num_domain, p=pdf)
self.model.data.replace_with_anchors(inputs[indices, :])

elif self.name == 'RAR-D':

inputs = self.generate_dense_training_set(self.model.data.num_domain)
residual = self.get_residual(inputs)

pdf = self.generate_pdf(residual, self.sampler_config['k'], self.sampler_config['c'])

indices = np.random.choice(pdf.size, size=self.sampler_config['number_of_points'], p=pdf)
self.model.data.add_anchors(inputs[indices, :])

else:
raise ValueError("Unsupported sampling strategy.")
6 changes: 6 additions & 0 deletions examples/pinn_forward/Burgers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ def pde(x, y):
model = dde.Model(data, net)

model.compile("adam", lr=1e-3)

model.train(iterations=15000)
# Uncomment the following lines to use different resampling strategies
#model.train(iterations=15000, callbacks=[dde.callbacks.PDEPointResampler(period=500, name='RAR-D', sampler_config={'k':2, 'c':0.5})])
#model.train(iterations=15000, callbacks=[dde.callbacks.PDEPointResampler(period=500, name='RAD', sampler_config={'number_of_points': 1000, 'k':2})])
#model.train(iterations=15000, callbacks=[dde.callbacks.PDEPointResampler(period=500, name='RAR-G', sampler_config={})])

model.compile("L-BFGS")
losshistory, train_state = model.train()
# """Backend supported: pytorch"""
Expand Down