Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
55663e3
Add JAX backend support for VariableValue callback
bonneted Mar 27, 2024
3b141cc
fix copy paste
bonneted Mar 27, 2024
ab626f7
Merge branch 'lululxvi:master' into jax-variableValue
bonneted Apr 16, 2024
cb562bc
create array wrapper for variable value
bonneted Apr 17, 2024
c51eac6
restore blank line
bonneted Apr 17, 2024
ccbb394
refactor SPINN with list handler
bonneted Apr 18, 2024
e49df54
Add SPINN module to jax/__init__.py
bonneted Apr 18, 2024
c177d86
Add concat and stack functions to jax/tensor.py
bonneted Apr 18, 2024
c639977
create full class for spinn
bonneted Apr 18, 2024
e25523a
fix dimension handling
bonneted Apr 18, 2024
7a16482
Refactor input and output transform handling in NN class
bonneted Apr 18, 2024
51c78de
handle list input for testing
bonneted Apr 18, 2024
d27dcf8
refactor decorator
bonneted Apr 18, 2024
1906448
implement helmot spinn
bonneted Apr 18, 2024
2646d35
Add support for JAX backend in Helmoltz_Dirichlet_2d.py
bonneted Apr 19, 2024
f86641a
Refactor input and output transform handling
bonneted Apr 19, 2024
59a4cc7
reshape instead of squezze to keep one dim
bonneted Apr 22, 2024
a48ba25
dimension handling
bonneted Apr 22, 2024
d62f604
add spinn sampling for hypercube
bonneted Apr 30, 2024
149d7b9
refactor variable wrapper
bonneted Apr 30, 2024
8340f25
remove empty line
bonneted May 2, 2024
3d64359
remove empty line
bonneted May 3, 2024
63de873
fix line
bonneted May 3, 2024
de5c8d9
L-BFGS not implemented in jax
bonneted May 6, 2024
68ab04f
Merge branch 'jax-variableValue' into SPINN-only
bonneted May 6, 2024
32b760f
fix number of sampling points
bonneted May 23, 2024
facfcbf
scale factor for trainable variables
bonneted May 23, 2024
d6f29a9
uniform spinn sampling
bonneted Jun 14, 2024
b74eb4d
undo scale factor
bonneted Jun 17, 2024
9871395
remove save model
bonneted Jun 17, 2024
d9ffe22
Merge branch 'master' into SPINN-only
bonneted Jun 17, 2024
98d414c
update Helmotz example
bonneted Jun 17, 2024
37110d1
snn in a different file
bonneted Jun 17, 2024
719a77b
spinn credit
bonneted Jun 17, 2024
8bc071d
fix spinn sampling
bonneted Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 71 additions & 29 deletions deepxde/data/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .. import backend as bkd
from .. import config
from ..backend import backend_name
from ..utils import get_num_args, run_if_all_none, mpi_scatter_from_rank0
from ..utils import get_num_args, run_if_all_none, mpi_scatter_from_rank0, list_handler


