diff --git a/deepxde/callbacks.py b/deepxde/callbacks.py index 4854187e8..97826aab0 100644 --- a/deepxde/callbacks.py +++ b/deepxde/callbacks.py @@ -568,38 +568,136 @@ def on_train_end(self): class PDEPointResampler(Callback): - """Resample the training points for PDE and/or BC losses every given period. + """ + Resample the training points for PDE and/or BC losses every given period. Args: - period: How often to resample the training points (default is 100 iterations). - pde_points: If True, resample the training points for PDE losses (default is - True). - bc_points: If True, resample the training points for BC losses (default is - False; only supported by PyTorch and PaddlePaddle backend currently). + period (int): How often to resample the training points (default is 100 iterations). + + name (str): Choose one of the following options: + - 'static': Resample using the same method that was used initially. + - 'RAR-G': Residual-based adaptive refinement with greed (only supported by PyTorch). + - 'RAD': Residual-based adaptive distribution (only supported by PyTorch). + - 'RAR-D': Residual-based adaptive refinement distribution (only supported by PyTorch). + + sampler (dict): Specify the sampler to resample the training points (default is {'pde_points': True, + 'bc_points': False, + 'k': 1.0, + 'c': 0.0, + 'number_of_points': 100}). + + Only for non-residual-based methods: + + - 'pde_points' (bool): Whether to resample the PDE points (default is True). + - 'bc_points' (bool): Whether to resample the BC points (default is False; always False for residual-based methods). + + Only for distribution-based adaptive refinement methods: + + - 'k' (float): The exponent for the residuals (default is 1.0). + - 'c' (float): A constant determining the 'strength' of the PDF compared to randomness (default is 0.0). + + Only for residual-based adaptive refinement methods: + + - 'number_of_points' (int): The number of points sampled from the PDF and added to the original PDE training points. + + save (bool): Whether to save the new training points in the model's data (default is False). + + For more information on the residual-based adaptive refinement methods, see the paper: https://www.sciencedirect.com/science/article/pii/S0045782522006260 """ - def __init__(self, period=100, pde_points=True, bc_points=False): + + def __init__(self, period:int=100, name:str='static', sampler_config:dict={}): super().__init__() + + default_config = {'pde_points': True, 'bc_points': False, 'k': 1.0, 'c': 0.0, 'number_of_points': 100} + default_config.update(sampler_config) + self.period = period - self.pde_points = pde_points - self.bc_points = bc_points - + self.name = name self.num_bcs_initial = None self.epochs_since_last_resample = 0 + self.sampler_config = default_config def on_train_begin(self): self.num_bcs_initial = self.model.data.num_bcs + + def generate_dense_training_set(self, num_domain:int=None): + """Generating a training set""" + X = np.empty((0, self.model.data.geom.dim), dtype=config.real(np)) + + if self.model.data.train_distribution == "uniform": + X = self.model.data.geom.uniform_points(num_domain, boundary=False) + else: + X = self.model.data.geom.random_points(num_domain, random=self.model.data.train_distribution) + + return X + + def generate_pdf(self, residual:np.array=None, k:float=1.0, c:float=0.0): + """Generating the probability density function (PDF) based on the PDE residuals.""" + + eps = np.nan_to_num(residual, nan=0.0) + eps_k = np.abs(np.pow(eps, k)) + pdf = (eps_k / np.sum(eps_k)) + c + if c != 0: pdf /= np.sum(pdf) # need to renormalize the PDF if c != 0 in order for np.random.choice(..., p=pdf) to work + + return pdf + + def get_residual(self, X:np.array=None): + """Calculating the residual of the PDE for the given inputs.""" + + if backend_name == 'pytorch': + inputs = torch.as_tensor(X) + inputs.requires_grad_() + outputs = self.model.net(inputs) + residual = self.model.data.pde(inputs, outputs) + return residual.detach().cpu().numpy().flatten() + else: + raise ValueError("Unsupported backend.") def on_epoch_end(self): self.epochs_since_last_resample += 1 if self.epochs_since_last_resample < self.period: return self.epochs_since_last_resample = 0 - self.model.data.resample_train_points(self.pde_points, self.bc_points) + + if self.name == 'static': + + self.model.data.resample_train_points(self.sampler_config['pde_points'], self.sampler_config['bc_points']) + + if not np.array_equal(self.num_bcs_initial, self.model.data.num_bcs): + print("Initial value of self.num_bcs:", self.num_bcs_initial) + print("self.model.data.num_bcs:", self.model.data.num_bcs) + raise ValueError( + "`num_bcs` changed! Please update the loss function by `model.compile`." + ) + + elif self.name == 'RAR-G': - if not np.array_equal(self.num_bcs_initial, self.model.data.num_bcs): - print("Initial value of self.num_bcs:", self.num_bcs_initial) - print("self.model.data.num_bcs:", self.model.data.num_bcs) - raise ValueError( - "`num_bcs` changed! Please update the loss function by `model.compile`." - ) + inputs = self.generate_dense_training_set(self.model.data.num_domain) + residual = self.get_residual(inputs) + + indices = np.argpartition(residual,-self.sampler_config['number_of_points'])[-self.sampler_config['number_of_points']:] + self.model.data.add_anchors(inputs[indices, :]) + + elif self.name == 'RAD': + + inputs = self.generate_dense_training_set(2*self.model.data.num_domain) + residual = self.get_residual(inputs) + + pdf = self.generate_pdf(residual, self.sampler_config['k'], self.sampler_config['c']) + + indices = np.random.choice(pdf.size, size=self.model.data.num_domain, p=pdf) + self.model.data.replace_with_anchors(inputs[indices, :]) + + elif self.name == 'RAR-D': + + inputs = self.generate_dense_training_set(self.model.data.num_domain) + residual = self.get_residual(inputs) + + pdf = self.generate_pdf(residual, self.sampler_config['k'], self.sampler_config['c']) + + indices = np.random.choice(pdf.size, size=self.sampler_config['number_of_points'], p=pdf) + self.model.data.add_anchors(inputs[indices, :]) + + else: + raise ValueError("Unsupported sampling strategy.") \ No newline at end of file diff --git a/examples/pinn_forward/Burgers.py b/examples/pinn_forward/Burgers.py index 7b6883ed1..9f0086c16 100644 --- a/examples/pinn_forward/Burgers.py +++ b/examples/pinn_forward/Burgers.py @@ -36,7 +36,13 @@ def pde(x, y): model = dde.Model(data, net) model.compile("adam", lr=1e-3) + model.train(iterations=15000) +# Uncomment the following lines to use different resampling strategies +#model.train(iterations=15000, callbacks=[dde.callbacks.PDEPointResampler(period=500, name='RAR-D', sampler_config={'k':2, 'c':0.5})]) +#model.train(iterations=15000, callbacks=[dde.callbacks.PDEPointResampler(period=500, name='RAD', sampler_config={'number_of_points': 1000, 'k':2})]) +#model.train(iterations=15000, callbacks=[dde.callbacks.PDEPointResampler(period=500, name='RAR-G', sampler_config={})]) + model.compile("L-BFGS") losshistory, train_state = model.train() # """Backend supported: pytorch"""