@@ -149,6 +149,8 @@ class SupervisedExperiment(abc.ABC):
149149 computation of the loss of the model for the estimator.
150150 eval_splits: Evaluation splits of the evaluation dataset loader.
151151 batch_size: An instance of `ExperimentBatchSizes`.
152+ model_func_for_power_iteration: A function that allows a different
153+ computation of the loss of the model for power iteration.
152154 init_parameters_func: A function that initializes the parameters and
153155 optionally the state of the model if it has one.
154156 params_init: A function that initializes the model parameters.
@@ -173,6 +175,9 @@ def __init__(
173175 model_func_for_estimator : kfac_jax .optimizer .ValueFunc | None = None ,
174176 eval_splits : tuple [str , ...] = ("train" , "test" ),
175177 batch_size_calculator_ctor : BatchSizeCalculatorCtor = BatchSizeCalculator ,
178+ model_func_for_power_iteration : (
179+ kfac_jax .optimizer .ValueFunc | None
180+ ) = None ,
176181 ):
177182 """Initializes experiment.
178183
@@ -193,6 +198,8 @@ def __init__(
193198 eval_splits: Evaluation splits of the evaluation dataset loader.
194199 batch_size_calculator_ctor: A constructor function to create a batch size
195200 calculator.
201+ model_func_for_power_iteration: A function that allows a different
202+ computation of the loss of the model for power iteration.
196203 """
197204 self .mode = mode
198205 self .init_rng , seed_rng = jax .random .split (init_rng )
@@ -215,6 +222,7 @@ def __init__(
215222 self .params_init = jax .pmap (init_parameters_func , axis_name = "kfac_axis" )
216223 self .model_loss_func = model_loss_func
217224 self .model_func_for_estimator = model_func_for_estimator
225+ self .model_func_for_power_iteration = model_func_for_power_iteration
218226
219227 self .train_model_func = functools .partial (
220228 self .model_loss_func , is_training = True
@@ -467,6 +475,9 @@ def create_optimizer(
467475 total_steps = self .config .training .steps ,
468476 total_epochs = self .config .training .epochs ,
469477 schedule_free_config = self ._schedule_free_config ,
478+ model_func_for_power_iteration = functools .partial (
479+ self .model_func_for_power_iteration , is_training = True
480+ ) if self .model_func_for_power_iteration is not None else None ,
470481 )
471482
472483 def maybe_initialize_state (self ):
0 commit comments