class PDE(Data):
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(
solution=None,
num_test=None,
auxiliary_var_function=None,
is_SPINN=False,
):
self.geom = geometry
self.pde = pde
Expand All @@ -108,10 +109,18 @@ def __init__(
self.anchors = None if anchors is None else anchors.astype(config.real(np))
self.exclusions = exclusions

self.soln = solution
if solution is not None:
@list_handler
def solution_handling_list(inputs):
return solution(inputs)
self.soln = solution_handling_list
else:
self.soln = solution

self.num_test = num_test

self.auxiliary_var_fn = auxiliary_var_function
self.is_SPINN = is_SPINN

# TODO: train_x_all is used for PDE losses. It is better to add train_x_pde
# explicitly.
Expand All @@ -128,29 +137,41 @@ def __init__(
self.test()

def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):

bcs_start = np.cumsum([0] + self.num_bcs)
bcs_start = list(map(int, bcs_start))

if self.is_SPINN:
num_bcs_output = [num_bc**2 for num_bc in self.num_bcs]
bcs_start_output = np.cumsum([0] + num_bcs_output)
bcs_start_output = list(map(int, bcs_start_output))
else:
bcs_start_output = bcs_start

if backend_name in ["tensorflow.compat.v1", "paddle"]:
outputs_pde = outputs
outputs_pde = outputs[bcs_start_output[-1] :]
elif backend_name in ["tensorflow", "pytorch"]:
if config.autodiff == "reverse":
outputs_pde = outputs
outputs_pde = outputs[bcs_start_output[-1] :]
elif config.autodiff == "forward":
# forward-mode AD requires functions
outputs_pde = (outputs, aux[0])
outputs_pde = (outputs[bcs_start_output[-1] :], aux[0])
elif backend_name == "jax":
# JAX requires pure functions
outputs_pde = (outputs, aux[0])
outputs_pde = (outputs[bcs_start_output[-1] :], aux[0])

inputs_pde = inputs[-1] if isinstance(inputs, (list, tuple)) else inputs[bcs_start[-1] :]
f = []
if self.pde is not None:
if get_num_args(self.pde) == 2:
f = self.pde(inputs, outputs_pde)
f = self.pde(inputs_pde, outputs_pde)
elif get_num_args(self.pde) == 3:
if self.auxiliary_var_fn is None:
if aux is None or len(aux) == 1:
raise ValueError("Auxiliary variable function not defined.")
f = self.pde(inputs, outputs_pde, unknowns=aux[1])
f = self.pde(inputs_pde, outputs_pde, unknowns=aux[1])
else:
f = self.pde(inputs, outputs_pde, model.net.auxiliary_vars)
f = self.pde(inputs_pde, outputs_pde, model.net.auxiliary_vars)
if not isinstance(f, (list, tuple)):
f = [f]

Expand All @@ -163,16 +184,19 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
)
)

bcs_start = np.cumsum([0] + self.num_bcs)
bcs_start = list(map(int, bcs_start))
error_f = [fi[bcs_start[-1] :] for fi in f]

error_f = f
losses = [
loss_fn[i](bkd.zeros_like(error), error) for i, error in enumerate(error_f)
]
for i, bc in enumerate(self.bcs):
beg, end = bcs_start[i], bcs_start[i + 1]
# The same BC points are used for training and testing.
error = bc.error(self.train_x, inputs, outputs, beg, end)
if isinstance(inputs, (list, tuple)):
beg, end = bcs_start_output[i], bcs_start_output[i + 1]
error = bc.error(self.train_x, inputs[i], outputs[beg:end,:], 0, end-beg)
else:
beg, end = bcs_start_output[i], bcs_start_output[i + 1]
# The same BC points are used for training and testing.
error = bc.error(self.train_x, inputs, outputs, beg, end)
losses.append(loss_fn[len(error_f) + i](bkd.zeros_like(error), error))
return losses

