Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions kfac_jax/_src/curvature_blocks/tnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def update_curvature_matrix_estimate(
identity_weight: Numeric,
batch_size: Numeric,
) -> KroneckerFactored.State:

del identity_weight

# Copy this first since we mutate it later in this function.
Expand Down Expand Up @@ -119,6 +120,7 @@ def update_curvature_matrix_estimate(
identity_weight: Numeric,
batch_size: Numeric,
) -> KroneckerFactored.State:

del identity_weight

# Copy this first since we mutate it later in this function.
Expand Down Expand Up @@ -216,6 +218,7 @@ def update_curvature_matrix_estimate(
identity_weight: Numeric,
batch_size: Numeric,
) -> KroneckerFactored.State:

del identity_weight

# Copy this first since we mutate it later in this function.
Expand Down
2 changes: 1 addition & 1 deletion kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ def _compute_preconditioned_gradient(
def _maybe_apply_norm_constraint(
self, grads: Params, preconditioned_grads: Params, coefficient: Array
) -> tuple[Params, Params | None]:
"""Scales precon grad to have F-weighted norm <= norm_constraint."""
"""Scales precon grad to have curvature-weighted norm <= norm_constraint."""

if self._norm_constraint is None:
return preconditioned_grads, None
Expand Down
2 changes: 1 addition & 1 deletion kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def add_vars_if_possible(
If at least one of the pattern variables is a parameter, but the
corresponding graph variable is not or vise-versa, the method does not
update the current variables map and returns ``False``. Similarly, if at
least one of the graph variables is a :class:`iteral` (meaning a
least one of the graph variables is a :class:`Literal` (meaning a
constant, independent of the function inputs) and the corresponding
pattern variable is not an input to the pattern, it returns ``False``. In
all other cases it updates the map and returns ``True``.
Expand Down
Loading