Skip to content

Commit ab63f4b

Browse files
committed
decouple model.compile and model.fit to allow repeated calling of fit
1 parent 760992a commit ab63f4b

File tree

4 files changed

+35
-23
lines changed

4 files changed

+35
-23
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ cache/*
22
tensordiffeq/__pycache__/*
33
venv/*
44
.idea/*
5-
dist/*
5+
dist/*
6+
examples/*.ipynb

examples/burgers-new.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def f_model(u_model, x, t):
3838

3939
model = CollocationSolverND()
4040
model.compile(layer_sizes, f_model, Domain, BCs)
41-
model.fit(tf_iter=1000, newton_iter=1000)
41+
model.fit(tf_iter=301, newton_iter=101)
42+
43+
model.fit(tf_iter=301, newton_iter=101)
4244

4345

4446
#######################################################

tensordiffeq/fit.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def fit(obj, tf_iter, newton_iter, batch_sz=None, newton_eager=True):
15-
obj.u_model = neural_net(obj.layer_sizes)
15+
# obj.u_model = neural_net(obj.layer_sizes)
1616
# obj.build_loss()
1717
# Can adjust batch size for collocation points, here we set it to N_f
1818
if batch_sz is not None:
@@ -25,19 +25,18 @@ def fit(obj, tf_iter, newton_iter, batch_sz=None, newton_eager=True):
2525
# N_f = len(obj.x_f)
2626
n_batches = int(N_f // obj.batch_sz)
2727
start_time = time.time()
28-
obj.tf_optimizer = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
29-
obj.tf_optimizer_weights = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
28+
# obj.tf_optimizer = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
29+
# obj.tf_optimizer_weights = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
3030

3131
# these cant be tf.functions on initialization since the distributed strategy requires its own
3232
# graph using grad and adaptgrad, so they cant be compiled as tf.functions until we know dist/non-dist
3333
obj.grad = tf.function(obj.grad)
3434
print("starting Adam training")
3535
# tf.profiler.experimental.start('../cache/tblogdir1')
36-
print(n_batches)
37-
print(tf_iter)
36+
train_op_fn = train_op_inner(obj)
37+
print(obj.tf_optimizer)
3838
for epoch in range(tf_iter):
39-
40-
loss_value = train_op(obj, n_batches)
39+
loss_value = train_op_fn(n_batches, obj)
4140

4241
if epoch % 100 == 0:
4342
elapsed = time.time() - start_time
@@ -85,23 +84,29 @@ def lbfgs_op(func, init_params, newton_iter):
8584
tolerance=1e-20,
8685
)
8786

88-
@tf.function()
89-
def train_op(obj, n_batches):
90-
for _ in range(n_batches):
91-
# unstack = tf.unstack(obj.u_model.trainable_variables, axis = 2)
92-
obj.variables = obj.u_model.trainable_variables
93-
if obj.isAdaptive:
94-
obj.variables.extend([obj.u_weights, obj.col_weights])
95-
loss_value, grads = obj.grad()
96-
obj.tf_optimizer.apply_gradients(zip(grads[:-2], obj.u_model.trainable_variables))
97-
obj.tf_optimizer_weights.apply_gradients(zip([-grads[-2], -grads[-1]], [obj.u_weights, obj.col_weights]))
98-
else:
99-
loss_value, grads = obj.grad()
100-
obj.tf_optimizer.apply_gradients(zip(grads, obj.u_model.trainable_variables))
101-
return loss_value
87+
88+
def train_op_inner(obj):
89+
@tf.function
90+
def apply_grads(n_batches, obj=obj):
91+
for _ in range(n_batches):
92+
# unstack = tf.unstack(obj.u_model.trainable_variables, axis = 2)
93+
obj.variables = obj.u_model.trainable_variables
94+
if obj.isAdaptive:
95+
obj.variables.extend([obj.u_weights, obj.col_weights])
96+
loss_value, grads = obj.grad()
97+
obj.tf_optimizer.apply_gradients(zip(grads[:-2], obj.u_model.trainable_variables))
98+
obj.tf_optimizer_weights.apply_gradients(
99+
zip([-grads[-2], -grads[-1]], [obj.u_weights, obj.col_weights]))
100+
else:
101+
loss_value, grads = obj.grad()
102+
obj.tf_optimizer.apply_gradients(zip(grads, obj.u_model.trainable_variables))
103+
return loss_value
104+
105+
return apply_grads
102106

103107

104108
# TODO Distributed training re-integration
109+
# TODO decouple u_model from being overwritten by calling model.fit
105110

106111
def fit_dist(obj, tf_iter, newton_iter, batch_sz=None, newton_eager=True):
107112
BUFFER_SIZE = len(obj.x_f)

tensordiffeq/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ def __init__(self, assimilate=False):
1313

1414
def compile(self, layer_sizes, f_model, domain, bcs, isAdaptive=False,
1515
col_weights=None, u_weights=None, g=None, dist=False):
16+
self.tf_optimizer = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
17+
self.tf_optimizer_weights = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
1618
self.layer_sizes = layer_sizes
1719
self.sizes_w, self.sizes_b = get_sizes(layer_sizes)
1820
self.bcs = bcs
@@ -27,6 +29,7 @@ def compile(self, layer_sizes, f_model, domain, bcs, isAdaptive=False,
2729
self.X_f_len = tf.slice(self.X_f_dims, [0], [1]).numpy()
2830
tmp = [np.reshape(vec, (-1,1)) for i, vec in enumerate(self.domain.X_f.T)]
2931
self.X_f_in = np.asarray(tmp)
32+
self.u_model = neural_net(self.layer_sizes)
3033

3134

3235

@@ -87,6 +90,7 @@ def update_loss(self):
8790
loss_tmp = tf.math.add(loss_tmp, mse_f_u)
8891
return loss_tmp
8992

93+
#@tf.function
9094
def grad(self):
9195
with tf.GradientTape() as tape:
9296
loss_value = self.update_loss()

0 commit comments

Comments
 (0)