Skip to content

Commit 6dc7352

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Include source info annotations in state primitive custom pretty printer rules.
The standard pp_eqn rule includes a pretty printer annotation containing the source info. Add this annotation to several custom pp_eqn rules. PiperOrigin-RevId: 861788817
1 parent 3a38a4b commit 6dc7352

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

jax/_src/pallas/primitives.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from jax._src import effects
3737
from jax._src import linear_util as lu
3838
from jax._src import pretty_printer as pp
39+
from jax._src import source_info_util
3940
from jax._src import state
4041
from jax._src import util
4142
from jax._src.interpreters import ad
@@ -404,8 +405,11 @@ def _load_pp_rule(eqn, context, settings):
404405
eqn.params["args_tree"], eqn.invars
405406
)
406407
# TODO(sharadmv): pretty print mask and other
408+
annotation = (source_info_util.summarize(eqn.source_info)
409+
if settings.source_info else None)
407410
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
408-
result = [lhs, pp.text(" <- "), sp.pp_ref_transforms(context, x, transforms)]
411+
result = [lhs, pp.text(" <- ", annotation=annotation),
412+
sp.pp_ref_transforms(context, x, transforms)]
409413
if mask is not None:
410414
result += [
411415
pp.text(" "),
@@ -564,16 +568,19 @@ def _swap_pp_rule(eqn, context, settings):
564568
y, = eqn.outvars
565569
x, transforms, val, mask = eqn.params["args_tree"].unflatten(eqn.invars)
566570
x_i = sp.pp_ref_transforms(context, x, transforms)
571+
annotation = (source_info_util.summarize(eqn.source_info)
572+
if settings.source_info else None)
567573
if isinstance(y, jax_core.DropVar):
568574
return pp.concat([
569575
x_i,
570-
pp.text(" <- "), pp.text(jax_core.pp_var(val, context))])
576+
pp.text(" <- ", annotation=annotation),
577+
pp.text(jax_core.pp_var(val, context))])
571578
y = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
572579
result = [
573580
y,
574581
pp.text(", "),
575582
x_i,
576-
pp.text(" <- "),
583+
pp.text(" <- ", annotation=annotation),
577584
x_i,
578585
pp.text(", "),
579586
pp.text(jax_core.pp_var(val, context)),

jax/_src/state/primitives.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from jax._src import dispatch
2626
from jax._src import dtypes
2727
from jax._src import pretty_printer as pp
28+
from jax._src import source_info_util
2829
from jax._src import traceback_util
2930
from jax._src import tree_util
3031
from jax._src.interpreters import ad
@@ -532,42 +533,48 @@ def _get_pp_rule(eqn, context, settings) -> pp.Doc:
532533
x, *flat_idx = eqn.invars
533534
transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
534535
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
536+
annotation = (source_info_util.summarize(eqn.source_info)
537+
if settings.source_info else None)
535538
return pp.concat(
536-
[lhs, pp.text(" <- "), pp_ref_transforms(context, x, transforms)]
539+
[lhs, pp.text(" <- ", annotation=annotation),
540+
pp_ref_transforms(context, x, transforms)]
537541
)
538542
core.pp_eqn_rules[get_p] = _get_pp_rule
539543

540544
def _swap_pp_rule(eqn, context, settings) -> pp.Doc:
541545
y, = eqn.outvars
542546
x, v, *flat_idx = eqn.invars
543547
transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
548+
annotation = (source_info_util.summarize(eqn.source_info)
549+
if settings.source_info else None)
544550
if type(y) is core.DropVar:
545551
# In the case of a set (ignored return value),
546552
# pretty print `_ = swap x v i` as `x[i] <- v`
547553
del y
548554
return pp.concat([
549555
pp_ref_transforms(context, x, transforms),
550-
pp.text(" <- "),
556+
pp.text(" <- ", annotation=annotation),
551557
pp.text(core.pp_var(v, context)),
552558
])
553559
else:
554560
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
555561
x_i = pp_ref_transforms(context, x, transforms)
556562
y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
557-
return pp.concat([y, pp.text(', '), x_i, pp.text(' <- '),
558-
x_i, pp.text(', '),
559-
pp.text(core.pp_var(v, context))])
563+
return pp.concat([y, pp.text(', '), x_i,
564+
pp.text(' <- ', annotation=annotation), x_i,
565+
pp.text(', '), pp.text(core.pp_var(v, context))])
560566
core.pp_eqn_rules[swap_p] = _swap_pp_rule
561567

562568
def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc:
563-
del settings
564569
# pretty-print ` = addupdate x i v` as `x[i] += v`
565570
() = eqn.outvars
566571
x, v, *flat_idx = eqn.invars
567572
transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
573+
annotation = (source_info_util.summarize(eqn.source_info)
574+
if settings.source_info else None)
568575
return pp.concat([
569576
pp_ref_transforms(context, x, transforms),
570-
pp.text(" += "),
577+
pp.text(" += ", annotation=annotation),
571578
pp.text(core.pp_var(v, context)),
572579
])
573580
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule

0 commit comments

Comments
 (0)