Skip to content

Commit 1faae59

Browse files
committed
make u_weights and col_weights not interdependent
1 parent 06a82a5 commit 1faae59

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

tensordiffeq/fit.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,22 @@ def apply_grads(n_batches, obj=obj):
9292
for _ in range(n_batches):
9393
# unstack = tf.unstack(obj.u_model.trainable_variables, axis = 2)
9494
obj.variables = obj.u_model.trainable_variables
95-
if obj.isAdaptive:
95+
if obj.isAdaptive and obj.u_weights is not None and obj.col_weights is not None:
9696
obj.variables.extend([obj.u_weights, obj.col_weights])
9797
loss_value, grads = obj.grad()
9898
obj.tf_optimizer.apply_gradients(zip(grads[:-2], obj.u_model.trainable_variables))
9999
obj.tf_optimizer_weights.apply_gradients(
100100
zip([-grads[-2], -grads[-1]], [obj.u_weights, obj.col_weights]))
101+
elif obj.isAdaptive and obj.u_weights is None and obj.col_weights is not None:
102+
obj.variables.extend([obj.col_weights])
103+
loss_value, grads = obj.grad()
104+
obj.tf_optimizer.apply_gradients(zip(grads[:-1], obj.u_model.trainable_variables))
105+
obj.tf_optimizer_weights.apply_gradients(zip([-grads[-1]], [obj.col_weights]))
106+
elif obj.isAdaptive and obj.u_weights is not None and obj.col_weights is None:
107+
obj.variables.extend([obj.u_weights])
108+
loss_value, grads = obj.grad()
109+
obj.tf_optimizer.apply_gradients(zip(grads[:-1], obj.u_model.trainable_variables))
110+
obj.tf_optimizer_weights.apply_gradients(zip([-grads[-1]], [obj.u_weights]))
101111
else:
102112
loss_value, grads = obj.grad()
103113
obj.tf_optimizer.apply_gradients(zip(grads, obj.u_model.trainable_variables))

0 commit comments

Comments
 (0)