@@ -217,8 +217,7 @@ def train(self,
217217 batch_size = 2 ** 6 ,
218218 epochs = 100 ,
219219 learning_rate = 0.001 ,
220- adaptive_weights = False ,
221- log_adaptive_weights = None ,
220+ adaptive_weights = None ,
222221 log_loss_gradients = None ,
223222 shuffle = True ,
224223 callbacks = None ,
@@ -228,7 +227,8 @@ def train(self,
228227 stop_after = None ,
229228 stop_loss_value = 1e-8 ,
230229 log_parameters = None ,
231- log_parameters_freq = None ,
230+ log_functionals = None ,
231+ log_loss_landscape = None ,
232232 save_weights_to = None ,
233233 save_weights_freq = 0 ,
234234 default_zero_weight = 0.0 ,
@@ -262,12 +262,13 @@ def train(self,
262262 learning_rate = ([0, 100, 1000], [0.001, 0.0005, 0.00001])
263263 shuffle: Boolean (whether to shuffle the training data).
264264 Default value is True.
265- adaptive_weights: Defaulted to False (no updates - evaluated once in the beginning).
266- Used if the model is compiled with adaptive_weights.
267- log_adaptive_weights: Logging the weights and gradients of adaptive_weight.
268- Defaulted to adaptive_weights.
269- log_loss_gradients: Frequency for logging the norm2 of gradients of each target.
270- Defaulted to None.
265+ adaptive_weights: Pass a Dict with the following keys:
266+ . freq: Freq to update the weights.
267+ . log_freq: Freq to log the weights and gradients in the history object.
268+ . beta: The beta parameter in from Gradient Pathology paper.
269+ log_loss_gradients: Pass a Dict with the following keys:
270+ . freq: Freq of logs. Defaulted to 100.
271+ . path: Path to log the gradients.
271272 callbacks: List of `keras.callbacks.Callback` instances.
272273 reduce_lr_after: patience to reduce learning rate or stop after certain missed epochs.
273274 Defaulted to epochs max(10, epochs/10).
@@ -278,7 +279,21 @@ def train(self,
278279 This values affects number of failed attempts to trigger reduce learning rate based on reduce_lr_after.
279280 stop_after: To stop after certain missed epochs. Defaulted to total number of epochs.
280281 stop_loss_value: The minimum value of the total loss that stops the training automatically.
281- Defaulted to 1e-8.
282+ Defaulted to 1e-8.
283+ log_parameters: Dict object expecting the following keys:
284+ . parameters: pass list of parameters.
285+ . freq: pass freq of outputs.
286+ log_functionals: Dict object expecting the following keys:
287+ . functionals: List of functionals to log their training history.
288+ . inputs: The input grid to evaluate the value of each functional.
289+ Should be of the same size as the inputs to the model.train.
290+ . path: Path to the location that the csv files will be logged.
291+ . freq: Freq of logging the functionals.
292+ log_loss_landscape: Dict object expecting the following arguments:
293+ . norm: defaulted to 2.
294+ . resolution: defaulted to 10.
295+ . path: Path to the location that the csv files will be logged.
296+ . freq: Freq of logging the loss landscape.
282297 save_weights_to: (file_path) If you want to save the state of the model (at the end of the training).
283298 save_weights_freq: (Integer) Save weights every N epcohs.
284299 Defaulted to 0.
@@ -416,30 +431,46 @@ def train(self,
416431 opt_fit_func = self ._model .fit
417432
418433 if adaptive_weights :
434+ if not isinstance (adaptive_weights , dict ):
435+ adaptive_weights = GradientPathologyLossWeight .prepare_inputs (adaptive_weights )
419436 callbacks .append (
420437 GradientPathologyLossWeight (
421- self .model , x_true , y_star , sample_weights ,
422- beta = 0.1 , freq = adaptive_weights , log_freq = log_adaptive_weights ,
423- types = [ type ( v ). __name__ for v in self . constraints ]
438+ self .model , x_true , y_star , sample_weights ,
439+ types = [ type ( v ). __name__ for v in self . constraints ] ,
440+ ** adaptive_weights
424441 ),
425- # k.callbacks.CSVLogger('GP.log')
426442 )
427- elif log_loss_gradients :
443+
444+ if log_loss_gradients :
445+ if not isinstance (log_loss_gradients , dict ):
446+ log_loss_gradients = LossGradientHistory .prepare_inputs (log_loss_gradients )
428447 callbacks .append (
429448 LossGradientHistory (
430449 self .model , x_true , y_star , sample_weights ,
431- freq = log_loss_gradients
450+ ** log_loss_gradients
432451 )
433452 )
434-
435453 if log_parameters :
454+ if not isinstance (log_parameters , dict ):
455+ log_parameters = ParameterHistory .prepare_inputs (log_parameters )
456+ callbacks .append (
457+ ParameterHistory (** log_parameters )
458+ )
459+ if log_functionals :
460+ if not isinstance (log_functionals , dict ):
461+ log_functionals = FunctionalHistory .prepare_inputs (log_functionals )
462+ callbacks .append (
463+ FunctionalHistory (self , ** log_functionals )
464+ )
465+ if log_loss_landscape :
466+ if not isinstance (log_loss_landscape , dict ):
467+ log_loss_landscape = LossLandscapeHistory .prepare_inputs (log_loss_landscape )
436468 callbacks .append (
437- ParameterHistory (
438- list ( log_parameters ) ,
439- freq = 1 if log_parameters_freq is None else log_parameters_freq
469+ LossLandscapeHistory (
470+ self . model , x_true , y_star , sample_weights ,
471+ ** log_loss_landscape
440472 )
441473 )
442-
443474 # training the models.
444475 history = opt_fit_func (
445476 x_true , y_star ,
0 commit comments