Skip to content

Commit a68d7a3

Browse files
authored
Merge pull request #18 from levimcclenny/main
Fix output of discoverymodel
2 parents d6810a8 + fef833b commit a68d7a3

File tree

3 files changed

+31
-33
lines changed

3 files changed

+31
-33
lines changed

tensordiffeq/fit.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,14 @@ def fit(obj, tf_iter=0, newton_iter=0, batch_sz=None, newton_eager=True):
4242
for epoch in t:
4343
loss_value = train_op_fn(n_batches, obj)
4444
# Description will be displayed on the left
45-
t.set_description('Adam epoch %i' % (epoch+1))
45+
t.set_description('Adam epoch %i' % (epoch + 1))
4646
# Postfix will be displayed on the right,
4747
# formatted automatically based on argument's datatype
4848
if epoch % 10 == 0:
49-
elapsed = time.time() - start_time
5049
t.set_postfix(loss=loss_value.numpy())
51-
#
52-
# print('It: %d, Time: %.2f, loss: %.6f' % (epoch, elapsed, loss_value.numpy()))
53-
# tf.print(f"total loss: {loss_value}")
54-
start_time = time.time()
5550

56-
# tf.profiler.experimental.stop()
5751

52+
# tf.profiler.experimental.stop()
5853

5954
# tf.profiler.experimental.start('../cache/tblogdir1')
6055
if newton_iter > 0:
@@ -97,7 +92,7 @@ def lbfgs_op(func, init_params, newton_iter):
9792

9893

9994
def train_op_inner(obj):
100-
@tf.function
95+
@tf.function()
10196
def apply_grads(n_batches, obj=obj):
10297
for _ in range(n_batches):
10398
# unstack = tf.unstack(obj.u_model.trainable_variables, axis = 2)
@@ -126,7 +121,6 @@ def apply_grads(n_batches, obj=obj):
126121
return apply_grads
127122

128123

129-
130124
def fit_dist(obj, tf_iter, newton_iter, batch_sz=None, newton_eager=True):
131125
def train_epoch(dataset, STEPS):
132126
total_loss = 0.0
@@ -180,17 +174,17 @@ def dist_loop(obj, STEPS):
180174
return train_loss
181175

182176
def train_loop(obj, tf_iter, STEPS):
183-
177+
print_screen(obj)
184178
start_time = time.time()
185-
for epoch in range(tf_iter):
186-
loss = dist_loop(obj, STEPS)
187-
188-
if epoch % 100 == 0:
189-
elapsed = time.time() - start_time
190-
template = ("Epoch {}, Time: {}, Loss: {}")
191-
print(template.format(epoch, elapsed, loss))
192-
# print('It: %d, Time: %.2f, loss: %.9f' % (epoch, elapsed, tf.get_static_value(loss)))
193-
start_time = time.time()
179+
with trange(tf_iter) as t:
180+
for epoch in t:
181+
loss = dist_loop(obj, STEPS)
182+
t.set_description('Adam epoch %i' % (epoch + 1))
183+
if epoch % 10 == 0:
184+
elapsed = time.time() - start_time
185+
t.set_postfix(loss=loss.numpy())
186+
# print('It: %d, Time: %.2f, loss: %.9f' % (epoch, elapsed, tf.get_static_value(loss)))
187+
start_time = time.time()
194188

195189
print("starting Adam training")
196190
STEPS = np.max((obj.n_batches // obj.strategy.num_replicas_in_sync, 1))

tensordiffeq/models.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from .networks import *
66
from .plotting import *
77
from .fit import *
8+
from tqdm.auto import tqdm, trange
9+
from .output import print_screen
810

911

1012
class CollocationSolverND:
@@ -232,10 +234,7 @@ def compile(self, layer_sizes, f_model, X, u, var, col_weights=None):
232234
@tf.function
233235
def loss(self):
234236
u_pred = self.u_model(tf.concat(self.X, 1))
235-
print(self.vars)
236237
f_u_pred = self.f_model(self.u_model, self.vars, *self.X_in)
237-
print(self.vars)
238-
239238
if self.col_weights is not None:
240239
return MSE(u_pred, self.u) + g_MSE(f_u_pred, constant(0.0), self.col_weights ** 2)
241240
else:
@@ -273,15 +272,18 @@ def train_op(self):
273272
def fit(self, tf_iter):
274273
self.train_loop(tf_iter)
275274

276-
# @tf.function
277275
def train_loop(self, tf_iter): # sourcery skip: move-assign
278276
start_time = time.time()
279-
for i in range(tf_iter):
280-
loss_value = self.train_op()
281-
if i % 100 == 0:
282-
elapsed = time.time() - start_time
283-
print('It: %d, Time: %.2f' % (i, elapsed))
284-
tf.print(f"loss_value: {loss_value}")
285-
var = [var.numpy() for var in self.vars]
286-
tf.print(f"vars estimate(s): {var}")
287-
start_time = time.time()
277+
print_screen(self, discovery_model=True)
278+
with trange(tf_iter) as t:
279+
for i in t:
280+
loss_value = self.train_op()
281+
if i % 10 == 0:
282+
# elapsed = time.time() - start_time
283+
# print('It: %d, Time: %.2f' % (i, elapsed))
284+
# tf.print(f"loss_value: {loss_value}")
285+
var = [var.numpy() for var in self.vars]
286+
t.set_postfix(loss=loss_value.numpy())
287+
t.set_postfix(vars=var)
288+
# tf.print(f"vars estimate(s): {var}")
289+
# start_time = time.time()

tensordiffeq/output.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from os import system, name
33
import sys
44

5-
def print_screen(model):
5+
def print_screen(model, discovery_model=False):
66
f = Figlet(font='slant')
77
print(f.renderText('TensorDiffEq'))
8+
if discovery_model:
9+
print("Running Discovery Model for Parameter Estimation\n\n")
810
print("Neural Network Model Summary\n")
911
print(model.u_model.summary())

0 commit comments

Comments
 (0)