@@ -92,12 +92,22 @@ def apply_grads(n_batches, obj=obj):
9292 for _ in range (n_batches ):
9393 # unstack = tf.unstack(obj.u_model.trainable_variables, axis = 2)
9494 obj .variables = obj .u_model .trainable_variables
95- if obj .isAdaptive :
95+ if obj .isAdaptive and obj . u_weights is not None and obj . col_weights is not None :
9696 obj .variables .extend ([obj .u_weights , obj .col_weights ])
9797 loss_value , grads = obj .grad ()
9898 obj .tf_optimizer .apply_gradients (zip (grads [:- 2 ], obj .u_model .trainable_variables ))
9999 obj .tf_optimizer_weights .apply_gradients (
100100 zip ([- grads [- 2 ], - grads [- 1 ]], [obj .u_weights , obj .col_weights ]))
101+ elif obj .isAdaptive and obj .u_weights is None and obj .col_weights is not None :
102+ obj .variables .extend ([obj .col_weights ])
103+ loss_value , grads = obj .grad ()
104+ obj .tf_optimizer .apply_gradients (zip (grads [:- 1 ], obj .u_model .trainable_variables ))
105+ obj .tf_optimizer_weights .apply_gradients (zip ([- grads [- 1 ]], [obj .col_weights ]))
106+ elif obj .isAdaptive and obj .u_weights is not None and obj .col_weights is None :
107+ obj .variables .extend ([obj .u_weights ])
108+ loss_value , grads = obj .grad ()
109+ obj .tf_optimizer .apply_gradients (zip (grads [:- 1 ], obj .u_model .trainable_variables ))
110+ obj .tf_optimizer_weights .apply_gradients (zip ([- grads [- 1 ]], [obj .u_weights ]))
101111 else :
102112 loss_value , grads = obj .grad ()
103113 obj .tf_optimizer .apply_gradients (zip (grads , obj .u_model .trainable_variables ))
0 commit comments