@@ -112,9 +112,10 @@ def compile(
112112 weighted by the `loss_weights` coefficients.
113113 external_trainable_variables: A trainable ``dde.Variable`` object or a list
114114 of trainable ``dde.Variable`` objects. The unknown parameters in the
115- physics systems that need to be recovered. If the backend is
116- tensorflow.compat.v1, `external_trainable_variables` is ignored, and all
117- trainable ``dde.Variable`` objects are automatically collected.
115+ physics systems that need to be recovered. Regularization will not be
116+ applied to these variables. If the backend is tensorflow.compat.v1,
117+ `external_trainable_variables` is ignored, and all trainable ``dde.Variable``
118+ objects are automatically collected.
118119 verbose (Integer): Controls the verbosity of the compile process.
119120 """
120121 if verbose > 0 and config .rank == 0 :
@@ -330,30 +331,40 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
330331 False , inputs , targets , auxiliary_vars , self .data .losses_test
331332 )
332333
333- # Another way is using per-parameter options
334- # https://pytorch.org/docs/stable/optim.html#per-parameter-options,
335- # but not all optimizers (such as L-BFGS) support this.
336- trainable_variables = (
337- list (self .net .parameters ()) + self .external_trainable_variables
338- )
339- if self .net .regularizer is None :
340- self .opt , self .lr_scheduler = optimizers .get (
341- trainable_variables , self .opt_name , learning_rate = lr , decay = decay
342- )
343- else :
344- if self .net .regularizer [0 ] == "l2" :
345- self .opt , self .lr_scheduler = optimizers .get (
346- trainable_variables ,
347- self .opt_name ,
348- learning_rate = lr ,
349- decay = decay ,
350- weight_decay = self .net .regularizer [1 ],
351- )
352- else :
334+ weight_decay = 0
335+ if self .net .regularizer is not None :
336+ if self .net .regularizer [0 ] != "l2" :
353337 raise NotImplementedError (
354338 f"{ self .net .regularizer [0 ]} regularization to be implemented for "
355- "backend pytorch. "
339+ "backend pytorch"
356340 )
341+ weight_decay = self .net .regularizer [1 ]
342+
343+ optimizer_params = self .net .parameters ()
344+ if self .external_trainable_variables :
345+ # L-BFGS doesn't support per-parameter options.
346+ if self .opt_name in ["L-BFGS" , "L-BFGS-B" ]:
347+ optimizer_params = (
348+ list (optimizer_params ) + self .external_trainable_variables
349+ )
350+ if weight_decay > 0 :
351+ print (
352+ "Warning: L2 regularization will also be applied to external_trainable_variables. "
353+ "Ensure this is intended behavior."
354+ )
355+ else :
356+ optimizer_params = [
357+ {"params" : optimizer_params },
358+ {"params" : self .external_trainable_variables , "weight_decay" : 0 },
359+ ]
360+
361+ self .opt , self .lr_scheduler = optimizers .get (
362+ optimizer_params ,
363+ self .opt_name ,
364+ learning_rate = lr ,
365+ decay = decay ,
366+ weight_decay = weight_decay ,
367+ )
357368
358369 def train_step (inputs , targets , auxiliary_vars ):
359370 def closure ():
0 commit comments