1212
1313
1414def fit (obj , tf_iter , newton_iter , batch_sz = None , newton_eager = True ):
15- obj .u_model = neural_net (obj .layer_sizes )
15+ # obj.u_model = neural_net(obj.layer_sizes)
1616 # obj.build_loss()
1717 # Can adjust batch size for collocation points, here we set it to N_f
1818 if batch_sz is not None :
@@ -25,19 +25,18 @@ def fit(obj, tf_iter, newton_iter, batch_sz=None, newton_eager=True):
2525 # N_f = len(obj.x_f)
2626 n_batches = int (N_f // obj .batch_sz )
2727 start_time = time .time ()
28- obj .tf_optimizer = tf .keras .optimizers .Adam (lr = 0.005 , beta_1 = .99 )
29- obj .tf_optimizer_weights = tf .keras .optimizers .Adam (lr = 0.005 , beta_1 = .99 )
28+ # obj.tf_optimizer = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
29+ # obj.tf_optimizer_weights = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
3030
3131 # these cant be tf.functions on initialization since the distributed strategy requires its own
3232 # graph using grad and adaptgrad, so they cant be compiled as tf.functions until we know dist/non-dist
3333 obj .grad = tf .function (obj .grad )
3434 print ("starting Adam training" )
3535 # tf.profiler.experimental.start('../cache/tblogdir1')
36- print ( n_batches )
37- print (tf_iter )
36+ train_op_fn = train_op_inner ( obj )
37+ print (obj . tf_optimizer )
3838 for epoch in range (tf_iter ):
39-
40- loss_value = train_op (obj , n_batches )
39+ loss_value = train_op_fn (n_batches , obj )
4140
4241 if epoch % 100 == 0 :
4342 elapsed = time .time () - start_time
@@ -85,23 +84,29 @@ def lbfgs_op(func, init_params, newton_iter):
8584 tolerance = 1e-20 ,
8685 )
8786
88- @tf .function ()
89- def train_op (obj , n_batches ):
90- for _ in range (n_batches ):
91- # unstack = tf.unstack(obj.u_model.trainable_variables, axis = 2)
92- obj .variables = obj .u_model .trainable_variables
93- if obj .isAdaptive :
94- obj .variables .extend ([obj .u_weights , obj .col_weights ])
95- loss_value , grads = obj .grad ()
96- obj .tf_optimizer .apply_gradients (zip (grads [:- 2 ], obj .u_model .trainable_variables ))
97- obj .tf_optimizer_weights .apply_gradients (zip ([- grads [- 2 ], - grads [- 1 ]], [obj .u_weights , obj .col_weights ]))
98- else :
99- loss_value , grads = obj .grad ()
100- obj .tf_optimizer .apply_gradients (zip (grads , obj .u_model .trainable_variables ))
101- return loss_value
87+
88+ def train_op_inner (obj ):
89+ @tf .function
90+ def apply_grads (n_batches , obj = obj ):
91+ for _ in range (n_batches ):
92+ # unstack = tf.unstack(obj.u_model.trainable_variables, axis = 2)
93+ obj .variables = obj .u_model .trainable_variables
94+ if obj .isAdaptive :
95+ obj .variables .extend ([obj .u_weights , obj .col_weights ])
96+ loss_value , grads = obj .grad ()
97+ obj .tf_optimizer .apply_gradients (zip (grads [:- 2 ], obj .u_model .trainable_variables ))
98+ obj .tf_optimizer_weights .apply_gradients (
99+ zip ([- grads [- 2 ], - grads [- 1 ]], [obj .u_weights , obj .col_weights ]))
100+ else :
101+ loss_value , grads = obj .grad ()
102+ obj .tf_optimizer .apply_gradients (zip (grads , obj .u_model .trainable_variables ))
103+ return loss_value
104+
105+ return apply_grads
102106
103107
104108# TODO Distributed training re-integration
109+ # TODO decouple u_model from being overwritten by calling model.fit
105110
106111def fit_dist (obj , tf_iter , newton_iter , batch_sz = None , newton_eager = True ):
107112 BUFFER_SIZE = len (obj .x_f )
0 commit comments