Expand All @@ -194,7 +218,10 @@ def train_next_batch(self, batch_size=None):
if config.parallel_scaling == "strong":
self.train_x_all = mpi_scatter_from_rank0(self.train_x_all)
if self.pde is not None:
self.train_x = np.vstack((self.train_x, self.train_x_all))
if self.is_SPINN and len(self.train_x) > 0:
self.train_x = self.train_x + [self.train_x_all]
else:
self.train_x = np.vstack((self.train_x, self.train_x_all))
self.train_y = self.soln(self.train_x) if self.soln else None
if self.auxiliary_var_fn is not None:
self.train_aux_vars = self.auxiliary_var_fn(self.train_x).astype(
Expand Down Expand Up @@ -247,7 +274,11 @@ def replace_with_anchors(self, anchors):
self.train_x_all = self.anchors
self.train_x = self.bc_points()
if self.pde is not None:
self.train_x = np.vstack((self.train_x, self.train_x_all))
if self.is_SPINN and len(self.train_x) > 0:
self.train_x = self.train_x + [self.train_x_all]
else:
self.train_x = np.vstack((self.train_x, self.train_x_all))

self.train_y = self.soln(self.train_x) if self.soln else None
if self.auxiliary_var_fn is not None:
self.train_aux_vars = self.auxiliary_var_fn(self.train_x).astype(
Expand All @@ -258,12 +289,15 @@ def replace_with_anchors(self, anchors):
def train_points(self):
X = np.empty((0, self.geom.dim), dtype=config.real(np))
if self.num_domain > 0:
if self.train_distribution == "uniform":
X = self.geom.uniform_points(self.num_domain, boundary=False)
if self.is_SPINN:
X = self.geom.uniform_spinn_points(self.num_test, boundary=False)
else:
X = self.geom.random_points(
self.num_domain, random=self.train_distribution
)
if self.train_distribution == "uniform":
X = self.geom.uniform_points(self.num_domain, boundary=False)
else:
X = self.geom.random_points(
self.num_domain, random=self.train_distribution
)
if self.num_boundary > 0:
if self.train_distribution == "uniform":
tmp = self.geom.uniform_boundary_points(self.num_boundary)
Expand All @@ -287,17 +321,25 @@ def is_not_excluded(x):
def bc_points(self):
x_bcs = [bc.collocation_points(self.train_x_all) for bc in self.bcs]
self.num_bcs = list(map(len, x_bcs))
self.train_x_bc = (
np.vstack(x_bcs)
if x_bcs
else np.empty([0, self.train_x_all.shape[-1]], dtype=config.real(np))
)
if self.is_SPINN:
self.train_x_bc = x_bcs if x_bcs else np.empty([0, self.train_x_all.shape[-1]], dtype=config.real(np))
else:
self.train_x_bc = (
np.vstack(x_bcs)
if x_bcs
else np.empty([0, self.train_x_all.shape[-1]], dtype=config.real(np))
)
return self.train_x_bc

def test_points(self):
# TODO: Use different BC points from self.train_x_bc
x = self.geom.uniform_points(self.num_test, boundary=False)
x = np.vstack((self.train_x_bc, x))
if self.is_SPINN:
x = self.geom.uniform_spinn_points(self.num_test, boundary=False)
if len(self.train_x_bc) > 0:
x = self.train_x_bc + [x]
else:
x = self.geom.uniform_points(self.num_test, boundary=False)
x = np.vstack((self.train_x_bc, x))
return x


Expand Down
29 changes: 29 additions & 0 deletions deepxde/geometry/geometry_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,35 @@ def uniform_points(self, n, boundary=True):
"Warning: {} points required, but {} points sampled.".format(n, len(x))
)
return x

def uniform_spinn_points(self, n, boundary=True):
dx = (self.volume / n) ** (1 / self.dim)
xi = []
for i in range(self.dim):
ni = int(np.ceil(self.side_length[i] / dx))
if boundary:
xi.append(
np.linspace(
self.xmin[i], self.xmax[i], num=ni, dtype=config.real(np)
)
)
else:
xi.append(
np.linspace(
self.xmin[i],
self.xmax[i],
num=ni + 1,
endpoint=False,
dtype=config.real(np),
)[1:]
)
x = np.array(xi).T
if n != np.prod([len(x) for x in xi]):
print(
"Warning: {} points required, but {} points sampled.".format(n, np.prod([len(x) for x in xi]))
)
return x


def random_points(self, n, random="pseudo"):
x = sample(n, self.dim, random)
Expand Down
12 changes: 9 additions & 3 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,12 @@ def _compile_jax(self, lr, loss_fn, decay):
# Initialize the network's parameters
if self.params is None:
key = jax.random.PRNGKey(config.jax_random_seed)
self.net.params = self.net.init(key, self.data.test()[0])
X_test = (
self.data.test()[0][-1]
if isinstance(self.data.test()[0], (list, tuple))
else self.data.test()[0]
)
self.net.params = self.net.init(key, X_test)
external_trainable_variables_arr = [
var.value for var in self.external_trainable_variables
]
Expand All @@ -384,6 +389,7 @@ def _compile_jax(self, lr, loss_fn, decay):
self.opt_state = self.opt.init(self.params)

@jax.jit
@utils.list_handler
def outputs(params, training, inputs):
return self.net.apply(params, inputs, training=training)

Expand All @@ -392,9 +398,9 @@ def outputs_losses(params, training, inputs, targets, losses_fn):

# TODO: Add auxiliary vars
def outputs_fn(inputs):
return self.net.apply(nn_params, inputs, training=training)
return outputs(nn_params, training, inputs)

outputs_ = self.net.apply(nn_params, inputs, training=training)
outputs_ = outputs(nn_params, training, inputs)
# Data losses
# We use aux so that self.data.losses is a pure function.
aux = [outputs_fn, ext_params] if ext_params else [outputs_fn]
Expand Down
3 changes: 2 additions & 1 deletion deepxde/nn/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Package for jax NN modules."""

__all__ = ["FNN", "NN", "PFNN"]
__all__ = ["FNN", "NN", "PFNN", "SPINN"]

from .snn import SPINN
from .fnn import FNN, PFNN
from .nn import NN
Loading