Skip to content

Commit 8bf8400

Browse files
KfacJaxDevKfacJaxDev
authored andcommitted
Create objects for value functions for the estimator object and the corresponding implicit ggn estimator
PiperOrigin-RevId: 805859317
1 parent f7bbd0f commit 8bf8400

File tree

3 files changed

+86
-9
lines changed

3 files changed

+86
-9
lines changed

kfac_jax/_src/optimizer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,23 @@ def __init__(
545545
auto_register_tags=use_automatic_registration,
546546
auto_register_kwargs=auto_register_kwargs,
547547
)
548+
549+
self._value_func_for_estimator = (
550+
self._value_func if value_func_for_estimator is None else
551+
value_func_for_estimator)
552+
548553
self._implicit = curvature_estimator.ImplicitExactCurvature(
549554
self._value_func,
550555
params_index=self._params_index,
551556
batch_size_extractor=batch_size_extractor,
552557
)
553558

559+
self._implicit_estimator = curvature_estimator.ImplicitExactCurvature(
560+
self._value_func_for_estimator,
561+
params_index=self._params_index,
562+
batch_size_extractor=batch_size_extractor,
563+
)
564+
554565
# Each subclass should call finalize on its own, so this gets called only
555566
# for instances of exactly this class type.
556567
if type(self) == Optimizer: # pylint: disable=unidiomatic-typecheck

