Skip to content

Commit 28dd0ac

Browse files
james-martensKfacJaxDev
authored andcommitted
Improve logging of parameter registrations.
Log each line of the pretty-printed parameter tree separately to avoid truncation in logs. PiperOrigin-RevId: 857140589
1 parent 81ef38d commit 28dd0ac

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

kfac_jax/_src/tag_graph_matcher.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,12 +1784,11 @@ def print_parameter_tags(self):
17841784
labels = ["|".join(self._param_labels.get(p, ["Orphan"]))
17851785
for p in self._func_graph.params_vars]
17861786
logging.info("=" * 50)
1787-
logging.info(
1788-
"Graph parameter registrations:\n%s",
1789-
pprint.pformat(jax.tree_util.tree_unflatten(
1790-
self._func_graph.params_tree, labels,
1791-
))
1792-
)
1787+
logging.info("Graph parameter registrations:")
1788+
for line in pprint.pformat(jax.tree_util.tree_unflatten(
1789+
self._func_graph.params_tree, labels,
1790+
)).split("\n"):
1791+
logging.info(line)
17931792
logging.info("=" * 50)
17941793

17951794
def check_multiple_registrations(self):

0 commit comments

Comments
 (0)