Skip to content

Commit 2e2ec24

Browse files
committed
Update AC-inference with comments
1 parent 328a78e commit 2e2ec24

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

examples/AC-inference.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55
from tensordiffeq.models import DiscoveryModel
66
from tensordiffeq.utils import tensor
77

8+
#####################
9+
## Discovery Model ##
10+
#####################
11+
12+
13+
# Put params into a list
14+
params = [tf.Variable(0.0, dtype=tf.float32), tf.Variable(0.0, dtype=tf.float32)]
15+
816
def f_model(u_model, vars, x, t):
917
u = u_model(tf.concat([x,t],1))
1018
u_x = tf.gradients(u, x)
@@ -15,11 +23,6 @@ def f_model(u_model, vars, x, t):
1523
f_u = u_t - c1*u_xx + c2*u*u*u - c2*u
1624
return f_u
1725

18-
19-
20-
lb = np.array([-1.0])
21-
ub = np.array([1.0])
22-
2326
# Import data, same data as Raissi et al
2427

2528
data = scipy.io.loadmat('AC.mat')
@@ -29,33 +32,31 @@ def f_model(u_model, vars, x, t):
2932
Exact = data['uu']
3033
Exact_u = np.real(Exact)
3134

32-
35+
# define MLP depth and layer width
3336
layer_sizes = [2, 128, 128, 128, 128, 1]
34-
model = DiscoveryModel()
3537

38+
# generate all combinations of x and t
3639
X, T = np.meshgrid(x,t)
3740

3841
X_star = np.hstack((X.flatten()[:,None], T.flatten()[:,None]))
3942
u_star = Exact_u.T.flatten()[:,None]
40-
N = X_star.shape[0]
41-
T = t.shape[0]
4243

4344
x = X_star[:, 0:1]
4445
t = X_star[:, 1:2]
4546

46-
X_star = tensor(X_star)
47-
48-
47+
# append to a list for input to model.fit
4948
X = [x, t]
50-
print(np.shape(x), np.shape(t), np.shape(X_star))
51-
52-
vars = [tf.Variable(0.0, dtype = tf.float32), tf.Variable(0.0, dtype = tf.float32)]
5349

50+
#define col_weights for SA discovery model
5451
col_weights = tf.Variable(tf.random.uniform([np.shape(x)[0], 1]))
5552

56-
model.compile(layer_sizes, f_model, X, u_star, vars)
53+
# initialize, compile, train model
54+
model = DiscoveryModel()
55+
model.compile(layer_sizes, f_model, X, u_star, vars, col_weights=col_weights) # baseline approach can be done by simply removing the col_weights arg
56+
model.tf_optimizer_weights = tf.keras.optimizers.Adam(lr=0.005, beta_1=.95) # an example as to how one could modify an optimizer, in this case the col_weights optimizer
5757

58-
#train loop
58+
# train loop
5959
model.fit(tf_iter = 10000)
6060

61+
# doesnt work quite yet
6162
tdq.plotting.plot_weights(model, scale = 10.0)

0 commit comments

Comments
 (0)