Skip to content

Commit 8592029

Browse files
author
Kuangdai Leng
authored
Add ZCS (#1629)
1 parent 268ceca commit 8592029

File tree

6 files changed

+468
-0
lines changed

6 files changed

+468
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@ docs/_build/
2020

2121
# setuptools_scm
2222
deepxde/_version.py
23+
24+
# PyCharm
25+
.idea/

deepxde/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"utils",
1010
"Model",
1111
"Variable",
12+
"zcs",
1213
]
1314

1415
try:
@@ -26,6 +27,7 @@
2627
from . import icbc
2728
from . import nn
2829
from . import utils
30+
from . import zcs
2931

3032
from .backend import Variable
3133
from .model import Model

deepxde/zcs/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Enhancing the performance of DeepONets using Zero Coordinate Shift.
2+
3+
Reference: https://arxiv.org/abs/2311.00860
4+
"""
5+
6+
__all__ = [
7+
"LazyGrad",
8+
"Model",
9+
"PDEOperatorCartesianProd",
10+
]
11+
12+
from .gradient import LazyGrad
13+
from .model import Model
14+
from .operator import PDEOperatorCartesianProd

deepxde/zcs/gradient.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Gradients for ZCS"""
2+
3+
from typing import Tuple
4+
5+
import numpy as np
6+
7+
from ..backend import backend_name, tf, torch, paddle # noqa
8+
9+
10+
class LazyGrad:
11+
"""Gradients for ZCS with lazy evaluation."""
12+
13+
def __init__(self, zcs_parameters, u):
14+
self.zcs_parameters = zcs_parameters
15+
self.n_dims = len(zcs_parameters["leaves"])
16+
17+
# create tensor $a_{ij}$
18+
if backend_name == "tensorflow":
19+
self.a = tf.Variable(tf.ones_like(u), trainable=True)
20+
elif backend_name == "pytorch":
21+
self.a = torch.ones_like(u).requires_grad_()
22+
elif backend_name == "paddle":
23+
self.a = paddle.ones_like(u) # noqa
24+
self.a.stop_gradient = False
25+
else:
26+
raise NotImplementedError(
27+
f"ZCS is not implemented for backend {backend_name}"
28+
)
29+
30+
# omega
31+
if backend_name == "tensorflow":
32+
self.a_tape = tf.GradientTape(
33+
persistent=True, watch_accessed_variables=False
34+
)
35+
with self.a_tape: # z_tape is already watching
36+
self.a_tape.watch(self.a)
37+
omega = tf.math.reduce_sum(u * self.a)
38+
else:
39+
omega = (u * self.a).sum()
40+
41+
# cached lower-order derivatives of omega
42+
self.cached_omega_grads = {
43+
# the only initial element is omega itself, with all orders being zero
44+
(0,)
45+
* self.n_dims: omega
46+
}
47+
48+
def grad_wrt_z(self, y, z):
49+
if backend_name == "tensorflow":
50+
with self.a_tape: # z_tape is already watching
51+
return self.zcs_parameters["tape"].gradient(y, z)
52+
if backend_name == "pytorch":
53+
return torch.autograd.grad(y, z, create_graph=True)[0]
54+
if backend_name == "paddle":
55+
return paddle.grad(y, z, create_graph=True)[0] # noqa
56+
raise NotImplementedError(
57+
f"ZCS is not implemented for backend {backend_name}"
58+
)
59+
60+
def grad_wrt_a(self, y):
61+
if backend_name == "tensorflow":
62+
# no need to watch here because we don't need higher-orders w.r.t. a
63+
return self.a_tape.gradient(y, self.a)
64+
if backend_name == "pytorch":
65+
return torch.autograd.grad(y, self.a, create_graph=True)[0]
66+
if backend_name == "paddle":
67+
return paddle.grad(y, self.a, create_graph=True)[0] # noqa
68+
raise NotImplementedError(
69+
f"ZCS is not implemented for backend {backend_name}"
70+
)
71+
72+
def compute(self, required_orders: Tuple[int, ...]):
73+
if required_orders in self.cached_omega_grads.keys():
74+
# derivative w.r.t. a
75+
return self.grad_wrt_a(self.cached_omega_grads[required_orders])
76+
77+
# find the start
78+
orders = np.array(required_orders)
79+
exists = np.array(list(self.cached_omega_grads.keys()))
80+
diffs = orders[None, :] - exists
81+
# existing orders no greater than target element-wise
82+
avail_indices = np.where(diffs.min(axis=1) >= 0)[0]
83+
# start from the closet
84+
start_index = np.argmin(diffs[avail_indices].sum(axis=1))
85+
start_orders = exists[avail_indices][start_index]
86+
87+
# dim loop
88+
for i, zi in enumerate(self.zcs_parameters["leaves"]):
89+
# order loop
90+
while start_orders[i] != required_orders[i]:
91+
omega_grad = self.grad_wrt_z(
92+
self.cached_omega_grads[tuple(start_orders)], zi
93+
)
94+
start_orders[i] += 1
95+
self.cached_omega_grads[tuple(start_orders)] = omega_grad
96+
97+
# derivative w.r.t. a
98+
return self.grad_wrt_a(self.cached_omega_grads[required_orders])

0 commit comments

Comments
 (0)