Skip to content

Commit bc64a6b

Browse files
KfacJaxDevKfacJaxDev
authored andcommitted
Internal Change
PiperOrigin-RevId: 805859317
1 parent 72062e3 commit bc64a6b

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

examples/optimizers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def create_optimizer(
114114
total_steps: int | None,
115115
total_epochs: float | None,
116116
schedule_free_config: config_dict.ConfigDict,
117+
model_func_for_power_iteration: kfac_jax.optimizer.ValueFunc | None = None,
117118
) -> optax_wrapper.OptaxWrapper | kfac_jax.Optimizer:
118119
"""Creates an optimizer from the provided configuration."""
119120

@@ -154,6 +155,11 @@ def create_optimizer(
154155
**kwargs[sched_name]
155156
)
156157

158+
if model_func_for_power_iteration is not None:
159+
kwargs["value_func_for_power_iteration"] = (
160+
model_func_for_power_iteration
161+
)
162+
157163
return kfac_jax.Optimizer(
158164
value_and_grad_func=value_and_grad_func,
159165
l2_reg=l2_reg,

examples/training.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ class SupervisedExperiment(abc.ABC):
149149
computation of the loss of the model for the estimator.
150150
eval_splits: Evaluation splits of the evaluation dataset loader.
151151
batch_size: An instance of `ExperimentBatchSizes`.
152+
model_func_for_power_iteration: A function that allows a different
153+
computation of the loss of the model for power iteration.
152154
init_parameters_func: A function that initializes the parameters and
153155
optionally the state of the model if it has one.
154156
params_init: A function that initializes the model parameters.
@@ -173,6 +175,9 @@ def __init__(
173175
model_func_for_estimator: kfac_jax.optimizer.ValueFunc | None = None,
174176
eval_splits: tuple[str, ...] = ("train", "test"),
175177
batch_size_calculator_ctor: BatchSizeCalculatorCtor = BatchSizeCalculator,
178+
model_func_for_power_iteration: (
179+
kfac_jax.optimizer.ValueFunc | None
180+
) = None,
176181
):
177182
"""Initializes experiment.
178183
@@ -193,6 +198,8 @@ def __init__(
193198
eval_splits: Evaluation splits of the evaluation dataset loader.
194199
batch_size_calculator_ctor: A constructor function to create a batch size
195200
calculator.
201+
model_func_for_power_iteration: A function that allows a different
202+
computation of the loss of the model for power iteration.
196203
"""
197204
self.mode = mode
198205
self.init_rng, seed_rng = jax.random.split(init_rng)
@@ -215,6 +222,7 @@ def __init__(
215222
self.params_init = jax.pmap(init_parameters_func, axis_name="kfac_axis")
216223
self.model_loss_func = model_loss_func
217224
self.model_func_for_estimator = model_func_for_estimator
225+
self.model_func_for_power_iteration = model_func_for_power_iteration
218226

219227
self.train_model_func = functools.partial(
220228
self.model_loss_func, is_training=True
@@ -467,6 +475,9 @@ def create_optimizer(
467475
total_steps=self.config.training.steps,
468476
total_epochs=self.config.training.epochs,
469477
schedule_free_config=self._schedule_free_config,
478+
model_func_for_power_iteration=functools.partial(
479+
self.model_func_for_power_iteration, is_training=True
480+
) if self.model_func_for_power_iteration is not None else None,
470481
)
471482

472483
def maybe_initialize_state(self):

kfac_jax/_src/optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def __init__(
545545
auto_register_tags=use_automatic_registration,
546546
auto_register_kwargs=auto_register_kwargs,
547547
)
548+
548549
self._implicit = curvature_estimator.ImplicitExactCurvature(
549550
self._value_func,
550551
params_index=self._params_index,

0 commit comments

Comments
 (0)