kfac_jax/_src/tag_graph_matcher.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,69 @@ def clean_jaxpr(
10551055
return to_jaxpr_or_closed_jaxpr(closed_jaxpr, jaxpr)
10561056

10571057

1058+
def clean_layer_tags_jaxpr(
1059+
jaxpr: J,
1060+
only_remove_auto_tags: bool = False,
1061+
) -> tuple[J, tuple[tags.LayerTagEqn | JaxprEqn, ...]]:
1062+
"""Removes layer tags from a Jaxpr."""
1063+
1064+
closed_jaxpr = to_closed_jaxpr(jaxpr)
1065+
eqns = []
1066+
layer_tag_eqns = []
1067+
var_map = {}
1068+
1069+
for eqn in closed_jaxpr.jaxpr.eqns:
1070+
if isinstance(eqn.primitive, tags.LayerTag) and (
1071+
not only_remove_auto_tags
1072+
or (
1073+
eqn.params["meta"].name is not None
1074+
and "Auto" in eqn.params["meta"].name
1075+
)
1076+
):
1077+
for ind1, ind2 in enumerate(eqn.params["meta"].outputs_index):
1078+
var_map[eqn.outvars[ind1]] = eqn.invars[ind2]
1079+
else:
1080+
eqns.append(eqn)
1081+
if isinstance(eqn.primitive, tags.LayerTag):
1082+
layer_tag_eqns.append(eqn)
1083+
1084+
def remap_input_vars(
1085+
eqns: list[JaxprEqn], var_map: dict[jex.core.Var, jex.core.Var]
1086+
) -> list[JaxprEqn]:
1087+
"""Remaps the input variables of a JaxprEqn.
1088+
1089+
Args:
1090+
eqns: The list of JaxprEqns to remap.
1091+
var_map: A mapping from variables to new variables.
1092+
1093+
Returns:
1094+
A new list of JaxprEqns with remapped input variables.
1095+
"""
1096+
eqns_new = []
1097+
for eqn in eqns:
1098+
new_invars = []
1099+
for var in eqn.invars:
1100+
if not isinstance(var, jex.core.Literal) and var in var_map.keys():
1101+
new_invars.append(var_map[var])
1102+
else:
1103+
new_invars.append(var)
1104+
eqns_new.append(eqn.replace(invars=new_invars))
1105+
return eqns_new
1106+
1107+
eqns_new = remap_input_vars(eqns, var_map)
1108+
layer_tag_eqns_new = remap_input_vars(layer_tag_eqns, var_map)
1109+
1110+
closed_jaxpr = ClosedJaxpr(
1111+
jaxpr=closed_jaxpr.jaxpr.replace(eqns=eqns_new),
1112+
consts=closed_jaxpr.consts,
1113+
)
1114+
1115+
return (
1116+
to_jaxpr_or_closed_jaxpr(closed_jaxpr, jaxpr),
1117+
tuple(layer_tag_eqns_new),
1118+
)
1119+
1120+
10581121
# Prototype for clean_jaxpr using JAX's dce_jaxpr. Doesn't work because
10591122
# dce_jaxpr will remove any equations with no used outputs, regardless of the
10601123
# dce_rule for that equation's primitive. Adding an "effect" to loss/layer

kfac_jax/_src/tracer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,11 @@ def __init__(
225225
self.consts = consts
226226
self.in_tree = in_tree
227227
self.params_index = params_index
228-
self.layer_tags, self.loss_tags = extract_tags(jaxpr)
228+
# clean jaxpr of layer tags at this point
229+
closed_jaxpr = jex.core.ClosedJaxpr(jaxpr=self.jaxpr, consts=self.consts)
230+
self.jaxpr, self.layer_tags = tgm.clean_layer_tags_jaxpr(closed_jaxpr)
231+
self.jaxpr = self.jaxpr.jaxpr
232+
_, self.loss_tags = extract_tags(self.jaxpr)
229233
name_layer_tags(self.layer_tags)
230234
self.layer_tags, self.layer_indices = order_layer_tags(
231235
params_vars_flat=self.params_vars_flat,
@@ -579,7 +583,7 @@ def _loss_tags_vjp(
579583
p_jaxpr: ProcessedJaxpr,
580584
primal_func_args: FuncArgs,
581585
) -> LossTagsVjp:
582-
"""Computes a (backward-mode) vector-Jacobian product w.r.t. all loss tags.
586+
"""Computes a (backward-mode) vector-Jacobian product for the vector of losses given by the loss tags.
583587
584588
The function has similar interface to :func:`jax.vjp`. It takes as inputs the
585589
concrete values of the primals at which the Jacobian will be evaluated. It
@@ -651,7 +655,7 @@ def _loss_tags_jvp(
651655
primal_func_args: FuncArgs,
652656
params_tangents: Params,
653657
) -> LossTagsJvp:
654-
"""Computes a (forward-mode) Jacobian-vector product w.r.t. all loss tags.
658+
"""Computes a (forward-mode) Jacobian-vector product for the losses given by the loss tags.
655659
656660
The function has similar interface to :func:`jax.jvp`. It takes as inputs the
657661
concrete values of the primals at which the Jacobian will be evaluated at and
@@ -795,7 +799,7 @@ def _layer_tag_vjp(
795799
Args:
796800
processed_jaxpr: The :class:`~ProcessedJaxpr` representing the model
797801
function. This must include at least one loss tag.
798-
primal_func_args: The primals at which to evaluate the Hessian.
802+
primal_func_args: The primals at which to evaluate the Jacobian.
799803
800804
Returns:
801805
The computed ``losses`` and ``vjp_func`` pair.
@@ -824,14 +828,12 @@ def forward() -> tuple[Array, ...]:
824828
# Loop through equations and evaluate them
825829
num_losses_passed = 0
826830
for eqn in processed_jaxpr.jaxpr.eqns:
827-
828-
write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, read(eqn.invars)))
829-
830831
if isinstance(eqn.primitive, tags.LossTag):
831832
num_losses_passed += 1
832833
if num_losses_passed == len(processed_jaxpr.loss_tags):
833834
break
834-
835+
else:
836+
write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, read(eqn.invars)))
835837
assert num_losses_passed == len(processed_jaxpr.loss_tags)
836838

837839
return tuple(read(layer_input_vars))
@@ -884,7 +886,6 @@ def write(variables: list[jex.core.Var], values: list[Array]) -> None:
884886
for eqn in processed_jaxpr.jaxpr.eqns:
885887

886888
input_values = read(eqn.invars)
887-
write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, input_values))
888889

889890
if isinstance(eqn.primitive, tags.LossTag):
890891
loss: LossFunction = tags.loss_eqn_construct_loss(eqn, *input_values)
@@ -896,6 +897,8 @@ def write(variables: list[jex.core.Var], values: list[Array]) -> None:
896897

897898
if num_losses_passed == len(processed_jaxpr.loss_tags):
898899
break
900+
else:
901+
write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, input_values))
899902

900903
assert num_losses_passed == len(processed_jaxpr.loss_tags)
901904

0 commit comments

Comments
 (0)