Skip to content

Commit 4e5ddbe

Browse files
committed
add 0 newton iter case
1 parent 8d94885 commit 4e5ddbe

File tree

4 files changed

+21
-20
lines changed

4 files changed

+21
-20
lines changed

examples/burgers-new.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def f_model(u_model, x, t):
3838

3939
model = CollocationSolverND()
4040
model.compile(layer_sizes, f_model, Domain, BCs)
41-
model.fit(tf_iter=301, newton_iter=101)
41+
model.fit(newton_iter=301)
4242

43-
model.fit(tf_iter=301, newton_iter=101)
43+
model.fit(tf_iter=301)
4444

4545

4646
#######################################################

tensordiffeq.egg-info/PKG-INFO

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
Metadata-Version: 2.1
22
Name: tensordiffeq
3-
Version: 0.1.6.2
3+
Version: 0.1.6.3
44
Summary: Distributed PDE Solver in Tensorflow
55
Home-page: https://github.com/tensordiffeq/tensordiffeq
66
Author: Levi McClenny
77
Author-email: [email protected]
88
License: UNKNOWN
9-
Download-URL: https://github.com/tensordiffeq/tensordiffeq/tarball/v0.1.6.2
9+
Download-URL: https://github.com/tensordiffeq/tensordiffeq/tarball/v0.1.6.3
1010
Description:
1111
![TensorDiffEq logo](tdq-banner.png)
1212

tensordiffeq/fit.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"
1212

1313

14-
def fit(obj, tf_iter, newton_iter, batch_sz=None, newton_eager=True):
14+
def fit(obj, tf_iter=0, newton_iter=0, batch_sz=None, newton_eager=True):
1515
# obj.u_model = neural_net(obj.layer_sizes)
1616
# obj.build_loss()
1717
# Can adjust batch size for collocation points, here we set it to N_f
@@ -34,7 +34,6 @@ def fit(obj, tf_iter, newton_iter, batch_sz=None, newton_eager=True):
3434
print("starting Adam training")
3535
# tf.profiler.experimental.start('../cache/tblogdir1')
3636
train_op_fn = train_op_inner(obj)
37-
print(obj.tf_optimizer)
3837
for epoch in range(tf_iter):
3938
loss_value = train_op_fn(n_batches, obj)
4039

@@ -46,22 +45,24 @@ def fit(obj, tf_iter, newton_iter, batch_sz=None, newton_eager=True):
4645
start_time = time.time()
4746
# tf.profiler.experimental.stop()
4847

49-
print("Starting L-BFGS training")
50-
# tf.profiler.experimental.start('../cache/tblogdir1')
5148

52-
if newton_eager:
53-
print("Executing eager-mode L-BFGS")
54-
loss_and_flat_grad = obj.get_loss_and_flat_grad()
55-
eager_lbfgs(loss_and_flat_grad,
56-
get_weights(obj.u_model),
57-
Struct(), maxIter=newton_iter, learningRate=0.8)
49+
# tf.profiler.experimental.start('../cache/tblogdir1')
50+
if newton_iter > 0:
51+
print("Starting L-BFGS training")
52+
if newton_eager:
53+
print("Executing eager-mode L-BFGS")
54+
loss_and_flat_grad = obj.get_loss_and_flat_grad()
55+
eager_lbfgs(loss_and_flat_grad,
56+
get_weights(obj.u_model),
57+
Struct(), maxIter=newton_iter, learningRate=0.8)
5858

59-
else:
60-
print("Executing graph-mode L-BFGS\n Building graph...")
61-
print("Warning: Typically eager-mode L-BFGS is faster. If the computational graph takes a long time to build, "
62-
"or the computation is slow, try eager-mode L-BFGS")
59+
else:
60+
print("Executing graph-mode L-BFGS\n Building graph...")
61+
print("Warning: Depending on your CPU/GPU setup, eager-mode L-BFGS may prove faster. If the computational "
62+
"graph takes a long time to build, or the computation is slow, try eager-mode L-BFGS (enabled by "
63+
"default)")
6364

64-
lbfgs_train(obj, newton_iter)
65+
lbfgs_train(obj, newton_iter)
6566

6667
# tf.profiler.experimental.stop()
6768

tensordiffeq/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def grad(self):
9797
grads = tape.gradient(loss_value, self.variables)
9898
return loss_value, grads
9999

100-
def fit(self, tf_iter, newton_iter, batch_sz=None, newton_eager=True):
100+
def fit(self, tf_iter = 0, newton_iter = 0, batch_sz=None, newton_eager=True):
101101
if self.isAdaptive and (batch_sz is not None):
102102
raise Exception("Currently we dont support minibatching for adaptive PINNs")
103103
if self.dist:

0 commit comments

Comments
 (0)