Skip to content

Commit f31fc3d

Browse files
committed
fix discovery example, make len more stable in DiscoveryModel
1 parent 79b37e4 commit f31fc3d

File tree

5 files changed

+10
-7
lines changed

5 files changed

+10
-7
lines changed

examples/AC-inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def f_model(u_model, vars, x, t):
4646
x = X_star[:, 0:1]
4747
t = X_star[:, 1:2]
4848

49+
print(np.shape(x))
4950
# append to a list for input to model.fit
5051
X = [x, t]
5152

@@ -54,7 +55,7 @@ def f_model(u_model, vars, x, t):
5455

5556
# initialize, compile, train model
5657
model = DiscoveryModel()
57-
model.compile(layer_sizes, f_model, X, u_star, vars, col_weights=col_weights) # baseline approach can be done by simply removing the col_weights arg
58+
model.compile(layer_sizes, f_model, X, u_star, params, col_weights=col_weights) # baseline approach can be done by simply removing the col_weights arg
5859
model.tf_optimizer_weights = tf.keras.optimizers.Adam(lr=0.005, beta_1=.95) # an example as to how one could modify an optimizer, in this case the col_weights optimizer
5960

6061
# train loop

tensordiffeq.egg-info/PKG-INFO

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
Metadata-Version: 2.1
22
Name: tensordiffeq
3-
Version: 0.1.6
3+
Version: 0.1.6.2
44
Summary: Distributed PDE Solver in Tensorflow
55
Home-page: https://github.com/tensordiffeq/tensordiffeq
66
Author: Levi McClenny
77
Author-email: [email protected]
88
License: UNKNOWN
9-
Download-URL: https://github.com/tensordiffeq/tensordiffeq/tarball/v0.1.6
9+
Download-URL: https://github.com/tensordiffeq/tensordiffeq/tarball/v0.1.6.2
1010
Description:
1111
![TensorDiffEq logo](tdq-banner.png)
1212

tensordiffeq.egg-info/SOURCES.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ tensordiffeq/models.py
99
tensordiffeq/networks.py
1010
tensordiffeq/optimizers.py
1111
tensordiffeq/plotting.py
12+
tensordiffeq/sampling.py
1213
tensordiffeq/utils.py
1314
tensordiffeq.egg-info/PKG-INFO
1415
tensordiffeq.egg-info/SOURCES.txt

tensordiffeq.egg-info/requires.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ numpy
33
scipy
44
tensorflow
55
tensorflow_probability
6-
smt
6+
pyDOE2

tensordiffeq/models.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,13 @@ def predict(self, X_star):
129129
# WIP
130130
# TODO DiscoveryModel
131131
class DiscoveryModel():
132-
def compile(self, layer_sizes, f_model, X, u, vars, col_weights=None):
132+
def compile(self, layer_sizes, f_model, X, u, var, col_weights=None):
133133
self.layer_sizes = layer_sizes
134134
self.f_model = get_tf_model(f_model)
135135
self.X = X
136136
self.u = u
137-
self.vars = vars
137+
self.vars = var
138+
self.len_ = len(var)
138139
self.u_model = neural_net(self.layer_sizes)
139140
self.tf_optimizer = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
140141
self.tf_optimizer_vars = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
@@ -167,7 +168,7 @@ def grad(self):
167168
@tf.function
168169
def train_op(self):
169170
self.variables = self.u_model.trainable_variables
170-
len_ = len(self.vars)
171+
len_ = self.len_
171172
if self.col_weights is not None:
172173

173174
self.variables.extend([self.col_weights])

0 commit comments

Comments
 (0)